Question

Please consider the (example) code below before I get to my specific question regarding visitor pattern in python:

class Node:
    def __init__(self):
        self.children = []
    def add(self, node):
        self.children.append(node)
    def check(self):
        print("Node")
        return True
    def accept(self, visitor):
        visitor.visit(self)

class NodeA(Node):
    def check(self):
        print("NodeA")
        return True
class NodeB(Node):
    def check(self):
        print("NodeB")
        return True


class NodeA_A(NodeA):
    def check(self):
        print("NodeA_A")
        return True
class NodeA_B(NodeA):
    def check(self):
        print("NodeA_B")
        return True

class NodeA_A_A(NodeA_A):
    def check(self):
        print("NodeA_A_A")
        return False

class NodeRunner:
    def visit(self, node):
        node.check()
        if len(node.children) > 0:
            for child in node.children:
                child.accept(self)

if __name__ == "__main__":
    n = Node()
    n1 = NodeA()
    n2 = NodeB()
    n11 = NodeA_A()
    n12 = NodeA_B()
    n111 = NodeA_A_A()

    n.add(n1)
    n.add(n2)

    n1.add(n11)
    n1.add(n12)

    n11.add(n111)

    v = NodeRunner()
    v.visit(n)

When I run it, it traverse all the nodes-classes iteratively and returns the following:

Node
NodeA
NodeA_A
NodeA_A_A
NodeA_B
NodeB

This is all fine but now to my question. You may have noticed that each check-method returns a Boolean (lets say this is a complicated method in reality).

In the example above every check-method inside Node classes return True except NodeA_A_A. I would like to store this somehow during visiting so I can fail all the base classes.

This is hard to explain let me illustrate:

  • if NodeA_A_A returns False, then I would like to fail NodeA_A, NodeA and Node. regardless of what these classes return.
  • if NodeB returns False, then I would like to fail Node. regardless of what other classes return.

So if a child-class is somewhere failing (check method returns False), I would like to fail all its base classes.

Does anyone have any ideas?

Was it helpful?

Solution 2

I used the visitor pattern to visit all the nodes. One visitor visits and runs all the nodes, the other visitor bubbles up the result. The code and output is provided below:

class Node(object):
    def __init__(self):
        self.children = []
        self.result = None
    def add(self, node):
        self.children.append(node)
    def check(self):
        self.result = True
        print "Node: result:%s" % self.result
        return self.result
    def accept(self, visitor):
        visitor.visit(self)

class Node_A(Node):
    def __init__(self):
        super(Node_A, self).__init__()
    def check(self):
        self.result = True
        print "Node_A: result:%s" % self.result
        return self.result

class Node_A_A(Node_A):
    def __init__(self):
        super(Node_A_A, self).__init__()
    def check(self):
        self.result = True
        print "Node_A_A: result:%s" % self.result
        return self.result

class Node_A_B(Node_A):
    def __init__(self):
        super(Node_A_B, self).__init__()
    def check(self):
        self.result = True
        print "Node_A_B: result:%s" % self.result
        return self.result

class Node_A_A_A(Node_A_A):
    def __init__(self):
        super(Node_A_A_A, self).__init__()
    def check(self):
        self.result = True
        print "Node_A_A_A: result:%s" % self.result
        return self.result
class Node_A_A_B(Node_A_A):
    def __init__(self):
        super(Node_A_A_B, self).__init__()
    def check(self):
        self.result = False
        print "Node_A_A_B: result:%s" % self.result
        return self.result

class Node_A_B_A(Node_A_B):
    def __init__(self):
        super(Node_A_B_A, self).__init__()
    def check(self):
        self.result = True
        print "Node_A_B_A: result:%s" % self.result
        return self.result



class NodeRunner:
    def visit(self, node):
        if len(node.children) > 0:
            for child in node.children:
                child.accept(self)
        node.check()

class NodeChecker:
    def visit(self, node):
        results = []
        if len(node.children) > 0:
            for child in node.children:
                child.accept(self)
                results.append(child.result)
            node.result = all(results)


if __name__ == "__main__":
    node = Node()
    node_a = Node_A()
    node_a_a = Node_A_A()
    node_a_b = Node_A_B()

    node_a_a_a = Node_A_A_A()
    node_a_a_b = Node_A_A_B()

    node_a_b_a = Node_A_B_A()

    node.add(node_a)

    node_a.add(node_a_a)
    node_a_a.add(node_a_a_a)
    node_a_a.add(node_a_a_b)

    node_a.add(node_a_b)
    node_a_b.add(node_a_b_a)


    print("-------------------")
    nVisitor = NodeRunner()
    nVisitor.visit(node)
    print("-------------------")
    nVisitor = NodeChecker()
    nVisitor.visit(node)
    print("-------------------")
    print "node_a_a_a: result: %s" % node_a_a_a.result
    print "node_a_a_b: result: %s" % node_a_a_b.result
    print "node_a_a: result: %s" % node_a_a.result
    print "node_a_b_a: result: %s" % node_a_b_a.result
    print "node_a_b: result: %s" % node_a_b.result
    print "node_a: result: %s" % node_a.result
    print "node: result: %s" % node.result

The output is provided below:

-------------------
Node_A_A_A: result:True
Node_A_A_B: result:False
Node_A_A: result:True
Node_A_B_A: result:True
Node_A_B: result:True
Node_A: result:True
Node: result:True
-------------------
-------------------
node_a_a_a: result: True
node_a_a_b: result: False
node_a_a: result: False
node_a_b_a: result: True
node_a_b: result: True
node_a: result: False
node: result: False

OTHER TIPS

It seems that what you are asking for is not about visitor patter, but about how to implement a depth-first search algorithm. Here is my solution for your question:

class Node:
    def __init__(self):
        self.children = []
    def add(self, node):
        self.children.append(node)
    def check(self):
        print("Node")
        return True
    def accept(self, visitor):
        return visitor.visit(self)

class NodeA(Node):
    def check(self):
        print("NodeA")
        return True
class NodeB(Node):
    def check(self):
        print("NodeB")
        return True


class NodeA_A(NodeA):
    def check(self):
        print("NodeA_A")
        return True
class NodeA_B(NodeA):
    def check(self):
        print("NodeA_B")
        return True

class NodeA_A_A(NodeA_A):
    def check(self):
        print("NodeA_A_A")
        return False

class NodeRunner:
    def visit(self, node):
        ret = True
        # visit all children
        for child in node.children:
            v = child.accept(self)
            if not v and ret:  # if some child not accepted, then we think that the parent node should also not be accepted
                ret = False

        # check the node
        if not node.check():
            ret = False
        return ret

if __name__ == "__main__":
    n = Node()
    n1 = NodeA()
    n2 = NodeB()
    n11 = NodeA_A()
    n12 = NodeA_B()
    n111 = NodeA_A_A()

    n.add(n1)
    n.add(n2)

    n1.add(n11)
    n1.add(n12)

    n11.add(n111)

    v = NodeRunner()
    print v.visit(n)
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top