Question

What type of tree traversal does ast use (specifically ast.NodeVisitor())? When I created a stack and pushed each node that was traversed into the stack the results seemed to be a 'breadth first' tree traversal. Meaning that the order was dependent on the level in the tree.

Ex. Tree looks like

Module
  Assign
    Name
      Store
    Call
      Attribute
        Str
        Load

and the stack looks like

[Module,Assign,Name,Call,Store,Attribute,Str,Load]

Ex. Code

stack = []
class a(ast.NodeTransformer):
    def visit_Num(self,node):
        stack.append(node)
        ...
        return node

    ...                      #this is all the other visit_*() functions

    def visit_Str(self,node):
        stack.append(node)
        ...
        return node

if __name__ == "__main__":
    with open('some_file.py','r') as pt:
        tree = ast.parse(pt)
    new_tree = a()
    new_tree_edit = ast.fix_missing_locations(new_tree.visit(tree)) # I have tried with and without calling fix_missing_locations and got the same results.
    print stack
Was it helpful?

Solution

The ast.walk() function walks the tree breath-first; see the ast.py source:

def walk(node):
    """
    Recursively yield all descendant nodes in the tree starting at *node*
    (including *node* itself), in no specified order.  This is useful if you
    only want to modify nodes in place and don't care about the context.
    """
    from collections import deque
    todo = deque([node])
    while todo:
        node = todo.popleft()
        todo.extend(iter_child_nodes(node))
        yield node

New nodes are pushed into a queue, the next node being traversed is the front of the queue.

If you wanted a depth-first traversal, use a subclass of ast.NodeVisitor() instead; it will walk the tree using recursion; NodeVisitor.visit() calls NodeVisitor.generic_visit() unless a more node-specific visitor method is defined, and NodeVisitor.generic_visit() calls NodeVisitor.visit() again for child nodes:

class NodeVisitor(object):
    """
    A node visitor base class that walks the abstract syntax tree and calls a
    visitor function for every node found.  This function may return a value
    which is forwarded by the `visit` method.

    This class is meant to be subclassed, with the subclass adding visitor
    methods.

    Per default the visitor functions for the nodes are ``'visit_'`` +
    class name of the node.  So a `TryFinally` node visit function would
    be `visit_TryFinally`.  This behavior can be changed by overriding
    the `visit` method.  If no visitor function exists for a node
    (return value `None`) the `generic_visit` visitor is used instead.

    Don't use the `NodeVisitor` if you want to apply changes to nodes during
    traversing.  For this a special visitor exists (`NodeTransformer`) that
    allows modifications.
    """

    def visit(self, node):
        """Visit a node."""
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        for field, value in iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, AST):
                        self.visit(item)
            elif isinstance(value, AST):
                self.visit(value)

If you do subclass NodeVisitor (or it's derived version, NodeTransformer), do remember to also call super(YourClass, self).generic_visit(node) in your specific visit_* methods to continue to traverse the tree.

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top