Automatic code changing in Python with the ast module

tiagoantao

Tiago Antao

Posted on December 5, 2021

Automatic code changing in Python with the ast module

In this article we are going to explore Python Abstract Syntax Trees. Our example will use the Python ast module to convert legacy Python 2 to Python 3.

One of the changes that Python 3 brings versus Python 2 is that, if you redefine the __eq__ method of an object then your __hash__ method automatically becomes None. This means that your class is rendered unhashable (and thus unusable in many settings). This makes sense as the
hash function is closely related to object equality. But, if you have Python 2 code that is nonetheless working and you are planning to move to Python 3, then automatically adding an __hash__ function to recover object hashbility might be a pragmatically acceptable solution.

This not-so-perfect example sets the stage for a mini-tutorial on traversing a Python source tree, finding classes that have new __eq__ methods but lack a new __hash__ method and then patching the class to have one.
This is, thus, mostly an excuse for a brief exploration on the ast module of Python...

The ast module in Python

The ast module allows you to access the Abstract Syntax Tree of a Python program. This can be useful for program analysis and program transformation, among other things. Lets start with the AST representation of a Hello World program:

# Typical hello world example
print('Hello World')  
Enter fullscreen mode Exit fullscreen mode

Getting the AST tree for this is trivial:

import ast
tree = ast.parse(open('hello.py').read())  
Enter fullscreen mode Exit fullscreen mode

Now if you decide to explore the tree object, it might seem daunting at first, but don't let appearances fool you, it is actually quite simple. The type of tree will
be ast.Module. If you ask, by doing tree.fields, for its children
(as you ask can to all AST nodes) you will see that a Module will have a
body (i.e. tree.body). The body attribute will have, you guessed it, the body of the file:

print(tree.body)
<_ast.expr object at 0x7f5731c59bb0>
Enter fullscreen mode Exit fullscreen mode

The body of a module is composed of a list of Statements. In our case
we have a single Node: an Expr (Expression is another type of node -
do not confuse them). Notice that the AST representation will not include the comments, these are "lost".

OK, back to our Call. What is in a call? Well, you call a function with arguments so a call is a function name plus a set of arguments:

print(tree.body[0].value._fields)
('func', 'args', 'keywords', 'starargs', 'kwargs')
Enter fullscreen mode Exit fullscreen mode

You will notice the func plus a lot of stuff for arguments. This is because, as you know, Python has quite a rich way of passing arguments (positional arguments, named arguments, ...). For now lets concentrate
on func and args only. func is a Name with an attribute called id
(the function name). args is a list with a single element, a string:

>>> print(tree.body[0].value.func.id)
print
>>> print(tree.body[0].value.args[0].s)
Hello World  
Enter fullscreen mode Exit fullscreen mode

This is actually quite simple, here is a graphical representation:

The AST for print('Hello World')

The best way to find all types of nodes is to check the abstract grammar documentation.

Second attempt (a function definition)

Lets have a look at the AST for a function definition:

def print_hello(who):
    print('Hello %s' % who)  
Enter fullscreen mode Exit fullscreen mode

If you get a tree for this code, you will notice that the body is still composed of a single statement:

>print(tree.body)
<_ast.functiondef object at 0x7f5731c59bb0>  
Enter fullscreen mode Exit fullscreen mode

At the top level, there is only a single statement, the print function
call belongs to the function definition proper. A function definition
has name, args, body, decorator~list and returns. Name is a string
print_hello, no more indirections here, easy. args cannot be very
simple because it has to accommodate the fairly rich Python syntax for
argument specification, so its fields are args, vararg,
kwonlyargs, kw_defaults, kwarg, defaults. In our case we just
have a simple positional argument, so we can find it on args.args:

print(tree.body[0].args.args)
<_ast.arg object at 0x7f5731c59bb0>  
Enter fullscreen mode Exit fullscreen mode

The args object has a arg field (tree.body[0].args.args[0].arg starts to sound a bit ludicrous, but such is life), which is a string
(who). Now, the function body can be found in the body field of the
function definition:

print(tree.body[0].body)
<_ast.expr object at>  
Enter fullscreen mode Exit fullscreen mode

The body is as discussed above for the print call.

To finalize this version, I just want to present the sub-tree inside the
print("Hello %s" % who) - Notice the BinOp for the % operator:

Ops on an AST tree

OK, before we go to the final version, one very last thing:

# Lets look at two extra properties of the print line...
print(tree.body[0].body[0].lineno, tree.body[0].body[0].col_offset)
2 4
Enter fullscreen mode Exit fullscreen mode

Yes, you can retrieve the line number and the column of a statement...

Third attempt

class Hello:
    def __init__(self, who):
        self.who = who

    def print_hello(self):
        print('Hello %s' % self.who)  
Enter fullscreen mode Exit fullscreen mode

It should not come as a shock that the module only has a single statement:

<_ast.classdef object at>

And yes, as you might expect by now, the ClassDef object has a name
and a body attribute (things start making sense and are consistent).
Indeed most things are as expected: The class has two statements (two
FuncDefs). There are only a couple of conceptually new things, and
these are visible in the line:

self.who = who

Here we have two new features: the assignment = and the compound name
self.who.

self.who = who

Notice that the list of potential targets is a list, this is because you
can do things like:

x, y = 1, 2

There is much more than can be said. Processing ASTs requires a
recursive mindset, indeed if you are not used to think recursively I
suspect that might be your biggest hurdle in doing things with the AST
module. And with Python things can get very nested indeed (nested
functions, nested classes, lambda functions, ...)

A complete solution

This code was run a few years ago. While it is for Python 2, the general thought process is the same with Python 3

OK, lets switch gears completely and apply this to a concrete case. The
Abjad project helps composers build up
complex pieces of music notation in an iterative and incremental way. It
is currently Python 2 only, and I am volunteering some of my time to
help it support Python 3. This is highly-OO, highly well documented and
with a massive load of test cases (kudos to the main developers for
this). Some of the classes do have __eq__ methods defined, but lack
__hash__ methods, so a pragmatic (though not totally rigorous
solution) is to add the __hash__ methods required by Python 3.

A caveat here: the processing has to be done in Python 2, so the code
below is Python 2. The idea is to generate Python 2 code with hashes
that can be automatically translated by 2to3 (no need for monkey
patching after 2to3)

First thing, we have to traverse all the files:

def traverse_dir(my_dir):
    content = os.listdir(my_dir)
    for element in content:
        if os.path.isdir(my_dir + os.sep + element):
            traverse_dir(my_dir + os.sep + element)
        elif os.path.isfile(my_dir + os.sep + element) and element.endswith('.py'):
            process_file(my_dir + os.sep + element)</pre>  
Enter fullscreen mode Exit fullscreen mode

Nothing special where, just plain traverse of a directory structure. Now
we want to find all classes that have an __eq__ method, but not an
__hash__ method:

def get_classes(tree):
    # Will not work for nested classes
    my_classes = []
    for expr in tree.body:
        if type(expr) == ast.ClassDef:
            my_classes.append(expr)
    return my_classes</code></pre>
Enter fullscreen mode Exit fullscreen mode

The function above will return all classes from a list of statements
(typically a module body).

def get_class_methods(tree):
    my_methods = []
    for expr in tree.body:
        if type(expr) == ast.FunctionDef:
            my_methods.append(expr)
    return my_methods
Enter fullscreen mode Exit fullscreen mode

The function above will return all function definitions from a
ClassDef object.

def process_file(my_file):
    shutil.copyfile(my_file, my_file + '.bak')
    tree = ast.parse(open(my_file).read())
    my_classes = get_classes(tree)
    patches = {}
    for my_class in my_classes:
        methods = get_class_methods(my_class)
        has_eq = '__eq__' in [method.name for method in methods]
        has_hash = '__hash__' in [method.name for method in methods]
        if has_eq and not has_hash:
            lineno = compute_patch(methods)
            patches[lineno] = my_class.name
    patch(my_file, patches)  
Enter fullscreen mode Exit fullscreen mode

This is the main loop applied to each file: We get all the available
classes; for each class we get all available methods and if there is an
__eq__ method with no __hash__ method then a patch is computed and
then applied. The first thing that we need is to know where to patch the
code:

def compute_patch(methods):
    names = [method.name for method in methods]
    names.append('__hash__')
    try:
        names.remove('__init__')
    except ValueError:
        pass
    names.sort()
    try:
        method_after = names[names.index('__hash__') + 1]
    except IndexError:
        method_after = names[-2]
    for method in methods:
        if method.name == method_after:
            return method.lineno  
Enter fullscreen mode Exit fullscreen mode

The main point here is to decide to which line to apply the patch. It
has of course to be inside the class, but we want to be slightly more
elegant than that: We want to put the method in lexicographical order
with all the others (so, __hash__ would go between __eq__ and
__init__). Now we can patch:

def patch(my_file, patches):
    f = open(my_file + '.bak')
    w = open(my_file, 'w')
    lineno = 0
    for l in f:
        lineno += 1
        if lineno in patches:
            w.write("""    def __hash__(self):
        r'''Hashes my class.

        Required to be explicitely re-defined on Python 3 if __eq__ changes.

        Returns integer.
        '''
        return super(%s, self).__hash__()

""" % patches[lineno])
        w.write(l)  
Enter fullscreen mode Exit fullscreen mode

Final notes

This could be done in many other ways. Indeed this is not even the standard way (that would be coding a 2to3 fixer). The practical solution is not general (for instance, it does not support nested classes). Also,
it has problems with comments (as we lose them in the AST). But, for the
practical purpose (patching abjad) it was good enough (You can see the result here.

For further reading, you might want to have a look at this stack overflow question.

💖 💪 🙅 🚩
tiagoantao
Tiago Antao

Posted on December 5, 2021

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related