# =============================================================================
#  WebsiteGraphGenerator
#  ---------------------
#  Generates a reproducible synthetic web graph with node and edge attributes.
#
#  (c) 2025 Vojtěch Drahý
#  All rights reserved.
#
# =============================================================================

import random
import numpy as np
import networkx as nx


class WebsiteGraphGenerator:
    """
    A class for generating reproducible synthetic directed graphs representing
    websites and their interconnections.

    Each node corresponds to a website with attributes such as category, region,
    language, visit counts, and creation year. Edges represent
    directional links between sites with weighted properties.

    Attributes
    ----------
    num_nodes : int
        Number of nodes (websites) in the generated graph.
    seed : int
        Random seed to ensure reproducibility.
    categories : list[str]
        Possible website categories.
    regions : list[str]
        Possible regions.
    languages : list[str]
        Possible language codes.

    Methods
    -------
    generate_graph() -> nx.DiGraph:
        Generates and returns a directed graph with the specified configuration.
    """

    def __init__(self, num_nodes: int = 20, seed: int = 42) -> None:
        """Initialize the generator and set reproducibility seeds."""
        self.num_nodes = num_nodes
        self.seed = seed

        # Set random seeds for reproducibility
        random.seed(seed)
        np.random.seed(seed)

        # Predefined categories and metadata
        self.categories = ['news', 'social', 'blog', 'shop', 'edu', 'forum']
        self.regions = ['US', 'EU', 'ASIA']
        self.languages = ['en', 'cz', 'de', 'fr', 'zh']

        # Maintain a running counter for generating unique domain names
        self._counters = {c: 1 for c in self.categories}

        # Define average out-degree (number of outgoing links) by category
        self._category_out_lambda = {
            'news': 6,
            'social': 8,
            'blog': 4,
            'shop': 3,
            'edu': 4,
            'forum': 5
        }

    def _gen_domain(self, category: str) -> str:
        """
        Generate a synthetic domain name based on the category type.

        Parameters
        ----------
        category : str
            The website category.

        Returns
        -------
        str
            A unique domain name (e.g., "news-site1.com").
        """
        # Map category to specific domain suffixes
        suffixes = {
            'news': '.com',
            'social': '.net',
            'blog': '.blog',
            'shop': '.com',
            'edu': '.org',
            'forum': '.io'
        }
        name_prefix = {
            'news': 'news-site',
            'social': 'social',
            'blog': 'author',
            'shop': 'shop',
            'edu': 'edu',
            'forum': 'forum'
        }
        # Generate a unique name and increment category counter
        name = f"{name_prefix[category]}{self._counters[category]}"
        self._counters[category] += 1
        return name + suffixes[category]

    def generate_graph(self) -> nx.DiGraph:
        """
        Generate a directed graph with nodes and weighted edges.

        Returns
        -------
        nx.DiGraph
            A directed graph where each node represents a website and edges
            represent weighted hyperlinks between them.
        """
        G = nx.DiGraph()

        # ----------------------------
        # Node generation
        # ----------------------------
        for _ in range(self.num_nodes):
            cat = random.choices(
                self.categories,
                weights=[0.2, 0.15, 0.2, 0.15, 0.15, 0.15]
            )[0]
            domain = self._gen_domain(cat)
            region = random.choices(self.regions, weights=[0.5, 0.3, 0.2])[0]
            language = random.choice(self.languages) if random.random() > 0.2 else 'en'
            visits = int(np.random.lognormal(mean=8, sigma=1.2))
            creation_year = random.randint(1998, 2024)

            # Add node with detailed attributes
            G.add_node(
                domain,
                category=cat,
                region=region,
                language=language,
                visits_per_day=visits,
                creation_year=creation_year
            )

        # ----------------------------
        # Edge generation
        # ----------------------------
        all_nodes = list(G.nodes)
        for src in all_nodes:
            cat = G.nodes[src]['category']
            # Number of outgoing edges determined by Poisson distribution
            k = np.random.poisson(self._category_out_lambda[cat])
            k = max(2, min(k, 15))

            targets = [t for t in all_nodes if t != src]
            probs = []

            # Compute weighted probabilities for link targets
            for t in targets:
                prob = 1.0

                # Category-based linking biases
                if cat == 'news' and G.nodes[t]['category'] in ['news', 'social']:
                    prob *= 2.0
                if cat == 'shop' and G.nodes[t]['category'] in ['blog', 'forum']:
                    prob *= 1.8
                if cat == 'social' and G.nodes[t]['category'] in ['news', 'blog']:
                    prob *= 1.5

                # Older domains are slightly more likely to receive links
                age = 2025 - G.nodes[t]['creation_year']
                prob *= (1 + age / 30)

                probs.append(prob)

            # Sample target nodes without replacement, weighted by probability

            probs = np.array(probs, dtype=float)
            total = np.sum(probs)

            # Handle degenerate cases where all probabilities are zero or invalid
            if total <= 0 or not np.isfinite(total):
                # Uniform fallback if probabilities are not usable
                p = np.ones(len(targets)) / len(targets)
            else:
                # Normalize probabilities safely
                p = probs / total
                # Replace any NaN with uniform probability
                if np.any(~np.isfinite(p)):
                    p = np.ones(len(targets)) / len(targets)

            # Weighted sampling without replacement
            idxs = np.random.choice(len(targets), size=k, replace=False, p=p)
            chosen_targets = [targets[i] for i in idxs]

            # Create edges with random weights and anchor scores
            for t in chosen_targets:
                G.add_edge(src, t)

        return G



# Example usage:
if __name__ == "__main__":
    import matplotlib.pyplot as plt

    generator = WebsiteGraphGenerator(num_nodes=20, seed=42)
    graph = generator.generate_graph()
    nx.draw_spring(graph, with_labels=True)
    plt.show()
