class Node:
    def __init__(self, key, color, leftlabel, rightlabel):
        self.key = key
        self.color = color
        self.leftlabel = leftlabel;
        self.rightlabel = rightlabel;


def countPairsInList(listOfNodes):
    pairs = 0;
    lastkeywhite = -1 # impossible key values
    lastkeyblack = -1
    for node in listOfNodes:
        if node.color == 0:
            if lastkeywhite == node.key: pairs += 1
            lastkeywhite = node.key
        if node.color == 1:
            if lastkeyblack == node.key: pairs += 1
            lastkeyblack = node.key
    return pairs


def findPairs(nodes, rootlabel):
    currlevel = [nodes[rootlabel]]
    nextlevel = []
    pairsCount = 0
    while len(currlevel) > 0:
        # number of pairs in current level
        pairsCount += countPairsInList(currlevel)
        # build next level
        for node in currlevel:
            if node.leftlabel != 0:
                nextlevel.append(nodes[node.leftlabel])
            if node.rightlabel != 0:
                nextlevel.append(nodes[node.rightlabel])
        # nextlevel becomes currentlevel
        currlevel = nextlevel
        nextlevel = []
    return pairsCount

# ____________________________________________________________________________
#                              M A I N


# read and build the tree
N, rootlabel = map(int, input().split())
nodes = [None] * (N+1)
for i in range(N):
    label, key, leftkey, rightkey, color = map(int, input().split())
    nodes[label] = Node(key, color, leftkey, rightkey)

# produce result
print(findPairs(nodes, rootlabel))









