Polymorphic Typechecking in Python by Unification
In the Simple Python typecheker in
Python
I show how to bootstrap a simple typecheker in Python using the ast
module. In this post I continue the saga adding parametric
polymorphism a lá SML.
To attach information to the functions I will use a sig
function. This function does nothing in runtime and receives a
constant string as argument from where we take the typing information.
Here is the code we will be typechecking:
def sig(*args, **kwargs):
return lambda f: f
@sig("inc : float -> A -> A")
def inc(a, b):
return a + 1
@sig("foo : int -> float")
def foo(a):
return float(inc(a, "a"))
The idea is basically the same as in the previous typechecking post. We traverse AST gathering type information and typechecking as we go. This type typechecking is done by unification, this is explained in detail at Wikipedia.
To sum up, unification a function that receives two terms, a
substitution accumulator and return a substitution on success or a
failure. A substitution is a set of mappings in the form {x -> a; y
-> b; ...}
which means replace x
with a
, y
with b
, …. A
term, is usualy a formula involving functions, constants and variables.
For example f(X,a)
is a term where f
is a function, X
is a variable
and a
is constant (constants and variables are terms too). In our
case instead of functions we will have signatures in the form A -> B
-> C
, variables start in upcase and constants start with lowercase,
so A
is a type variable and str
is a type constant.
As I said typecheck goes by unification, you get the arguments you have, the expected the arguments, and instead of comparing the by equality like in the simple typechecker we try to unify them. If the unification succeeds then the typechecking succeeds.
-Talk is cheap, show me the code!
So here is the code, I will explain it next
from typing import *
import ast
from dataclasses import dataclass
class TypeTerm:
@staticmethod
def from_str(input: str):
if "->" in input:
return Arrow.from_str(input)
elif input[0].isupper():
return Var(input)
else:
return Const(input)
@staticmethod
def parse(input):
name, term = (TypeTerm.from_str(x.strip()) for x in input.split(":"))
return name, term
@dataclass(frozen=True)
class Var(TypeTerm):
name: str
@dataclass(frozen=True)
class Arrow(TypeTerm):
args: list[TypeTerm]
@staticmethod
def from_str(input):
return Arrow([x.strip() for x in input.split("->")])
@property
def args_without_return(self):
return Arrow(self.args[:-1])
@dataclass(frozen=True)
class Const(TypeTerm):
name: str
Subst = Set[Tuple[str, TypeTerm]]
T = TypeVar("T")
class Result(Generic[T]):
@overload
def __init__(self, err: str): ...
@overload
def __init__(self, ok: T): ...
def __init__(self, *, ok=None, err=None):
if ok is not None:
self.value = ok
self.is_ok = True
else:
self.err = err
self.is_ok = False
@property
def is_err(self):
return not self.is_ok
class Unify:
@staticmethod
def subst1(term: TypeTerm, subst: Subst) -> TypeTerm:
if type(term) is Var:
for name, replacement in subst:
if name == term.name:
return replacement
else:
return term
elif type(term) is Arrow:
return Arrow([Unify.subst1(arg, subst) for arg in term.args])
elif type(term) is Const:
return term
else:
assert False, "invalid case in subst1"
@staticmethod
def substmult(subst: Subst, replacement: Subst) -> Subst:
return {(name, Unify.subst1(term, replacement)) for (name, term) in subst}
@staticmethod
def unify(t1: TypeTerm, t2: TypeTerm, subst: Subst = set()) -> Result[Subst]:
if t1 == t2:
return Result(ok=subst)
elif type(t1) is Arrow and type(t2) is Arrow:
if len(t1.args) != len(t1.args):
return Result(err=f"unification error {t1} <> {t2}")
else:
for a1, a2 in zip(t1.args, t2.args):
r = Unify.unify(a1, a2, subst)
if r.is_err:
return r
else:
subst |= r.value
return Result(ok=subst)
elif type(t2) is Var and type(t1) is not Var:
return Unify.unify(t2, t1, subst)
elif type(t1) is Var:
newsubst = {(t1.name, t2)}
subst = Unify.substmult(subst, newsubst) | newsubst
return Result(ok=subst)
elif type(t1) is Const and type(t2) is Const:
if t1.name != t2.name:
return Result(err=f"Type error, expected {t1.name}, found {t2.name}")
else:
return Result(ok=subst)
else:
assert False, "invalid case"
class Typechecker(ast.NodeVisitor):
def __init__(self):
super().__init__()
self.typeenv = {}
def visit_FunctionDef(self, node):
# gatther the signature into of the
# function from @sig decorator
for dec in node.decorator_list:
if dec.func.id == "sig":
typ = Arrow([TypeTerm.from_str(x.strip()) for x
in dec.args[0].value.split(":")[1].split(" -> ")])
self.typeenv[node.name] = typ
if not node.name in self.typeenv:
self.generic_visit(node)
return
oldenv = self.typeenv.copy() # save the old environment
self.typeenv.update({ # extend the env with the arguments
arg.arg: typ for arg, typ
in zip(node.args.args, self.typeenv[node.name].args[:-1])})
# visit children
self.generic_visit(node)
# restore the environment
self.typeenv = oldenv
def visit_Call(self, node):
if type(node.func) is ast.Name:
if node.func.id == "sig":
name, typ = TypeTerm.parse(node.args[0].value)
self.typeenv[name] = typ
self.generic_visit(node)
return
elif node.func.id in self.typeenv:
# get the arguments types
actual_args = []
for arg in node.args:
if type(arg) is ast.Constant:
actual_args.append(
TypeTerm.from_str(type(arg.value).__name__))
elif type(arg) is ast.Name:
if arg.id in self.typeenv:
actual_args.append(
self.typeenv[arg.id])
else:
# no type information,
# give up
self.generic_visit(node)
return
if not actual_args:
self.generic_visit(node)
return
actual_args = Arrow(
[arg for arg in actual_args])
expected_args = self.typeenv[node.func.id].args_without_return
result = Unify.unify(expected_args, actual_args)
if result.is_err:
raise TypeError(result.err)
self.generic_visit(node)
def typecheck(text, typeenv={}):
tree = ast.parse(text)
Typechecker().visit(tree)
typecheck(
"""
def sig(*args, **kwargs):
return lambda f: f
@sig("inc : float -> A -> A")
def inc(a, b):
return a + 1
@sig("foo : int -> float")
def foo(a):
return float(inc(a, "a"))
"""
)
print("ok")
In visit_FunctionDef(self, node)
I look for sig
annotations, parse
the type information and save in the typeenv
. In the visit_Call
, I
handle sig
calls so you can use a empty sig("foo : int -> str")
to
introduce types into the type environment. If we’re calling any
another function and we have its type information in the type
environment elif node.func.id in self.typeenv:
we gather the call
arguments, the expected arguemnts and try to unify them. If
unification returns an error we raise a TypeError
with the error
message.
The other bits are the unification. I will not explain unification in
detail here (maybe in another post) but,. the algorithm is implemented
in the Unify
static class, so you can check it, and the terms are
subclasses of TypeTerm
.
If you want to develop a better understanding of this I recomend that you:
- Read the Wikipedia page and other pages about unification
- Download the code https://gist.github.com/dhilst/b5b198af93302ade61ccbfe3b094621a and run it
- Change the signatures inside
sig
, run it again - Set breakpoints and follow the code, extend, break and fix it :)
- Have fun
Cheeeers!