# -----------------------------------------------------------------------------
# Copyright (c) 2025
# Author: Vojtěch Drahý
#
# Permission is hereby granted to use, modify, and distribute this code
# for personal, academic, and research purposes under a proper citation.
# -----------------------------------------------------------------------------

import matplotlib.pyplot as plt
import numpy as np
from netgraph import Graph


def draw_mealy(mealy, states, alphabet, initial_state=None):
    """
    Visualize a Mealy machine using netgraph.
    - Nodes are colored.
    - Edges inherit the color of the source node.
    - Edge labels inherit the color of their edge.
    - Bidirectional edges are separated by a small vertical label offset.
    - Self-loops get their label above the loop rather than on the node.
    """

    # Collect edges and merge labels of multi-input transitions
    edges = []
    merged_labels = {}

    for s in states:
        for a in alphabet:
            nxt, out = mealy[s][a]
            edges.append((s, nxt))
            merged_labels.setdefault((s, nxt), []).append(f"{a}/{out}")

    merged_labels = {
        edge: ", ".join(lbls) for edge, lbls in merged_labels.items()
    }

    # Deterministic color assignment for nodes
    base_colors = [
        "tab:blue", "tab:red", "tab:green",
        "tab:purple", "tab:orange", "tab:brown",
        "tab:pink", "tab:gray", "tab:olive", "tab:cyan"
    ]

    node_color_map = {
        s: base_colors[i % len(base_colors)]
        for i, s in enumerate(states)
    }

    # Edge color = color of the source node
    edge_colors = {(u, v): node_color_map[u] for (u, v) in edges}

    fig, ax = plt.subplots(figsize=(15, 15))

    # Draw the graph with straight edges
    nodes = list({u for u,v in edges} | {v for u,v in edges})
    n = len(nodes)

    pos = {
        node: (np.cos(2*np.pi*i/n), np.sin(2*np.pi*i/n))
        for i, node in enumerate(nodes)
    }

    graph_artist = Graph(
        edges,
        node_layout=pos,
        edge_layout="straight",
        node_color=node_color_map,
        node_size=5.0,
        node_edge_width=0.06,
        edge_color=edge_colors,
        edge_width=2.0,
        arrows=True,
        ax=ax,
    )

    pos = graph_artist.node_positions

    # Draw node labels inside the nodes
    for s in states:
        x, y = pos[s]
        text_color = (
            "white"
            if node_color_map[s] in ["tab:blue", "tab:red", "tab:purple"]
            else "black"
        )
        ax.text(
            x, y,
            s,
            fontsize=15,
            ha="center", va="center",
            color=text_color
        )

    # Draw edge labels
    for (u, v), label in merged_labels.items():
        x1, y1 = pos[u]
        x2, y2 = pos[v]

        if u == v:
            # Self-loop: label above the loop
            xm = x1
            ym = y1 + 0.12
        else:
            # Regular edge: midpoint of the segment
            xm = (x1 + x2) / 2
            ym = (y1 + y2) / 2

            # Slight offset for bidirectional pairs
            if (v, u) in merged_labels:
                ym += 0.015 if u < v else -0.015

        ax.text(
            xm, ym,
            label,
            fontsize=15,
            color=edge_colors[(u, v)],
            ha="center", va="center"
        )

    # Highlight initial state
    if initial_state is not None:
        x, y = pos[initial_state]
        ax.scatter([x], [y], s=1300,
                   facecolors="none", edgecolors="black", linewidths=3)

    ax.axis("off")
    plt.show()
