Automatic code changing in Python with the ast module
Tiago Antao
Posted on December 5, 2021
In this article we are going to explore Python Abstract Syntax Trees. Our example will use the Python ast
module to convert legacy Python 2 to Python 3.
One of the changes that Python 3 brings versus Python 2 is that, if you redefine the __eq__
method of an object then your __hash__
method automatically becomes None
. This means that your class is rendered unhashable (and thus unusable in many settings). This makes sense as the
hash function is closely related to object equality. But, if you have Python 2 code that is nonetheless working and you are planning to move to Python 3, then automatically adding an __hash__
function to recover object hashbility might be a pragmatically acceptable solution.
This not-so-perfect example sets the stage for a mini-tutorial on traversing a Python source tree, finding classes that have new __eq__
methods but lack a new __hash__
method and then patching the class to have one.
This is, thus, mostly an excuse for a brief exploration on the ast
module of Python...
The ast module in Python
The ast module allows you to access the Abstract Syntax Tree of a Python program. This can be useful for program analysis and program transformation, among other things. Lets start with the AST representation of a Hello World program:
# Typical hello world example
print('Hello World')
Getting the AST tree for this is trivial:
import ast
tree = ast.parse(open('hello.py').read())
Now if you decide to explore the tree
object, it might seem daunting at first, but don't let appearances fool you, it is actually quite simple. The type of tree will
be ast.Module
. If you ask, by doing tree.fields
, for its children
(as you ask can to all AST nodes) you will see that a Module will have a
body (i.e. tree.body
). The body attribute will have, you guessed it, the body of the file:
print(tree.body)
<_ast.expr object at 0x7f5731c59bb0>
The body of a module is composed of a list of Statements
. In our case
we have a single Node: an Expr
(Expression is another type of node -
do not confuse them). Notice that the AST representation will not include the comments, these are "lost".
OK, back to our Call
. What is in a call? Well, you call a function with arguments so a call is a function name plus a set of arguments:
print(tree.body[0].value._fields)
('func', 'args', 'keywords', 'starargs', 'kwargs')
You will notice the func
plus a lot of stuff for arguments. This is because, as you know, Python has quite a rich way of passing arguments (positional arguments, named arguments, ...). For now lets concentrate
on func
and args
only. func
is a Name with an attribute called id
(the function name). args
is a list with a single element, a string:
>>> print(tree.body[0].value.func.id)
print
>>> print(tree.body[0].value.args[0].s)
Hello World
This is actually quite simple, here is a graphical representation:
The best way to find all types of nodes is to check the abstract grammar documentation.
Second attempt (a function definition)
Lets have a look at the AST for a function definition:
def print_hello(who):
print('Hello %s' % who)
If you get a tree for this code, you will notice that the body is still composed of a single statement:
>print(tree.body)
<_ast.functiondef object at 0x7f5731c59bb0>
At the top level, there is only a single statement, the print function
call belongs to the function definition proper. A function definition
has name
, args
, body
, decorator~list
and returns
. Name
is a string
print_hello
, no more indirections here, easy. args
cannot be very
simple because it has to accommodate the fairly rich Python syntax for
argument specification, so its fields are args
, vararg
,
kwonlyargs
, kw_defaults
, kwarg
, defaults
. In our case we just
have a simple positional argument, so we can find it on args.args
:
print(tree.body[0].args.args)
<_ast.arg object at 0x7f5731c59bb0>
The args
object has a arg
field (tree.body[0].args.args[0].arg
starts to sound a bit ludicrous, but such is life), which is a string
(who
). Now, the function body can be found in the body field of the
function definition:
print(tree.body[0].body)
<_ast.expr object at>
The body
is as discussed above for the print
call.
To finalize this version, I just want to present the sub-tree inside the
print("Hello %s" % who)
- Notice the BinOp
for the %
operator:
OK, before we go to the final version, one very last thing:
# Lets look at two extra properties of the print line...
print(tree.body[0].body[0].lineno, tree.body[0].body[0].col_offset)
2 4
Yes, you can retrieve the line number and the column of a statement...
Third attempt
class Hello:
def __init__(self, who):
self.who = who
def print_hello(self):
print('Hello %s' % self.who)
It should not come as a shock that the module only has a single statement:
<_ast.classdef object at>
And yes, as you might expect by now, the ClassDef
object has a name
and a body attribute (things start making sense and are consistent).
Indeed most things are as expected: The class has two statements (two
FuncDefs
). There are only a couple of conceptually new things, and
these are visible in the line:
self.who = who
Here we have two new features: the assignment =
and the compound name
self.who
.
Notice that the list of potential targets is a list, this is because you
can do things like:
x, y = 1, 2
There is much more than can be said. Processing ASTs requires a
recursive mindset, indeed if you are not used to think recursively I
suspect that might be your biggest hurdle in doing things with the AST
module. And with Python things can get very nested indeed (nested
functions, nested classes, lambda functions, ...)
A complete solution
This code was run a few years ago. While it is for Python 2, the general thought process is the same with Python 3
OK, lets switch gears completely and apply this to a concrete case. The
Abjad project helps composers build up
complex pieces of music notation in an iterative and incremental way. It
is currently Python 2 only, and I am volunteering some of my time to
help it support Python 3. This is highly-OO, highly well documented and
with a massive load of test cases (kudos to the main developers for
this). Some of the classes do have __eq__
methods defined, but lack
__hash__
methods, so a pragmatic (though not totally rigorous
solution) is to add the __hash__
methods required by Python 3.
A caveat here: the processing has to be done in Python 2, so the code
below is Python 2. The idea is to generate Python 2 code with hashes
that can be automatically translated by 2to3
(no need for monkey
patching after 2to3
)
First thing, we have to traverse all the files:
def traverse_dir(my_dir):
content = os.listdir(my_dir)
for element in content:
if os.path.isdir(my_dir + os.sep + element):
traverse_dir(my_dir + os.sep + element)
elif os.path.isfile(my_dir + os.sep + element) and element.endswith('.py'):
process_file(my_dir + os.sep + element)</pre>
Nothing special where, just plain traverse of a directory structure. Now
we want to find all classes that have an __eq__
method, but not an
__hash__
method:
def get_classes(tree):
# Will not work for nested classes
my_classes = []
for expr in tree.body:
if type(expr) == ast.ClassDef:
my_classes.append(expr)
return my_classes</code></pre>
The function above will return all classes from a list of statements
(typically a module body).
def get_class_methods(tree):
my_methods = []
for expr in tree.body:
if type(expr) == ast.FunctionDef:
my_methods.append(expr)
return my_methods
The function above will return all function definitions from a
ClassDef
object.
def process_file(my_file):
shutil.copyfile(my_file, my_file + '.bak')
tree = ast.parse(open(my_file).read())
my_classes = get_classes(tree)
patches = {}
for my_class in my_classes:
methods = get_class_methods(my_class)
has_eq = '__eq__' in [method.name for method in methods]
has_hash = '__hash__' in [method.name for method in methods]
if has_eq and not has_hash:
lineno = compute_patch(methods)
patches[lineno] = my_class.name
patch(my_file, patches)
This is the main loop applied to each file: We get all the available
classes; for each class we get all available methods and if there is an
__eq__
method with no __hash__
method then a patch is computed and
then applied. The first thing that we need is to know where to patch the
code:
def compute_patch(methods):
names = [method.name for method in methods]
names.append('__hash__')
try:
names.remove('__init__')
except ValueError:
pass
names.sort()
try:
method_after = names[names.index('__hash__') + 1]
except IndexError:
method_after = names[-2]
for method in methods:
if method.name == method_after:
return method.lineno
The main point here is to decide to which line to apply the patch. It
has of course to be inside the class, but we want to be slightly more
elegant than that: We want to put the method in lexicographical order
with all the others (so, __hash__
would go between __eq__
and
__init__
). Now we can patch:
def patch(my_file, patches):
f = open(my_file + '.bak')
w = open(my_file, 'w')
lineno = 0
for l in f:
lineno += 1
if lineno in patches:
w.write(""" def __hash__(self):
r'''Hashes my class.
Required to be explicitely re-defined on Python 3 if __eq__ changes.
Returns integer.
'''
return super(%s, self).__hash__()
""" % patches[lineno])
w.write(l)
Final notes
This could be done in many other ways. Indeed this is not even the standard way (that would be coding a 2to3 fixer). The practical solution is not general (for instance, it does not support nested classes). Also,
it has problems with comments (as we lose them in the AST). But, for the
practical purpose (patching abjad) it was good enough (You can see the result here.
For further reading, you might want to have a look at this stack overflow question.
Posted on December 5, 2021
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.