# ============================================================================
#                         N O D E
# ============================================================================

class Node:
    def __init__(self, key):

        self.left = None     # mandatory
        self.right = None    # mandatory
        self.key = key       # mandatory

        # optional
        self.tag = ' ' # one character for demonstation purposes here

# ============================================================================
#                         B I N A R Y   T R E E
# ============================================================================

class BinaryTree:

    # .......Constructor ...........................................
    # Create initial binary tree with one node -- the root
    def __init__( self, key = 0 ):
        self.root = Node( key )

    # ........................................................................
    #       S E R V I C E     F U N C T O N S
    # ......IMPORTED..........................................................
    # The following BinaryTree methods are part of the present class BinaryTree,
    # they are stored separately in imported files, to improve readability of this file,
    from _treebuild_   import randomTree, addNode, addNodes
    from _treedisplay_ import display
    # def randomTree( self, node, depth ):
    # def addnode( self, parentKey, nodeKey ):
    # def addNodes( self, nodePairs ):          # pair == (parentKey, nodeKey)
    # def display( self ):
    # For purely technical reasons, other private functions are additionally imported
    from _treebuild_   import _addnoder
    from _treedisplay_ import _setXcoord, _countNodes
    # .......................................................................


    # ........................................................................
    #   C U S T O M   F U N C T I O N (S),   T E S T E D  I N   M A I N (below)
    # .............................................................. .........

    def countAllNodes(self, node):
        if node == None: return 0
        total = 1+ self.countAllNodes(node.left) +  self.countAllNodes(node.right)
        return  total

    def countInternalNodes(self, node):
        if node == None: return 0
        if node.left == None and node.right == None: return 0  # neglect a leaf
        return       self.countInternalNodes(node.left)\
               + 1 + self.countInternalNodes(node.right)

    def countLeaves(self, node):
        if node == None: return 0
        if node.left == None and node.right == None: return 1  # this is a leaf
        totalL = self.countLeaves(node.left)
        totalR = self.countLeaves(node.right)
        print( "in node", node.key, " the number of leaves in L and R subtree is:", totalL, totalR)
        return  totalL + totalR

    def countNodesWith1child(self, node):
        if node == None: return 0
        if node.left == None and node.right != None:
            return 1 + self.countNodesWith1child( node.right )
        if node.right == None and node.left != None:
            return 1 + self.countNodesWith1child( node.left )
        return   self.countNodesWith1child(node.left) \
               + self.countNodesWith1child(node.right)

    def countNodesWith2children(self, node):
        if node == None: return 0
        if node.left != None and node.right != None:
            # print( " node with 2 children, key:", node.key)
            return 1 + self.countNodesWith2children(node.left) \
                     + self.countNodesWith2children(node.right)
        return     self.countNodesWith2children(node.left) \
                 + self.countNodesWith2children(node.right)

    # a little bit more compact version of the previous function
    def countNodesWith2children_b(self, node):
        if node == None: return 0
        countBelow = self.countNodesWith2children_b(node.left) \
                   + self.countNodesWith2children_b(node.right)
        if node.left != None and node.right != None:
            countBelow += 1
        return countBelow

    # =================================
    #    End of class BinaryTree
    # =================================


# ............................................................................
#                M A I N   P R O G R A M
# ............................................................................

t = BinaryTree( )

print( "Random tree" )
t.randomTree( t.root, 4 )  # 2nd param is maximum depth of the tree

print( "Display ")
t.display()


# example function calls

print( "The number of all nodes is", t.countAllNodes(t.root) )

print( "The number of leaves is", t.countLeaves(t.root) )

print( "The number of internal nodes is", t.countInternalNodes(t.root) )

print( "The number of nodes with 1 child is", t.countNodesWith1child(t.root) )

print( "The number of nodes with 2 children is", t.countNodesWith2children(t.root) )



























