#include <iostream>
#include <vector>

#include "Utils.hpp"

using namespace std;

// [left, right]
void mergeSortSequentialRecursive(vector<double> &u, int left, int right) {
    if (left < right) {
        // If the vector is small enough, do not continue with recursion.
        if ((right - left) >= 32) {
            int mid = (left + right) / 2;
            mergeSortSequentialRecursive(u, left, mid);
            mergeSortSequentialRecursive(u, mid + 1, right);
            inplace_merge(u.begin() + left, u.begin() + mid + 1, u.begin() + right + 1);
        } else {
            sort(u.begin() + left, u.begin() + right + 1);
        }
    }
}

void mergeSortSequential(vector<double> &u) {
    mergeSortSequentialRecursive(u, 0, u.size() - 1);
}

void mergeSortParallelRecursive(vector<double> &u, int left, int right) {
    if (left < right) {
        if ((right - left) >= 32) {
            int mid = (left + right) / 2;

            #pragma omp taskgroup
            {
                #pragma omp task shared(u)
                mergeSortParallelRecursive(u, left, mid);

                #pragma omp task shared(u)
                mergeSortParallelRecursive(u, mid + 1, right);
            }

            inplace_merge(u.begin() + left, u.begin() + mid + 1, u.begin() + right + 1);
        } else {
            sort(u.begin() + left, u.begin() + right + 1);
        }
    }
}

void mergeSortParallel(vector<double> &u) {
    #pragma omp parallel
    {
        #pragma omp single
        {
            mergeSortParallelRecursive(u, 0, u.size() - 1);
        }
    }
}

int main() {
    vector<double> u = generateRandomVector(5000000);

    {
        auto uCopy = u;
        Stopwatch sw;
        sw.start();
        mergeSortSequential(uCopy);
        sw.stop();
        cout << "Sequential merge sort: " << sw.duration().count()
             << " ms, sorted correctly " << is_sorted(uCopy.begin(), uCopy.end()) << endl;
    }

    {
        auto uCopy = u;
        Stopwatch sw;
        sw.start();
        mergeSortParallel(uCopy);
        sw.stop();
        cout << "Parallel merge sort: " << sw.duration().count()
             << " ms, sorted correctly " << is_sorted(uCopy.begin(), uCopy.end()) << endl;
    }

    return 0;
}

