#include <iostream>
#include <vector>
#include <cstdlib>
#include <ctime>
#include <mpi.h>

using namespace std;

// Function to generate a square matrix with random numbers from 0 to 9
void generateMatrix(vector<vector<float>>& matrix, int n) {
    srand(time(0)); // Seed the random number generator
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            matrix[i][j] = rand() % 10; // Random number between 0 and 9
        }
    }
}

// Function to print a matrix (works for both int and float)
template <typename T>
void printMatrix(const vector<vector<T>>& matrix, int n) {
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            cout << matrix[i][j] << " ";
        }
        cout << endl;
    }
}

int main(int argc, char** argv) {
    MPI_Init(&argc, &argv);  // Initialize MPI environment

    int n;
    int rank, size;

    // Get the rank of the process and the number of processes
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    // Only the root process (rank 0) will handle user input
    if (rank == 0) {
        cout << "Assume that number of rows is divisible by number of processes" << endl;
        cout << "Enter the size of the matrix: ";
        cin >> n;
    }

    // Broadcast the matrix size to all processes
    MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);

    vector<vector<float>> matrix(n, vector<float>(n));

    // Only the root process generates the matrix
    if (rank == 0) {
        generateMatrix(matrix, n);

        // Print the original matrix
        cout << "Original matrix:\n";
        printMatrix(matrix, n);
    }

    // Flatten the matrix into a 1D array for scattering
    vector<float> flat_matrix(n * n);
    if (rank == 0) {
        int idx = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                flat_matrix[idx++] = matrix[i][j];
            }
        }
    }

    // Scatter the matrix rows to all processes
    // Calculate how many rows each process will get
    int rows_per_process = n / size;
    vector<float> local_rows(n * rows_per_process);  // Each process will receive multiple rows
    MPI_Scatter(flat_matrix.data(), n * rows_per_process, MPI_FLOAT, local_rows.data(), n * rows_per_process, MPI_FLOAT, 0, MPI_COMM_WORLD);

    // Buffer for the pivot row (for broadcasting) as floating-point numbers
    vector<float> pivot_row(n);  // Use 'float' for floating-point precision

    // Perform iterations
    for (int i = 0; i < n; i++) {
        // Determine which process is responsible for the pivot based on the iteration index
        int speaking_process = i / rows_per_process;  // Speaking process determined by matrix rows distribution

        // Only the "speaking" process will broadcast its row as the pivot row
        if (rank == speaking_process) {
            // Select the appropriate row as the pivot
            int pivot_row_index = i % rows_per_process;  // Select row within the chunk
            pivot_row.assign(local_rows.begin() + pivot_row_index * n, local_rows.begin() + (pivot_row_index + 1) * n);

            // Normalize the pivot row: make the i-th element equal to 1
            if (pivot_row[i] != 0) {
                float pivot_value = pivot_row[i];  // Get the pivot value
                for (int j = 0; j < n; j++) {
                    pivot_row[j] = pivot_row[j] / pivot_value;  // Perform floating-point division
                }
            }

            // Broadcast the normalized pivot row to all other processes
            MPI_Bcast(pivot_row.data(), n, MPI_FLOAT, rank, MPI_COMM_WORLD);
        } else {
            // Receive the pivot row from the "speaking" process
            MPI_Bcast(pivot_row.data(), n, MPI_FLOAT, speaking_process, MPI_COMM_WORLD);
        }

        // Perform Gauss elimination on the local rows with received pivot row
        if (rank > speaking_process) {
            // Update all local rows with the pivot row
            for (int j = 0; j < rows_per_process; j++) {
                float factor = local_rows[j * n + i];
                for (int k = 0; k < n; k++) {
                    local_rows[j * n + k] -= factor * pivot_row[k];
                }
            }
        } else if (rank == speaking_process) {
            // Only rows below the pivot row in the local chunk are updated
            int pivot_row_index = i % rows_per_process;
            for (int j = pivot_row_index + 1; j < rows_per_process; j++) {
                float factor = local_rows[j * n + i];
                for (int k = 0; k < n; k++) {
                    local_rows[j * n + k] -= factor * pivot_row[k];
                }
            }
        }
        // If rank < speaking_process, do nothing
    }

    // Gather the updated rows back to the root process (rank 0)
    vector<float> gathered_matrix(n * n);  // To store the gathered result
    MPI_Gather(local_rows.data(), n * rows_per_process, MPI_FLOAT, gathered_matrix.data(), n * rows_per_process, MPI_FLOAT, 0, MPI_COMM_WORLD);

    // Process 0 will reconstruct the matrix and print the final result
    if (rank == 0) {
        // Reconstruct the matrix from the gathered data
        vector<vector<float>> result_matrix(n, vector<float>(n));
        int idx = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                result_matrix[i][j] = gathered_matrix[idx++];
            }
        }

        // Print the final matrix
        cout << "Result matrix after Gauss elimination:\n";
        printMatrix(result_matrix, n);
    }

    MPI_Finalize();

    return 0;
}