#include <iostream>
#include <chrono>
#include <omp.h>

using namespace std;
using namespace std::chrono;

// Standard matrix multiplication (sequential)
void matrix_mul(float **A, float **B, float **C, int size) {
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            // Compute the dot product of the ith row of A and jth row of B
            float sum = 0.0f;
            for (int k = 0; k < size; k++) {
                sum += A[i][k] * B[k][j];
            }
            C[i][j] = sum;
        }
    }
}

// Standard matrix multiplication (parallelized)
void matrix_mul_parallel(float **A, float **B, float **C, int size) {
    // Parallelize the i-j loops using collapse(2) to balance the workflow
    #pragma omp parallel for collapse(2) shared(A, B, C)
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            float sum = 0.0f;
            for (int k = 0; k < size; k++) {
                sum += A[i][k] * B[k][j];
            }
            C[i][j] = sum;
        }
    }
}

// Matrix multiplication with transposed matrix B (sequential)
void matrix_mul_transposed(float **A, float **BT, float **C, int size) {
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            float sum = 0.0f;
            for (int k = 0; k < size; k++) {
                // Access to BT uses opposite index order because the columns are stored in the row direction
                sum += A[i][k] * BT[j][k];
            }
            C[i][j] = sum;
        }
    }
}

// Matrix multiplication with transposed matrix B (parallelized)
void matrix_mul_parallel_transposed(float **A, float **BT, float **C, int size) {
    // Parallelize the i-j loops using collapse(2) to balance the workflow
    #pragma omp parallel for collapse(2) shared(A, BT, C)
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            float sum = 0.0f;
            for (int k = 0; k < size; k++) {
                sum += A[i][k] * BT[j][k];
            }
            C[i][j] = sum;
        }
    }
}

// Block matrix multiplication (sequential)
void block_matrix_mul(float **A, float **BT, float **C, int size, int block_size) {
    // Iterate over the blocks of matrix C
    for (int i = 0; i < size; i += block_size) {
        for (int j = 0; j < size; j += block_size) {
            // Now iterate over blocks of matrix A and B for the current block in C
            for (int k = 0; k < size; k += block_size) {
                int i_block_end = i + block_size < size ? i + block_size : size;
                int j_block_end = j + block_size < size ? j + block_size : size;
                int k_block_end = k + block_size < size ? k + block_size : size;

                // Multiply the blocks A[i_block][k_block] and B[k_block][j_block] and add to C[i_block][j_block]
                for (int i_block = i; i_block < i_block_end; i_block++) {
                    for (int j_block = j; j_block < j_block_end; j_block++) {
                        float sum = 0;
                        // Multiply and accumulate the corresponding elements
                        for (int k_block = k; k_block < k_block_end; k_block++) {
                            sum += A[i_block][k_block] * BT[j_block][k_block];
                        }
                        // Add the result to the corresponding element in C
                        C[i_block][j_block] += sum;
                    }
                }
            }
        }
    }
}

// Block matrix multiplication (parallelized)
void block_matrix_mul_parallel(float **A, float **BT, float **C, int size, int block_size) {
    // Parallelize the outer loops (i, j) for blocks of C using OpenMP
    #pragma omp parallel for collapse(2) // Parallelize both i and j loops
    for (int i = 0; i < size; i += block_size) {
        for (int j = 0; j < size; j += block_size) {
            for (int k = 0; k < size; k += block_size) {
                int i_block_end = i + block_size < size ? i + block_size : size;
                int j_block_end = j + block_size < size ? j + block_size : size;
                int k_block_end = k + block_size < size ? k + block_size : size;

                for (int i_block = i; i_block < i_block_end; i_block++) {
                    for (int j_block = j; j_block < j_block_end; j_block++) {
                        float sum = 0;
                        for (int k_block = k; k_block < k_block_end; k_block++) {
                            sum += A[i_block][k_block] * BT[j_block][k_block];
                        }
                        C[i_block][j_block] += sum;
                    }
                }
            }
        }
    }
}

// Block matrix multiplication (sequential), another task grouping
void block_A_shared_matrix_mul(float **A, float **BT, float **C, int size, int block_size) {
    // Iterate over the blocks of matrix A
    for (int i = 0; i < size; i += block_size) {
        for (int k = 0; k < size; k += block_size) {
            // Block of A will be reused
            for (int j = 0; j < size; j += block_size) {
                int i_block_end = i + block_size < size ? i + block_size : size;
                int j_block_end = j + block_size < size ? j + block_size : size;
                int k_block_end = k + block_size < size ? k + block_size : size;

                for (int i_block = i; i_block < i_block_end; i_block++) {
                    for (int j_block = j; j_block < j_block_end; j_block++) {
                        float sum = 0;
                        // Multiply and accumulate the corresponding elements
                        for (int k_block = k; k_block < k_block_end; k_block++) {
                            sum += A[i_block][k_block] * BT[j_block][k_block];
                        }
                        // Add the result to the corresponding element in C
                        C[i_block][j_block] += sum;
                    }
                }
            }
        }
    }
}

// Block matrix multiplication (parallelized), another task grouping
void block_A_shared_matrix_mul_parallel(float **A, float **BT, float **C, int size, int block_size) {
    // Parallelize the outer loops (i, j) for blocks of C using OpenMP
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < size; i += block_size) {
        for (int k = 0; k < size; k += block_size) {
            // Block of A will be reused
            for (int j = 0; j < size; j += block_size) {
                int i_block_end = i + block_size < size ? i + block_size : size;
                int j_block_end = j + block_size < size ? j + block_size : size;
                int k_block_end = k + block_size < size ? k + block_size : size;

                for (int i_block = i; i_block < i_block_end; i_block++) {
                    for (int j_block = j; j_block < j_block_end; j_block++) {
                        float sum = 0;
                        for (int k_block = k; k_block < k_block_end; k_block++) {
                            sum += A[i_block][k_block] * BT[j_block][k_block];
                        }
                        C[i_block][j_block] += sum;
                    }
                }
            }
        }
    }
}

// Function to create matrix with all initialized values 0.0
float** create_empty_matrix(int size) {
    float **m = new float *[size];
    for (int i = 0; i < size; i++) {
        m[i] = new float[size];
    }
    return m;
}

// Function to create matrix with all randomly initialized values
float** create_random_matrix(int size) {
    float **m = create_empty_matrix(size);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            m[i][j] = static_cast<float> (rand()) / static_cast<float> (RAND_MAX);
        }
    }
    return m;
}

// Function to create transposed matrix from given matrix
float** create_transposed_matrix(float **matrix, int size) {
    float **m = create_empty_matrix(size);
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            m[i][j] = matrix[j][i];
        }
    }
    return m;
}

// Function for given matrix set all values to 0.0
void reset_matrix(float **matrix, int size) {
    #pragma omp parallel for
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            matrix[i][j] = 0.0f;
        }
    }
}

// Function to print matrix
void print_matrix(float **matrix, int size) {
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            cout << matrix[i][j] << " ";
        }
        cout << endl;
    }
}

// Function to decide if two given matrix are equal with some precision
bool matrix_equal(float **a, float **b, int size) {
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            if (abs(a[i][j] - b[i][j]) > 0.01) {
                cout << i << " " << j << endl;
                cout << a[i][j] << " " << b[i][j] << endl;
                return false;
            }
        }
    }
    return true;
}

int main(int argc, char **argv) {
    int size, block_size;
    cout << "Enter size of Matrix: ";
    cin >> size;
    cout << "Enter size of block: ";
    cin >> block_size;

    cout << "Init has begun" << endl;
    float **A = create_random_matrix(size);
    float **B = create_random_matrix(size);
    float **BT = create_transposed_matrix(B, size);
    float **C = create_empty_matrix(size);

    float **S = create_empty_matrix(size);

    // function to run algorithm approach, measure execution time and print it
    auto measure_approach = [](auto approach_function, string approach_name, auto... function_args) {
        cout << approach_name << ": ";

        high_resolution_clock::time_point t1 = high_resolution_clock::now();

        approach_function(function_args...);

        high_resolution_clock::time_point t2 = high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
        cout << static_cast <float> (duration) / 1000000.0f << " seconds" << endl;

    };

    #pragma omp parallel
    {
        int tid = omp_get_thread_num();
        if (tid == 0)
            cout << "Number of available threads: " << omp_get_num_threads() << endl;
    }

    for (int it = 0; it < 1; it++) {
        measure_approach(matrix_mul, "Sequential", A, B, S, size);
        measure_approach(matrix_mul_parallel, "Parallel", A, B, C, size);
        cout << "Correct " << matrix_equal(S, C, size) << endl;

        measure_approach(matrix_mul_transposed, "Sequential transposed", A, BT, C, size);
        cout << "Correct " << matrix_equal(S, C, size) << endl;
        measure_approach(matrix_mul_parallel_transposed, "Parallel transposed", A, BT, C, size);
        cout << "Correct " << matrix_equal(S, C, size) << endl;

        reset_matrix(C, size); // uncomment when you interested in right result, implementation sum in C
        measure_approach(block_matrix_mul, "Block Sequential (transposed)", A, BT, C, size, block_size);
        cout << "Correct " << matrix_equal(S, C, size) << endl;
        reset_matrix(C, size); // uncomment when you interested in right result, implementation sum in C
        measure_approach(block_matrix_mul_parallel, "Block parallel (transposed)", A, BT, C, size, block_size);
        cout << "Correct " << matrix_equal(S, C, size) << endl;

        reset_matrix(C, size); // uncomment when you interested in right result, implementation sum in C
        measure_approach(block_A_shared_matrix_mul, "Block A shared Sequential (transposed)", A, BT, C, size, block_size);
        cout << "Correct " << matrix_equal(S, C, size) << endl;
        reset_matrix(C, size); // uncomment when you interested in right result, implementation sum in C
        measure_approach(block_A_shared_matrix_mul_parallel, "Block A shared parallel (transposed)", A, BT, C, size, block_size);
        cout << "Correct " << matrix_equal(S, C, size) << endl;
    }
    return 0;
}
