#include <iostream>
#include <vector>

#include "Utils.hpp"

using namespace std;

double matrixSumSequential(const vector<vector<double>> &matrix) {
    double sum = 0.0;
    for (int i = 0; i < matrix.size(); i++) {
        for (int j = 0; j < matrix[i].size(); j++) {
            sum += matrix[i][j];
        }
    }

    return sum;
}

double matrixSumSequentialRowSums(const vector<vector<double>> &matrix) {
    vector<double> rowSums(matrix.size(), 0);
    for (int i = 0; i < matrix.size(); i++) {
        for (int j = 0; j < matrix[i].size(); j++) {
            rowSums[i] += matrix[i][j];
        }
    }

    double sum = 0.0;
    for (int i = 0; i < matrix.size(); i++) {
        sum += rowSums[i];
    }

    return sum;
}

double matrixSumParallelForAndCritical(const vector<vector<double>> &matrix) {
    double sum = 0.0;
    #pragma omp parallel for
    for (int i = 0; i < matrix.size(); i++) {
        for (int j = 0; j < matrix[i].size(); j++) {
            #pragma omp critical
            sum += matrix[i][j];
        }
    }

    return sum;
}

double matrixSumParallelForAndReduction(const vector<vector<double>> &matrix) {
    vector<double> rowSums(matrix.size(), 0);
    #pragma omp parallel for
    for (int i = 0; i < matrix.size(); i++) {
        for (int j = 0; j < matrix[i].size(); j++) {
            rowSums[i] += matrix[i][j];
        }
    }

    double sum = 0.0;
    #pragma omp parallel for reduction(+:sum)
    for (int i = 0; i < matrix.size(); i++) {
        sum += rowSums[i];
    }

    return sum;
}

double matrixSumParallelForAndCollapseAndReduction(const vector<vector<double>> &matrix) {
    int numRows = matrix.size();
    int numCols = matrix[0].size();

    double sum = 0.0;
    #pragma omp parallel for collapse(2) reduction(+:sum)
    for (int i = 0; i < numRows; i++) {
        for (int j = 0; j < numCols; j++) {
            sum += matrix[i][j];
        }
    }

    return sum;
}

int main() {
    // The sum of this matrix is 5.50037e+08.
    vector<vector<double>> matrix = generateRandomMatrix(10000, 10000, 42);

    {
        Stopwatch sw;
        sw.start();
        auto sum = matrixSumSequential(matrix);
        sw.stop();
        cout << "Sequential: " << sw.duration().count() << " ms, sum " << sum << endl;
    }

    {
        Stopwatch sw;
        sw.start();
        auto sum = matrixSumSequentialRowSums(matrix);
        sw.stop();
        cout << "Sequential using row sums: " << sw.duration().count() << " ms, sum " << sum << endl;
    }

    {
        Stopwatch sw;
        sw.start();
        auto sum = matrixSumParallelForAndCritical(matrix);
        sw.stop();
        cout << "Parallel for + critical section: " << sw.duration().count() << " ms, sum " << sum << endl;
    }

    {
        Stopwatch sw;
        sw.start();
        auto sum = matrixSumParallelForAndReduction(matrix);
        sw.stop();
        cout << "Parallel for + reduction: " << sw.duration().count() << " ms, sum " << sum << endl;
    }

    {
        Stopwatch sw;
        sw.start();
        auto sum = matrixSumParallelForAndCollapseAndReduction(matrix);
        sw.stop();
        cout << "Parallel for + collapse + reduction: " << sw.duration().count() << " ms, sum " << sum << endl;
    }

    return 0;
}
