#!/usr/bin/env python
"""ASTutils provides utilities for working with AST objects.
"""
# (c) 2005 Chad Whitacre
# This program is beerware. If you like it, buy me a beer someday.
# No warranty is expressed or implied.
__version__ = '0.2'
__author__ = 'Chad Whitacre '
import parser, token, symbol
from os import linesep
from pprint import pformat
from StringIO import StringIO
class ASTutilsException(Exception):
"""This represents an error in one of the ASTutils methods.
"""
class ASTutils:
"""This class holds four utilities for working with Python syntax trees.
Syntax trees are the output of the Python parser, which is mimicked in the
standard library's parser module. The parser module trades in AST objects,
which represent "abstract syntax trees." The parser module can convert ASTs
to list and tuple representations. I believe that these are considered
"concrete syntax trees," or CSTs. The compiler module also works with ASTs,
but I believe at a higher level. This module uses parser, but not compiler.
Where an 'st' argument is called for in this module, you may provide either
an AST object, or a list or tuple as produced by parser.ast2list and
parser.ast2tuple. This module uses the term "cst fragment" to refer to a
fragment of a syntax tree sequence that does not represent a well-formed
Python structure, e.g., the tree does not begin with one of: single_input,
file_input, or eval_input. Unless otherwise mentioned, syntax trees provided
as 'st' arguments must be well-formed; they may not be fragments.
For the record, I implemented these as classmethods rather than as
module-level functions because they call each other and I didn't want to
think about the order they were defined in.
"""
def _standardize_st(self, st, format='tuple'):
"""Given a syntax tree and a desired format, return the tree in that
format.
"""
# convert the incoming ast/cst into an AST
if type(st) is type(parser.suite('')):
ast = st
else:
if type(st) in (type(()), type([])):
ast = parser.sequence2ast(st)
else:
raise ASTutilsException, "incoming type unrecognized: " +\
repr(type(st))
# return the tree in the desired format
formats = { 'tuple' : ast.totuple
, 'list' : ast.tolist
, 'ast' : lambda: ast
}
outgoing = formats.get(format.lower())
if outgoing is None:
raise ASTutilsException, "requested format unrecognized: " + format
return outgoing()
_standardize_st = classmethod(_standardize_st)
def ast2read(self, st):
"""Given a syntax tree, return a more human-readable representation of
the tree than is returned by parser.ast2list and parser.ast2tuple.
Usage:
>>> import parser
>>> ast = parser.suite("print 'hello world'")
>>> print parser.ast2list(ast)
[257, [266, [267, [268, [271, [1, 'print'], [298, [299, [300, [301, [303, [304, [305, [306, [307, [308, [309, [310, [311, [3, "'hello world'"]]]]]]]]]]]]]]]], [4, '']]], [0, '']]
>>> print ASTutils.ast2read(ast)
['file_input',
['stmt',
['simple_stmt',
['small_stmt',
['print_stmt',
['NAME', 'print'],
['test',
['and_test',
['not_test',
['comparison',
['expr',
['xor_expr',
['and_expr',
['shift_expr',
['arith_expr',
['term',
['factor',
['power',
['atom', ['STRING', "'hello world'"]]]]]]]]]]]]]]]],
['NEWLINE', '']]],
['ENDMARKER', '']]
"""
# define our recursive function
def walk(cst):
"""Given an AST list, recursively walk it and replace the nodes with
human-readable equivalents.
"""
for node in cst:
if type(node) is type([]):
# we have a list of subnodes; recurse
walk(node)
else:
# we have an actual node; interpret it and store the result
if type(node) is type(0):
if node < 256:
readable_node = token.tok_name[node]
else:
readable_node = symbol.sym_name[node]
else:
readable_node = node
cst[cst.index(node)] = readable_node
# ggg!
TREE = self._standardize_st(st, 'list')
walk(TREE)
return pformat(TREE)
ast2read = classmethod(ast2read)
def ast2text(self, st):
"""Given a syntax tree, return an approximation of the source code that
generated it. The approximation will only differ from the original in
non-essential whitespace and missing comments.
Usage:
>>> import parser
>>> from ASTutils import ASTutils
>>> ast = parser.suite("print 'hello world'")
>>> print ASTutils.ast2text(ast)
print 'hello world'
>>> # Here's an example of whitespace differences (this also tests
>>> # multiple indent levels):
>>>
>>> from os import linesep as lf
>>> block = "def foo():"+lf+" if 1:"+lf+" return True"
>>> # note no whitespace around parens/colons and one-space indents
>>>
>>> ast = parser.suite(block)
>>> text = ASTutils.ast2text(ast)
>>>
>>> # account for the fact that I trim trailing spaces in my editor
>>> text = linesep.join([l.rstrip() for l in text.split(linesep)])
>>> print text
def foo ( ) :
if 1 :
return True
>>> # note extra spacing around parens/colons and four-space indents
"""
class walker:
TEXT = ''
INDENT_LEVEL = 0
def walk(self, cst):
"""Given an AST tuple (a CST?), recursively walk it and
assemble the nodes back into a text code block.
"""
for node in cst:
if type(node) is type(()):
# we have a tuple of subnodes; recurse
self.walk(node)
else:
# we have an actual node; interpret it and store the
# result
text = '' # default
if type(node) is type(''):
if node <> '':
node += ' ' # insert some whitespace
if not node.startswith('#'):
text = node # ignore comments
elif node == token.NEWLINE:
text = linesep + self.indent()
elif node == token.INDENT:
self.INDENT_LEVEL += 1
text = ' '
elif node == token.DEDENT:
self.INDENT_LEVEL -= 1
self.dedent()
self.TEXT += text
walk = classmethod(walk)
def indent(self):
if self.INDENT_LEVEL == 0:
return ''
else:
return self.INDENT_LEVEL * ' '
indent = classmethod(indent)
def dedent(self):
self.TEXT = self.TEXT[:-4]
dedent = classmethod(dedent)
cst = self._standardize_st(st, 'tuple')
walker.walk(cst)
# trim a possible trailing newline and/or space; this is necessary to
# make the doctest work
output = walker.TEXT
if output.endswith(linesep): output = output.rstrip(linesep)
if output.endswith(' '): output = output[:-1]
return output
ast2text = classmethod(ast2text)
def getnodes(self, st, nodetype):
"""Given an AST object or a cst fragment (as list or tuple), and a
string or int nodetype, return the first instance of the desired
nodetype as a cst fragment, or None if the nodetype is not found.
Usage:
>>> import parser, symbol
>>> ast = parser.suite("print 'hello world'")
>>> ASTutils.getnodes(ast, 'print_stmt')
[(271, (1, 'print'), (298, (299, (300, (301, (303, (304, (305, (306, (307, (308, (309, (310, (311, (3, "'hello world'")))))))))))))))]
>>> ASTutils.getnodes(ast, symbol.pass_stmt)
[]
>>> ASTutils.getnodes(ast, -1) # bad data
Traceback (most recent call last):
...
ASTutilsException: nodetype '-1' is not in symbol or token tables
>>> ASTutils.getnodes(ast, 'foo') # bad data
Traceback (most recent call last):
...
ASTutilsException: nodetype '-1' is not in symbol or token tables
"""
# we don't call _standardize_st because we want to accept fragments
ast = parser.suite('')
if type(st) is type(ast):
cst = st.totuple()
else:
cst = st
# standardize the incoming nodetype to a symbol or token int
if type(nodetype) is type(''):
symtype = getattr(symbol, nodetype, '')
if symtype:
nodetype = symtype
else:
toktype = getattr(token, nodetype, '')
if toktype:
nodetype = toktype
else:
nodetype = -1 # bad data
# validate the input
valid_ints = symbol.sym_name.keys() + token.tok_name.keys()
if nodetype not in valid_ints:
raise ASTutilsException, "nodetype '%s' " % nodetype +\
"is not in symbol or token tables"
# define our recursive function
class walker:
NODES = []
def walk(self, cst, nodetype):
for node in cst:
if type(node) in (type(()), type([])):
candidate = self.walk(node, nodetype)
else:
candidate = cst
if candidate is not None:
if candidate[0] == nodetype:
self.NODES.append(candidate)
walk = classmethod(walk)
# ggg!
walker.walk(cst, nodetype)
return walker.NODES
getnodes = classmethod(getnodes)
def hasnode(self, cst, nodetype):
"""Given an AST object or a cst fragment (either in list or tuple form),
and a nodetype (either as a string or an int), return a boolean.
Usage:
>>> import parser, symbol
>>> ast = parser.suite("print 'hello world'")
>>> ASTutils.hasnode(ast, 'print_stmt')
True
>>> ast = parser.suite("if 1: print 'hello world'")
>>> ASTutils.hasnode(ast, symbol.pass_stmt)
False
>>> ASTutils.hasnode(ast, symbol.print_stmt)
True
"""
return len(self.getnodes(cst, nodetype)) > 0
hasnode = classmethod(hasnode)
if __name__ == "__main__":
import doctest
doctest.testmod()