import enum
from typing import Generator, Optional
import graphviz
import random

class TraversalType(enum.Enum):
    PREORDER=1
    INORDER=2
    POSTORDER=3

class Node:
    """ A node in a binary graph. Has explicitly named left and right child nodes.

    I.e., the tree is represented by explicit neighbour relationship. """

    COUNTER = 0

    def __init__(self, key:Optional[int] = None):
        """ The constructor. Sets up the internal structure of the node. """

        if key is None:
            key = Node.COUNTER

        self.left:Optional[Node] = None
        """ Reference to the left child node. """

        self.right:Optional[Node] = None
        """ Reference to the right child node. """

        self.key:int = key
        """ Information held by this Node. """

        Node.COUNTER += 1
        """ Increment a counter of the nodes. """

    def is_leaf(self)->bool:
        """ Computes whether the node is a leaf node. """
        raise NotImplementedError

    def is_regular(self)->bool:
        """ Computes whether the node is a leaf node. """
        raise NotImplementedError

    def iterate_nodes(self, order:TraversalType=TraversalType.INORDER, current_depth:int=0) -> Generator[tuple[int, "Node"], None, None]:
        """ Uses a generator (via `yield` keyword) to produce a sequence of nodes in given ordering. Each node is annotated with its depth w.r.t. the caller. """
        raise NotImplementedError

    def iterate_leaves(self) -> Generator[tuple[int, "Node"], None, None]:
        """ Same as `iterate_nodes`, only limited to leafs. Use `yield from` idiom. """
        raise NotImplementedError

    def count_nodes(self) -> int:
        """ How many nodes in this sub tree. """
        raise NotImplementedError

    def count_leaves(self) -> int:
        """ How many nodes in this sub tree. """
        raise NotImplementedError

    def count_depth(self):
        """ How deep is this sub-tree. """
        raise NotImplementedError

    def count_depth_min(self):
        """ Depth of shallow-most empty child slot. """
        raise NotImplementedError

    def is_balanced(self, strict:bool=True)->bool:
        """ Computes whether the node is a root of a balanced (sub) tree. if strict == False, only leaves which are actually present need to be considered. """
        raise NotImplementedError

    def flip_subtrees(self) -> None:
        """ Recursively interchanges all left and right sub-trees. """
        raise NotImplementedError

    def append_unbalanced(self, depth:int):
        """ Appends a maximally unbalanced tree. """
        raise NotImplementedError

    def append_rnd_subtree(self, depth:int, branch_probability:float=0.5) -> "Node":
        """ A subtree of requested depth is branched from this node. """
        if depth <= 0:
            return self
        assert 0 <= branch_probability <= 1

        # Left subtree
        if random.random() < branch_probability:
            self.left = Node(Node.COUNTER).append_rnd_subtree(depth-1)
        else:
            self.left = None

        # Right subtree
        if random.random() < branch_probability:
            self.right = Node(Node.COUNTER).append_rnd_subtree(depth-1)
        else:
            self.right = None
        return self

    def print_horizontal(self, depth:int=0) -> None:
        """ Prints the graph vertically. Quite simple. """
        for depth, node in self.iterate_nodes():
            print(" "*depth + "|> " + ", ".join([
                f"key={node.key}",
                f"depth={node.count_depth()}",
                f"mindepth={node.count_depth_min()}",
                f"is_balanced={node.is_balanced()}",
                f"is_balanced(strict)={node.is_balanced(True)}",
                f"is_leaf={node.is_leaf()}",
                f"is_regular={node.is_regular()}"
                ]))

    def to_graphviz(self, parent:Optional[str] = None, dot:Optional[graphviz.Digraph] = None) -> graphviz.Digraph:
        attrs = {}
        attrs["fontcolor"] = "red"
        if dot is None:
            dot = graphviz.Digraph()
        if self.is_leaf():
            attrs["style"] = "filled"
            attrs["fillcolor"] = "green"
            attrs["fontcolor"] = "darkgreen"
        if self.is_balanced():
            attrs["fontcolor"] = "black"
        if self.is_regular():
            attrs["shape"] = "hexagon"
        dot.node(str(self.key), str(self.key), **attrs)
        if parent is not None:
            dot.edge(parent, str(self.key))

        if self.left is not None:
            self.left.to_graphviz(parent=str(self.key), dot=dot)
        else:
            dot.node(str(self.key)+"l", "x", shape="plaintext")
            dot.edge(str(self.key), str(self.key)+"l")
        if self.right is not None:
            self.right.to_graphviz(parent=str(self.key), dot=dot)
        else:
            dot.node(str(self.key)+"r", "x", shape="plaintext")
            dot.edge(str(self.key), str(self.key)+"r")
        return dot

if __name__ == "__main__":
    import sys
    seed = random.randrange(sys.maxsize)
    seed = 8076561028915218315
    random.seed(seed)
    print(f"seed was ", seed)

    binary_tree_root = Node()
    binary_tree_root.append_rnd_subtree(5)
    binary_tree_root.print_horizontal()
    binary_tree_root.to_graphviz().render("graph")
    print(binary_tree_root.count_leaves())
    for o in TraversalType:
        print(o)
        print(" ".join(map(lambda x : str(x[1]), binary_tree_root.iterate_nodes(o))))

