import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader


def test_dataset_implementation(dataset_class, config=None):
    """
    Test function to verify that the FruitDataset is working correctly
    by loading and displaying some sample images with their masks

    Args:
        dataset_class: The dataset class to test (e.g., FruitDataset)
        config: Configuration dictionary for the dataset
    """
    # Default configuration for testing
    if config is None:
        config = {
            'data_dir': 'FruitSegmentationDataset/FruitSeg30/',
            'image_size': [64, 64],
            'cache_images': False
        }

    print("Testing FruitDataset implementation...")
    print(f"Loading dataset from: {config['data_dir']}")

    # Create dataset instance
    try:
        dataset = dataset_class(data_dir=config['data_dir'], config=config)
        print(f"✓ Dataset loaded successfully!")
        print(f"✓ Total samples: {len(dataset)}")

        # Get unique fruit classes
        unique_fruits = sorted(set(dataset.metadata))
        print(f"✓ Number of fruit classes: {len(unique_fruits)}")
        print(f"✓ Fruit classes: {unique_fruits[:5]}..." if len(
            unique_fruits) > 5 else f"✓ Fruit classes: {unique_fruits}")

    except Exception as e:
        print(f"✗ Error loading dataset: {e}")
        return

    # Test loading individual samples
    print("\nTesting sample loading...")
    try:
        # Load first few samples
        for i in range(min(3, len(dataset))):
            image, mask, metadata, label, idx = dataset[i]
            print(f"Sample {i}:")
            print(f"  - Fruit type: {metadata}")
            print(f"  - Label: {label.item()}")
            print(f"  - Image shape: {image.shape}")
            print(f"  - Mask shape: {mask.shape}")
            print(f"  - Image dtype: {image.dtype}")
            print(f"  - Mask dtype: {mask.dtype}")
            print(f"  - Image range: [{image.min():.3f}, {image.max():.3f}]")
            print(f"  - Mask range: [{mask.min():.3f}, {mask.max():.3f}]")
            print()

        print("✓ Sample loading test passed!")

    except Exception as e:
        print(f"✗ Error loading samples: {e}")
        return

    # Visualize some samples
    print("Visualizing sample images and masks...")
    try:
        fig, axes = plt.subplots(2, 6, figsize=(15, 5))
        fig.suptitle(
            'Dataset Samples: Images (top) and Masks (bottom)', fontsize=14)

        # Show 6 random samples
        sample_indices = np.random.choice(
            len(dataset), min(6, len(dataset)), replace=False)

        for i, idx in enumerate(sample_indices):
            image, mask, metadata, label, _ = dataset[idx]

            # Convert tensors to numpy for visualization
            # Image: convert from CHW to HWC format
            img_np = image.cpu().numpy().transpose(1, 2, 0)
            mask_np = mask.cpu().numpy().squeeze()  # Remove channel dimension

            # Display image
            axes[0, i].imshow(img_np)
            axes[0, i].set_title(f'{metadata}', fontsize=10)
            axes[0, i].axis('off')

            # Display mask
            axes[1, i].imshow(mask_np, cmap='gray')
            axes[1, i].set_title(f'Mask (Label: {label.item()})', fontsize=10)
            axes[1, i].axis('off')

        plt.tight_layout()
        plt.show()

        print("✓ Visualization completed!")

    except Exception as e:
        print(f"✗ Error in visualization: {e}")

    print("\n" + "="*50)
    print("Dataset implementation test completed!")
    print("If you see images and masks displayed above, your dataset is working correctly.")


def get_device():
    """Get the appropriate device for training"""
    import platform
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available() and platform.system() == "Darwin":
        return torch.device("mps")  # Metal backend for Mac
    else:
        return torch.device("cpu")


def test_model_with_data(dataset_class, model_class, config=None):
    """
    Test function to verify that the model works correctly with the dataset and dataloaders

    Args:
        dataset_class: The dataset class (e.g., FruitDataset)
        model_class: The model class (e.g., MyModel)
        config: Configuration dictionary for the dataset
    """
    # Default configuration for testing
    if config is None:
        config = {
            'data_dir': 'FruitSegmentationDataset/FruitSeg30/',
            'image_size': [64, 64],
            'cache_images': False
        }

    print("🧪 Testing Model with Dataset and DataLoader...")
    print("=" * 60)

    try:
        # Step 1: Create dataset
        print("📁 Creating dataset...")
        dataset = dataset_class(data_dir=config['data_dir'], config=config)
        print(f"✓ Dataset created with {len(dataset)} samples")

        # Step 2: Create dataloader
        print("\n🔄 Creating dataloader...")
        batch_size = 4
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        print(f"✓ DataLoader created with batch size {batch_size}")
        print(f"✓ Total batches: {len(dataloader)}")

        # Step 3: Create model
        print("\n🧠 Creating model...")
        device = get_device()
        print(f"✓ Using device: {device}")

        num_classes = len(set(dataset.metadata))  # Number of fruit classes
        model = model_class(input_size=64*64*3,
                            hidden_size=128, output_size=num_classes)
        model = model.to(device)
        print(f"✓ Model created for {num_classes} classes")

        # Step 4: Test single batch
        print("\n🔍 Testing single batch...")
        model.eval()  # Set to evaluation mode

        # Get one batch
        batch_images, batch_masks, batch_metadata, batch_labels, batch_indices = next(
            iter(dataloader))

        print(f"✓ Batch loaded successfully:")
        print(f"  - Images shape: {batch_images.shape}")
        print(f"  - Masks shape: {batch_masks.shape}")
        print(f"  - Labels shape: {batch_labels.shape}")
        print(f"  - Batch size: {len(batch_metadata)}")

        # Step 5: Forward pass through model
        print("\n⚡ Testing forward pass...")
        with torch.no_grad():
            classification_output, segmentation_output, embeddings = model(
                batch_images)

        print(f"✓ Forward pass successful:")
        print(
            f"  - Classification output shape: {classification_output.shape}")
        print(f"  - Segmentation output shape: {segmentation_output.shape}")
        print(f"  - Embeddings shape: {embeddings.shape}")

        # Step 6: Test model components
        print("\n🔧 Testing model components...")

        # Test backbone
        backbone_features = model.backbone(batch_images)
        print(f"✓ Backbone output shape: {backbone_features.shape}")

        # Test classification head
        class_pred = model.classification_head(backbone_features)
        print(f"✓ Classification head output shape: {class_pred.shape}")

        # Test segmentation head
        seg_pred = model.segmentation_head(backbone_features)
        print(f"✓ Segmentation head output shape: {seg_pred.shape}")

        # Test get_embedding method
        embeddings_test = model.get_embedding(batch_images)
        print(f"✓ get_embedding() output shape: {embeddings_test.shape}")

        # Step 7: Verify output dimensions
        print("\n✅ Verifying output dimensions...")
        expected_class_shape = (batch_size, num_classes)
        expected_seg_shape = (batch_size, 1, 64, 64)
        expected_emb_shape = (batch_size, 64*8*8)

        assert classification_output.shape == expected_class_shape, f"Classification shape mismatch: {classification_output.shape} != {expected_class_shape}"
        assert segmentation_output.shape == expected_seg_shape, f"Segmentation shape mismatch: {segmentation_output.shape} != {expected_seg_shape}"
        assert embeddings.shape == expected_emb_shape, f"Embeddings shape mismatch: {embeddings.shape} != {expected_emb_shape}"

        print("✓ All output dimensions are correct!")

        # Step 8: Test multiple batches
        print("\n🔄 Testing multiple batches...")
        batch_count = 0
        for batch_images, batch_masks, batch_metadata, batch_labels, batch_indices in dataloader:
            with torch.no_grad():
                class_out, seg_out, emb_out = model(batch_images)
            batch_count += 1
            if batch_count >= 3:  # Test first 3 batches
                break

        print(f"✓ Successfully processed {batch_count} batches")

        # Step 9: Visualize predictions
        print("\n📊 Testing prediction visualization...")
        model.eval()
        with torch.no_grad():
            # Get a small batch for visualization
            sample_images, sample_masks, sample_metadata, sample_labels, _ = next(
                iter(DataLoader(dataset, batch_size=2, shuffle=True)))
            class_pred, seg_pred, _ = model(sample_images)

            # Apply sigmoid to segmentation predictions to get probabilities
            seg_pred_prob = torch.sigmoid(seg_pred)

            # Get class predictions
            predicted_classes = torch.argmax(class_pred, dim=1)

            print("✓ Sample predictions:")
            for i in range(len(sample_metadata)):
                true_fruit = sample_metadata[i]
                pred_class_idx = predicted_classes[i].item()
                unique_fruits = sorted(set(dataset.metadata))
                pred_fruit = unique_fruits[pred_class_idx] if pred_class_idx < len(
                    unique_fruits) else "Unknown"

                print(
                    f"  Sample {i+1}: True={true_fruit}, Predicted={pred_fruit}")
                print(
                    f"    Segmentation range: [{seg_pred_prob[i].min():.3f}, {seg_pred_prob[i].max():.3f}]")

        print("\n" + "=" * 60)
        print("🎉 ALL TESTS PASSED! Model works correctly with Dataset and DataLoader!")
        print("🚀 Your model is ready for training!")

    except Exception as e:
        print(f"\n❌ ERROR: {str(e)}")
        print("Please check your model, dataset, or dataloader implementation.")
        import traceback
        traceback.print_exc()


def test_triplet_loss(compute_triplet_loss_func, dataset_class, model_class, config=None):
    """
    Comprehensive test for triplet loss implementation

    Args:
        compute_triplet_loss_func: The triplet loss function to test
        dataset_class: The dataset class (e.g., FruitDataset)
        model_class: The model class (e.g., MyModel)
        config: Configuration dictionary for the dataset
    """
    print("Testing Triplet Loss Implementation...")
    print("=" * 50)

    device = get_device()

    # Default configuration for testing
    if config is None:
        config = {
            'data_dir': 'FruitSegmentationDataset/FruitSeg30/',
            'image_size': [64, 64],
            'cache_images': False
        }

    # Test 1: Basic functionality with known embeddings
    print("Test 1: Basic functionality with known embeddings")

    # Create simple test embeddings
    anchor = torch.tensor([[1.0, 0.0, 0.0]], device=device)
    positive = torch.tensor(
        [[0.8, 0.6, 0.0]], device=device)  # Similar to anchor
    # Different from anchor
    negative = torch.tensor([[0.0, 0.0, 1.0]], device=device)

    triplets = [(anchor[0], positive[0], negative[0])]

    # Test with margin = 1.0
    loss = compute_triplet_loss_func(triplets, margin=1.0)
    print(f"✓ Basic triplet loss (margin=1.0): {loss.item():.4f}")

    # The loss should be positive since we expect positive pairs to be closer than negative pairs
    assert loss.item() >= 0, "Triplet loss should be non-negative"

    # Test 2: Perfect case - positive and anchor are identical
    print("\nTest 2: Perfect positive case")
    identical_positive = anchor.clone()
    triplets_perfect = [(anchor[0], identical_positive[0], negative[0])]
    loss_perfect = compute_triplet_loss_func(triplets_perfect, margin=1.0)
    print(f"✓ Loss with identical anchor-positive: {loss_perfect.item():.4f}")
    assert loss_perfect.item() == 0.0, "Loss should be zero when positive is identical to anchor"

    # Test 3: Worst case - negative is closer than positive
    print("\nTest 3: Worst case scenario")
    close_negative = torch.tensor(
        [[1.0, 0.0, 0.0]], device=device)  # Similar to anchor
    far_positive = torch.tensor(
        [[0.0, 1.0, 0.0]], device=device)    # Different from anchor

    triplets_worst = [(anchor[0], far_positive[0], close_negative[0])]
    loss_worst = compute_triplet_loss_func(triplets_worst, margin=1.0)
    print(f"✓ Loss with close negative: {loss_worst.item():.4f}")
    assert loss_worst.item(
    ) == 2.0, "Loss should be 2 when negative is same as anchor and positive is perpendicular"

    # Test 4: Test with multiple triplets
    print("\nTest 4: Multiple triplets")

    # Create batch of triplets
    anchors = torch.randn(5, 10, device=device)
    positives = anchors + 0.1 * \
        torch.randn(5, 10, device=device)  # Similar to anchors
    negatives = torch.randn(5, 10, device=device)  # Random, likely different

    triplets_batch = [(anchors[i], positives[i], negatives[i])
                      for i in range(5)]
    loss_batch = compute_triplet_loss_func(triplets_batch, margin=0.5)
    print(f"✓ Batch triplet loss (5 triplets): {loss_batch.item():.4f}")
    assert loss_batch.item() >= 0, "Batch triplet loss should be non-negative"

    # Test 5: Test with real model embeddings
    print("\nTest 5: Real model embeddings")

    dataset = dataset_class(config['data_dir'], config)
    dataloader = DataLoader(dataset, batch_size=6, shuffle=True)

    # Create model and get embeddings
    num_classes = len(set(dataset.metadata))
    model = model_class(input_size=64*64*3, hidden_size=128,
                        output_size=num_classes).to(device)
    model.eval()

    with torch.no_grad():
        # Get a batch of data
        images, masks, metadata, labels, indices = next(iter(dataloader))
        images = images.to(device)
        labels = labels.to(device)

        # Get embeddings
        embeddings = model.get_embedding(images)

        # Create triplets from the batch
        triplets_real = []
        for i in range(len(embeddings)):
            anchor_emb = embeddings[i]
            anchor_label = labels[i]

            # Find positive (same class)
            pos_indices = (labels == anchor_label).nonzero(as_tuple=True)[0]
            if len(pos_indices) > 1:  # Need at least 2 samples of same class
                # Different sample, same class
                pos_idx = pos_indices[pos_indices != i][0]
                positive_emb = embeddings[pos_idx]

                # Find negative (different class)
                neg_indices = (labels != anchor_label).nonzero(
                    as_tuple=True)[0]
                if len(neg_indices) > 0:
                    neg_idx = neg_indices[0]
                    negative_emb = embeddings[neg_idx]

                    triplets_real.append(
                        (anchor_emb, positive_emb, negative_emb))

        if triplets_real:
            loss_real = compute_triplet_loss_func(triplets_real, margin=1.0)
            print(f"✓ Real model triplet loss: {loss_real.item():.4f}")
            print(f"✓ Number of valid triplets created: {len(triplets_real)}")
            assert loss_real.item() >= 0, "Real model triplet loss should be non-negative"
        else:
            print(
                "⚠ Could not create valid triplets from batch (need multiple samples per class)")

    # Test 6: Gradient computation
    print("\nTest 6: Gradient computation")

    # Create embeddings that require gradients
    anchor_grad = torch.randn(3, requires_grad=True, device=device)
    positive_grad = torch.randn(3, requires_grad=True, device=device)
    negative_grad = torch.randn(3, requires_grad=True, device=device)

    triplets_grad = [(anchor_grad, positive_grad, negative_grad)]
    loss_grad = compute_triplet_loss_func(triplets_grad, margin=1.0)

    # Compute gradients
    loss_grad.backward()

    print(f"✓ Loss with gradients: {loss_grad.item():.4f}")
    print(f"✓ Anchor gradient norm: {anchor_grad.grad.norm().item():.4f}")
    print(f"✓ Positive gradient norm: {positive_grad.grad.norm().item():.4f}")
    print(f"✓ Negative gradient norm: {negative_grad.grad.norm().item():.4f}")

    assert anchor_grad.grad is not None, "Gradients should be computed"

    print("\n" + "=" * 50)
    print("✅ All triplet loss tests passed!")
    print("✅ Triplet loss implementation is working correctly!")

    return True
