"""
GraphVisualizer: Non-interactive visualization of NetworkX graphs in Jupyter or VSCode.

Author: Vojtěch Drahý
Copyright (c) 2025 Vojtěch Drahý. All rights reserved.
"""

import networkx as nx
import matplotlib.pyplot as plt


class GraphVisualizer:
    """
    Simple class for static visualization of NetworkX graphs with customizable node/edge colors and labels.
    """

    def __init__(self, graph: nx.Graph, node_color="skyblue", edge_color="gray", with_labels=True, figsize=(8,6), node_size=300):
        """
        Initialize the GraphVisualizer.

        Parameters
        ----------
        graph : nx.Graph
            The NetworkX graph to visualize.
        node_color : str or list
            Color of the nodes.
        edge_color : str or list
            Color of the edges.
        with_labels : bool
            Whether to draw node labels.
        figsize : tuple
            Size of the matplotlib figure.
        node_size : int
            Size of the nodes.
        """
        self.graph = graph
        self.node_color = node_color
        self.edge_color = edge_color
        self.with_labels = with_labels
        self.figsize = figsize
        self.node_size = node_size

    def _get_layout(self, layout):
        """Return positions of nodes according to the chosen layout."""
        if layout == "spring":
            return nx.spring_layout(self.graph, seed=42)
        elif layout == "circular":
            return nx.circular_layout(self.graph)
        elif layout == "kamada_kawai":
            return nx.kamada_kawai_layout(self.graph)
        elif layout == "shell":
            return nx.shell_layout(self.graph)
        elif layout == "spectral":
            return nx.spectral_layout(self.graph)
        else:
            raise ValueError(f"Unknown layout: {layout}")

    def show(self, layout="spring"):
        """
        Render the graph using matplotlib with a specified layout.

        Parameters
        ----------
        layout : str
            Layout type: 'spring', 'circular', 'kamada_kawai', 'shell', 'spectral'.
        """
        pos = self._get_layout(layout)
        plt.figure(figsize=self.figsize)
        nx.draw(
            self.graph,
            pos,
            node_color=self.node_color,
            edge_color=self.edge_color,
            with_labels=self.with_labels,
            node_size=self.node_size
        )
        plt.show()
