#!/usr/bin/env python3
import argparse
import numpy as np
from typing import Tuple
from utils import read_input, save_list_to_path

parser = argparse.ArgumentParser(description="Process input and output file paths.")
parser.add_argument('input_file_path', type=str, help='Path to the input file')
parser.add_argument('--plot', action='store_true', help='Plot the data')
# These arguments will be set appropriately by Brute, even if you change them.
parser.add_argument('output_file_path', nargs='?', default=None, type=str, help='Path to the output file')
parser.add_argument('--brute', action='store_true', help='Evaluation in Brute')

def learn_classifier(x: np.ndarray, y: np.ndarray, loss_matrix: np.ndarray, histogram_edges: np.ndarray) -> np.ndarray:
    """
    Learn a classifier that minimizes the empirical loss based on the provided histogram edges and loss matrix.

    The classifier works by binning the observations `x` according to `histogram_edges`.
    Each bin corresponds to a specific class.

    Edge Handling:
    - The `histogram_edges` array defines the boundaries for the bins. The edges are interpreted using the 
      inclusive-exclusive rule: for each bin `i`, an observation `x` is assigned to bin `i` if 
      `histogram_edges[i] <= x < histogram_edges[i+1]`.
    - The exception is the final bin, where observations equal to the last edge `histogram_edges[-1]` are included 
      in the last bin.
    
    Args:
        x (np.ndarray): Float array representing the observations; the values range from min(histogram_edges) to max(histogram_edges).
        y (np.ndarray): Integer array representing the true class labels; label is an integer from 0 to Y-1.
        loss_matrix (np.ndarray): Loss represented as a (Y,Y) shape matrix. The value loss_matrix[y,yy] represents the loss incurred when the true label is y and prediction is yy.
        histogram_edges: (np.ndarray): Float array representing the edges of the histogram.
    
    Returns:
        np.ndarray: Array of weights for each bin.
    """
    # TODO: Seperate the observations into bins using the histogram edges
    # Note: You can take inspiration from the 'predict' function
    bin_indices = ...
    
    # Initialize the weights
    weights = np.zeros(histogram_edges.size - 1)
    
    # TODO: Compute the histogram classifier weights using empirical risk minimization
    # E.g., you can loop over all bins, and compute the label that minimizes the risk for the bin
    weights = ...
    
    return weights

def predict(x: np.ndarray, weights: np.ndarray, histogram_edges: np.ndarray) -> np.ndarray:
    """
    Predict class labels for the given observations based on the learned weights and histogram edges.

    Args:
        x (np.ndarray): Float array representing the observations.
        weights (np.ndarray): Array of weights for each bin.
        histogram_edges (np.ndarray): Float array representing the edges of the histogram.

    Returns:
        np.ndarray: Integer array of predicted class labels.
    """
    
    # Digitize the observations into bins using the histogram edges
    # `digitize` returns 1-based index, so subtract 1 for 0-based bins
    bin_indices = np.digitize(x, histogram_edges) - 1  
    # Manually assign values equal to the rightmost edge to the last bin
    bin_indices[x == histogram_edges[-1]] = len(histogram_edges) - 2    
    # Number of histogram bins
    K = np.size(weights)
    # Create one-hot encoded observations
    vector_observations = np.stack([np.eye(K)[i] for i in bin_indices])    
    # Predict the labels using the histogram classifier
    pred_y = vector_observations @ weights
    return pred_y.astype(int)
   
def generalization_bound(true_y: np.ndarray, pred_y: np.ndarray, histogram_edges: np.ndarray, loss_matrix: np.ndarray, delta: float) -> float:
    """
    Compute the upper bound on the true risk.
    
    Args:
        true_y (np.ndarray): Integer array representing the true class labels; label is an integer from 0 to Y-1.
        pred_y (np.ndarray): Integer array representing the predicted class labels; label is an integer from 0 to Y-1.  
        histogram_edges: (np.ndarray): Float array representing the edges of the histogram.
        loss_matrix (np.ndarray): Loss represented as a (Y,Y) shape matrix. The value loss_matrix[y,yy] represents the loss incurred when the true label is y and prediction is yy.
        delta (float): Scalar from (0,1) representing the probability of failure.
    
    Returns:
        float: The upper bound on the true risk.
    """
    assert np.size(true_y) == np.size(pred_y)

    # Number of histogram bins
    K = np.size(histogram_edges) - 1
    # Number of classes
    num_classes = np.shape(loss_matrix)[0]
    
    # TODO: 1) Compute the log of 2*H, where H is the number of unique histogram classifiers
    # Why not compute np.log(2*H) directly? Try it.
    log_2_H = ...
    # TODO: 2) Compute the training error
    R_train = ...
    # TODO: 3) Compute epsilon, specifying how much the empirical risk on the training set can deviate from the true risk
    epsilon = ...
    # TODO: 4) Compute the upper bound, such that the true risk is smaller with probability atleast 1-delta 
    R_UB = ...
    return R_UB


def estimation_error_bound(n_training_samples: float, histogram_edges: np.ndarray, loss_matrix: np.ndarray, delta: float) -> float:
    """
    Compute the estimation error bound, which specifies how much the true risk of the trained histogram classifier
    can deviate from the true risk of the best histogram classifier.
    
    Args:
        n_training_samples (float): Number of training samples.
        histogram_edges: (np.ndarray): Float array representing the edges of the histogram.
        loss_matrix (np.ndarray): Loss represented as a (Y,Y) shape matrix. The value loss_matrix[y,yy] represents the loss incurred when the true label is y and prediction is yy.
        delta (float): Scalar from (0,1) representing the probability of failure.
    
    Returns:
        float: The upper bound on the estimation error.
    """

    # Number of histogram bins
    K = np.size(histogram_edges) - 1
    # Number of classes
    num_classes = np.shape(loss_matrix)[0]

    # TODO: 1) Compute the log of 2*H, where H is the number of unique histogram classifiers
    # Why not compute np.log(2*H) directly? Try it.
    log_2_H = ...
    # TODO: 2) Compute epsilon, specifying how much the true risk of the trained histogram classifier
    # can deviate from the true risk of the best histogram classifier
    epsilon = ...
    return epsilon
    
    
def main(args):
    data = read_input(args.input_file_path) # Read the input data as a dictionary

    if args.plot:
        # Plot the data samples    
        import matplotlib.pyplot as plt
        plt.figure()
        samples = data['x']
        labels = data['y']
        plt.hist([samples[labels == label] for label in np.unique(labels)],
                bins=np.arange(np.min(samples), np.max(samples), (np.max(samples)-np.min(samples))/20.), 
                stacked=True, alpha=0.5, edgecolor='k')
        plt.xlabel('Value')
        plt.ylabel('Count (Stacked)')
        plt.title('Histograms of the observed classes')
        plt.show()

    w = learn_classifier(x=data['x'], 
                         y=data['y'], 
                         loss_matrix=data['loss_matrix'], 
                         histogram_edges=data['histogram_edges'])

    pred_y = predict(x=data['x'], 
                     weights=w, 
                     histogram_edges=data['histogram_edges'])
    
    R_UB = generalization_bound(true_y=data['y'], 
                                pred_y=pred_y,
                                histogram_edges=data['histogram_edges'],
                                loss_matrix=data['loss_matrix'],
                                delta=data['delta'])
    
    eps_UB = estimation_error_bound(n_training_samples=data['y'].size,
                                    histogram_edges=data['histogram_edges'],
                                    loss_matrix=data['loss_matrix'],
                                    delta=data['delta'])

    print(f"The trained histogram classifier achieves true error of at most {np.round(R_UB,3)} with probability at least {1-data['delta']}")
    print(f"The trained histogram classifier achieves true error that differs from the best histogram classifier by at most {np.round(eps_UB,3)} with probability at least {1-data['delta']}")

    if args.brute:
        save_list_to_path(w.tolist() + [R_UB, eps_UB], args.output_file_path) # Save the result
    else:
        print("Comparing with reference solution")
        with open(args.input_file_path.replace('instances', 'solutions').replace('.json', ''), 'r') as f:
            for line in f:
                reference = list(map(float, line.split()))

        reference_w = reference[:-2]
        reference_R_UB = reference[-2]
        reference_eps_UB = reference[-1]

        if np.allclose(w, reference_w, rtol=1e-03, atol=1e-03):
            print("learn_classifier: \t\tTest OK")
        else:
            print("learn_classifier: \t\tTest Failed")
    
        if np.allclose(R_UB, reference_R_UB, rtol=1e-03, atol=1e-03):
            print("generalization_bound: \t\tTest OK")
        else:
            print(reference_R_UB, R_UB)
            print("generalization_bound: \t\tTest Failed")  
            
        if np.allclose(eps_UB, reference_eps_UB, rtol=1e-03, atol=1e-03):
            print("estimation_error_bound: \tTest OK")
        else:
            print("estimation_error_bound: \tTest Failed")  
            

if __name__ == "__main__":            
    args = parser.parse_args()    
    main(args)
