dhilst

Simple Python typecheker in Python

Here is how to bootstrap a simple Python typechecker in Python.

The idea is using the ast module of Python to parse some code into a typed AST and then traverse the AST checking the types.

In this example I will type check function calls only, to do this the functions being called need to be annotated.

Here is the code that we will typecheck

def inc(a: int, b: int) -> int:
    return a + 1

def foo(a: float) -> float:
    return inc(a, "a") # type error here

The first thing we need to do is to parse the code into an AST and then traversing the AST. Parsing is easy just call ast.parse and you’re done. To visit we need to extend ast.NodeVisitor implementing the methods that we have interest. In this case I will visit function definitions (to gather typing information) and function calls, to effectively do the typechecking.

Here is a skeleton :

import ast
class Typechecker(ast.NodeVisitor):
    def __init__(self):
        super().__init__()
        self.typeenv = {}

    def visit_FunctionDef(self, node):
        self.generic_visit(node)

    def visit_Call(self, node):
        self.generic_visit(node)
		
def typecheck(text, typeenv={}):
    tree = ast.parse(text)
    Typechecker().visit(tree)

In our Typecheker class we initialize a typeenv member that will hold the type information against what we will typecheck. The generic_visit method will recurse into child nodes of the AST.

When visiting a FunctionDef we want to add the parameters to the type environment, then visit children nodes (particularly the body of the function, to do the typechecking), then remove the parameters from the type environment. We do this by saving the type environment, extending it with the parameters then restoring it after visiting the body.

from typing import *
import ast
from dataclasses import dataclass

@dataclass(frozen=True)
class FuncSig:
    name : str
    args : list[str]
    ret: str

    def __repr__(self):
        args = " -> ".join(self.args + [self.ret]) 
        return f"{self.name} : {args}"

class Typechecker(ast.NodeVisitor):
    def __init__(self):
        super().__init__()
        # this is our type environment
        self.typeenv = {}

    def visit_FunctionDef(self, node):
        # save the type environment
        oldenv = self.typeenv.copy()
        # updathe the type environment with the types of the arguments 
        self.typeenv.update({arg.arg: arg.annotation.id for arg in node.args.args})
        # visit the body
        self.generic_visit(node)
        # take the signature of this function being defined
        signature = FuncSig(node.name, [arg.annotation.id for arg in node.args.args], 
                            node.returns.id)
        # restore the old type environment, without the type of the arguments
        self.typeenv = oldenv
        # extend it with the now defined function
        self.typeenv[node.name] = signature

Here FuncSig is just a dataclass that I use to hold the function signature, with some pretty printing. The important thing to note is that we put the arguments in the type environment, recurse into the children nodes of the AST, and then remove the arguments from the type environment. This is because the scope of the arguments is the function body so we need to remove them once we leave the function.

Next we visit the Call. In this case, to keep things simple we check only named functions, so lambda calls will not be checked, and only functions that we have type information. The first thing we do is to check if the function name is in the type environment, if it is then we gather the call arguments and compare to the expected arguments (in the type environment), if they differ we raise a TypeError.

    def visit_Call(self, node):
        if type(node.func) is ast.Name and node.func.id in self.typeenv:
            actual_args = []
            for arg in node.args:
                if type(arg) is ast.Constant:
                    actual_args.append(type(arg.value).__name__)
                elif type(arg) is ast.Name:
                    if arg.id in self.typeenv:
                        actual_args.append(self.typeenv[arg.id])
                    else:
                        # Cannot typecheck, no type information
                        return 
            expected_args = self.typeenv[node.func.id].args
            # dumb typechecking 
            if actual_args != expected_args:
                raise TypeError(f"Type error in call for {node.func.id}, "
                                f"expected : {expected_args}, found : {actual_args}")

        self.generic_visit(node)

Here we consider two kinds of arguments in the call, constants are checked for its type, if it’s a variable instead, we check in the type environment for by its name, if no type is found in the type environment we silently giveup. If we could gather all the type information we need and the types differ we raise a TypeError and that’s it!

Putting all together, here is the full code : https://gist.github.com/dhilst/24bfd7904ccefb542abf7fa099e7e516

With this in hand you should be abble to tweak it and extend to it play with typechecking ideas (like typechecking with generics or type inference) without the need to bootstrap a whole grammar etc. Also you can use this kind of technique to lint your Python code, forbidden some functions that you consider unsafe for exmaple, ensure immutability, linearity, etc.

Cheers