Hatena::Groupcoders

ラシウラ出張所 このページをアンテナに追加 RSSフィード

2007/11/30

[] 型推論をPythonで実装してみる  型推論をPythonで実装してみる - ラシウラ出張所 を含むブックマーク はてなブックマーク -  型推論をPythonで実装してみる - ラシウラ出張所  型推論をPythonで実装してみる - ラシウラ出張所 のブックマークコメント

プログラミング言語での型チェックというのは、簡単に言うと、

  • プログラム中の式が、それを構成する要素(部分式や、それ自体では式にならない項)の間で型情報の関係が整合性が取れているかどうかをチェックすること。

型情報のつけ方には種類があって、

  • 項そのものを見るだけで型がわかるタイプ: (現時点での)Java, Cなど
    • 関数や変数すべてにそれ自体に具体的な型情報がついている
  • 項だけじゃ型がわからないかもしれないタイプ: OCamlHaskellなど
    • 関数や変数の型は、その関連する式から「推論」する必要がある

型推論の実装とは、後者の型チェックを行うための仕組み。

以下はソースコード。


言語定義

ここでは、パーザーではなく、抽象構文木(AST)を定義する。型推論(や型チェック)は構文木に対して行われるものだからパーザーは省略している。

x = 10
func = \num -> x + num
func(5)

というプログラムは、

expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))

という構文木として表現するものになります。

(以下ずっと上から順につなげて、最後に例題をつなげれば実行できます。)

# syntax tree nodes for the lang
class Expr:
    """absctract expression base class"""
    def __repr__(self):
        return self.__class__.__name__ + repr(self.__dict__)

    pass

class Val(Expr):
    """value, literal"""
    def __init__(self, python_value):
        self.value = python_value
        pass

    pass

class Add(Expr):
    """+"""
    def __init__(self, left, right):
        self.left = left
        self.right = right
        pass
    pass


class Let(Expr):
    """set variable.

    let name = value; body"""
    def __init__(self, name, value, body):
        self.name = name
        self.value = value
        self.body = body
        pass
    pass

class Ref(Expr):
    """variable reference by the name"""
    def __init__(self, name):
        self.name = name
        pass
    pass

class Lambda(Expr):
    """anonymous function.

    \param1, param2, ... -> body
    """
    def __init__(self, params, body):
        self.params = params
        self.body = body
        pass
    pass

class Apply(Expr):
    """apply function and args.

    func(arg1, arg2, ...)
    """
    def __init__(self, func, args):
        self.func = func
        self.args = args
        pass
    pass

class Typed(Expr):
    """ for type declaration

    expr :: type
    """
    def __init__(self, expr, type):
        self.expr = expr
        self.type = type
        pass
    pass

# types for the lang
class Type:
    """abstract type base class"""
    pass

class TAtom(Type):
    """atomic type"""
    def __init__(self, label):
        self.label = label
        pass

    def __repr__(self):
        return self.label
    pass

class TFunc(Type):
    """ function type """
    def __init__(self, params_t, ret_t):
        self.params_t = params_t
        self.ret_t = ret_t
        pass

    def __repr__(self):
        return VarNames().repr(self)

    pass

class TVar(Type):
    """variable type: must instanciate for each point of code
    """
    def __init__(self):
        pass

    def __repr__(self):
        return "<vtype: %s>" % str(id(self))

    pass

class VarNames:
    """Utility for manage human-readable names for TVar"""
    def __init__(self):
        self.names = {}
        self.current = ord("a")
        pass

    def append(self, vtype):
        """register the name for vtype"""
        if self.names.has_key(vtype):
            return
        self.names[vtype] = "$%s" % chr(self.current)
        self.current += 1
        pass

    def __getitem__(self, vtype):
        self.append(vtype)
        return self.names[vtype]

    def repr(self, type):
        """get human-readable string for the type"""
        if type.__class__ is TVar:
            return self[type]
        if type.__class__ is TAtom:
            return repr(type)
        if type.__class__ is TFunc:
            params = []
            for param_t in type.params_t:
                params.append(self.repr(param_t))
                pass
            ret = self.repr(type.ret_t)
            return "(%s)->%s" % (", ".join(params), ret)
        pass

    def __repr__(self):
        return repr(self.names)

    pass

最後のVarNamesはTVarやTVar入りのTFuncを人間が読める型変数名にして文字列化するユーティリティ。

一応文法定義すると、プログラム自体はEXPRで以下のようになる

  • EXPR = VAL | LET | REF | LAMBDA | APPLY | ADD | TYPED
  • VAL = <literal>
  • LET = <identifier> '=' EXPR (';'|<lf>) EXPR
  • REF = <identifier>
  • LAMBDA = '\' (<identifier> ',')* <identifier> '->' EXPR
  • APPLY = EXPR '(' (EXPR ',')* EXPR ')'
  • ADD = EXPR '+' EXPR
  • TYPED = EXPR '::' <type>

型は

  • TYPE = ATOM | FUNC | VAR
  • ATOM = int | string
  • FUNC = '(' (TYPE ',')* TYPE ')' '->' TYPE
  • VAR = '$'<identifier>

文や再帰関数はない。

eval

型推論にはまったく不要だが、簡単だしチェックにも使えるので、evaluatorも用意しておいた。

# for evaluator
class VarNotFound(Exception):
    pass

class Env:
    """ Env table"""
    def __init__(self, parent=None):
        self.parent = parent
        self.table = {}
        pass

    def put(self, name, value):
        self.table[name] = value
        pass

    def get(self, name):
        try:
            return self.table[name]
        except:
            if self.parent is not None:
                return self.parent.get(name)
            else:
                raise VarNotFound(name)
            pass
        pass
    pass

# for runtime value
class Value:
    pass

class VInt(Value):
    def __init__(self, value):
        self.value = value
        pass
    def __repr__(self):
        return repr(self.value)
    pass

class VFunc(Value):
    def __init__(self, code, env):
        self.env = env
        self.code = code
        pass
    def __repr__(self):
        return "func: %s" % repr(self.code)
    pass
def evaluate(expr, env):
    """ (Expr, Env) -> Val
    """
    if expr.__class__ is Val:
        return VInt(expr.value)
    if expr.__class__ is Add:
        return VInt(evaluate(expr.left, env).value + \
                    evaluate(expr.right, env).value)
    if expr.__class__ is Let:
        child_env = Env(env)
        child_env.put(expr.name, evaluate(expr.value, env))
        return evaluate(expr.body, child_env)
    if expr.__class__ is Ref:
        return env.get(expr.name)
    if expr.__class__ is Lambda:
        return VFunc(expr, env)
    if expr.__class__ is Apply:
        func = evaluate(expr.func, env)
        body = func.code.body
        child_env = Env(func.env)
        for index, name in enumerate(func.code.params):
            arg_value = evaluate(expr.args[index], env)
            child_env.put(name, arg_value)
            pass
        return evaluate(body, child_env)
    if expr.__class__ is Typed:
        return evaluate(expr.expr)
    raise Exception("Not supported Expr: %s" % expr.__class__.__name__)

一応使い方は、

expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))

print(evaluate(expr0, Env()))

型推論

型推論実装のコードは大きいので、トップレベルでごとに解説をつけます。

inference関数

型推論機構のトップレベル関数。

式の構造をチェックしあっていれば、同一視する型のペアを作っていく。たとえばApplyなら第一引数の型は関数、Applyの引数の型とApplyの型で作った関数とそれを同一視させる。

結果として、引数の式の型と、その型の具体型を導出するための情報が返ってくる。

def inference(expr, table, type_map):
    """type inference: returns type for the expr.

    (Expr, TypeEnv, TypeMap) -> Type, TypeMap
    """
    if expr.__class__ is Val:
        return TAtom(type(expr.value).__name__), type_map

    if expr.__class__ is Add:
        parametric = TVar()
        func_type = TFunc([parametric, parametric], parametric)

        left_type, type_map = inference(expr.left, table, type_map)
        right_type, type_map = inference(expr.right, table, type_map)
        ret_type = parametric
        compare_type = TFunc([left_type, right_type], ret_type)
        try:
            type_map = unify([(func_type, compare_type)], type_map)
            return ret_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Let:
        value_type, type_map = inference(expr.value, table, type_map)
        table = TypeEnv(table)
        value_type = concrete_type(value_type, type_map, {})
        table.put(expr.name, value_type)
        body_type, type_map = inference(expr.body, table, type_map)
        return body_type, type_map

    if expr.__class__ is Ref:
        name_type = table.get(expr.name)
        return name_type, type_map

    if expr.__class__ is Lambda:
        table = TypeEnv(table)
        for key in expr.params:
            table.put(key, TVar())
            pass
        body_type, type_map = inference(expr.body, table, type_map)
        arg_types = [concrete_type(table.get(key), type_map, {}) \
                     for key in expr.params]
        func_type = TFunc(arg_types, body_type)
        # check cycric
        try:
            check_cycric(func_type, type_map)
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        return func_type, type_map


    if expr.__class__ is Apply:
        func_type, type_map = inference(expr.func, table, type_map)
        func_type = concrete_type(func_type, type_map, {})
        arg_types = []
        for arg in expr.args:
            arg_type, type_map = inference(arg, table, type_map)
            arg_types.append(arg_type)
            pass
        ret_type = TVar()
        compare_type = TFunc(arg_types, ret_type)
        try:
            type_map = unify([(func_type, compare_type)], type_map)
            return ret_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Typed:
        expr_type, type_map = inference(expr.expr, table, type_map)
        try:
            type_map = unify([(expr.type, expr_type)], type_map)
            return expr_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass
    raise Exception("Not supported Expr: %s" % expr.__class__.__name__)

更新: Applyのとき、func_typeを一旦具体化させてからunifyさせてる。id=\x->x; ((id id) 10)対策。

unify関数

queueに入っている型のペアを同一視させる関数。

  • ペアになってる型同士の構造が不整合がないかチェック
    • たとえば、TFuncとTAtomは絶対に同一視できない
  • ペアがTFuncのような構造型どうしだったら、再帰的に要素の型を同一視させる
  • ペアに変数型があれば、それを推移(もう一方がTVarじゃない場合)か同値(もう一方もTVarの場合)として型対応情報を作る

戻り値は結果としての型対応情報


def unify(queue, type_map):
    """unify type-pairs on the queue.
    - check structures of the pair could be same.
    - recursively unify component types(TFunc) in structred types.
    - make reduction or equation for TVar in the pairs.

    ([(Type, Type)], TypeMap) -> TypeMap"""
    if len(queue) == 0:
        return type_map
    left, right = queue.pop()

    if left.__class__ is TVar:
        type_map = TypeMap(type_map)
        if right.__class__ is TVar:
            type_map.put_eq(left, right)
            pass
        else:
            type_map.put(left, right)
            pass
        return unify(queue, type_map)

    if left.__class__ is TFunc and right.__class__ is TFunc:
        if len(left.params_t) != len(right.params_t):
            raise TypeChechError("mismatched param size: %s <=> %s" %\
                                 (left, right))
        for param_l, param_r in zip(left.params_t, right.params_t):
            queue.append((param_l, param_r))
            pass
        queue.append((left.ret_t, right.ret_t))
        return unify(queue, type_map)

    if left.__class__ is TAtom and right.__class__ is TAtom:
        if left.label != right.label:
            raise TypeChechError("mismatched atomic types: %s <=> %s" %\
                                 (left.label, right.label))
        return unify(queue, type_map)

    if right.__class__ is TVar:
        queue.append((right, left))
        return unify(queue, type_map)

    raise TypeCheckError("mismatched structure of types: %s <=> %s" %\
                         (repr(left), repr(right)))

check_cycric関数

Lambdaの型は、無限型(..->t->tのように無限に続く型)にならないようにする。Lambdaを推論した結果に無限に循環する型がないかチェックする。これは

そのための関数。

無限ループの原因になる。再帰関数を作るためのYコンビネータもこの型になる。再帰関数を可能にする場合、Y相当のものを特別に独自の「構文要素(たとえばLetRecとか)として」用意する。

def check_cycric(func_type, type_map):
    """raise error if find cyclic function type in its params"""
    def check(type, vcache):
        if type.__class__ is TFunc:
            for pt in type.params_t:
                check(pt, vcache)
                pass
            return

        if type.__class__ is TVar:
            if type in vcache:
                raise TypeCheckError("cyclic found: %s" % repr(type))
            concrete = type_map.reduction(type)
            if concrete == type:
                return
            if concrete.__class__ is TFunc:
                vcache.append(type)
            check(concrete, vcache)
            if concrete.__class__ is TFunc:
                vcache.pop()
            return

        if type.__class__ is TAtom:
            return
        pass
    try:
        check(func_type, [])
    except TypeCheckError, ex:
        concrete_tfunc = concrete_type(func_type, type_map, {})
        raise TypeCheckError("cycric type: %s" % repr(concrete_tfunc))
    pass
TypeCheckErrorクラス

推論が失敗したことを表す例外クラス。つまり型エラー時起きる。

この推論の実装では例外を、どの式で型エラーが起きたかまでの、中間脱出としても使っている。

class TypeCheckError(Exception):
    pass

TypeEnvクラス

名前と型を対応付けるクラス。

LetやLambdaごとに追加で用意され、Refで引かれる。外側には波及しない。実行のEnvと似たような形のライフサイクルをとる。

class TypeEnv:
    """ Type of Name Table for inferencing """
    def __init__(self, parent=None):
        self.table = {}
        self.parent = parent
        pass

    def put(self, name, type):
        self.table[name] = type
        pass

    def get(self, name):
        try:
            return self.table[name]
        except:
            if self.parent is not None:
                return self.parent.get(name)
            else:
                raise TypeCheckError("%s not found in type env" % name)
            pass
        pass

    pass
TypeMapクラス

変数型(TVar)同士の同値関係(get/putメソッド)や、TVarから具体型への推移関係(put_eqメソッド)の情報を保持して、問い合わせで同値と推移結果を解決する(reductionメソッド)クラス。この型推論実装でのキモ。

class TypeMap:
    def __init__(self, parent=None):
        self.parent = parent
        self.table = {}
        pass

    def put(self, vtype, type):
        try:
            self.table[vtype].add(type)
        except:
            self.table[vtype] = set([type])
            pass
        pass

    def put_eq(self, vtypea, vtypeb):
        self.put(vtypea, vtypeb)
        self.put(vtypeb, vtypea)
        pass


    def reduction(self, vtype):
        types = list(self.get_types(set([vtype])))
        types.sort()
        #print types
        for type in types:
            if not isinstance(type, TVar):
                return type
            pass
        return types[0]

    def get_types(self, types):
        eq_types = self._collect_types(types)
        if self.parent is not None:
            p_types = self.parent.get_types(eq_types)
            if p_types != eq_types:
                return self.get_types(p_types)
            else:
                return eq_types
            pass
        else:
            return eq_types
        pass

    def _collect_types(self, types):
        eq_types = set().union(types)
        for type in types:
            if not isinstance(type, TVar): continue
            try:
                eq_types = eq_types.union(self.table[type])
            except KeyError:
                pass
            pass
        if types != eq_types:
            return self._collect_types(eq_types)
        else:
            return types
        pass

    def __repr__(self):
        return repr(self.table) + repr(self.parent)

    pass

更新: setを使うようにしました。

concrete_type関数

「表示可能な形で」具体型を解決する関数。つまり無限型だと途中で止めるようキャッシュしてる。


def concrete_type(type, type_map, cache):
    """resolve printable concrete type for the type"""
    if cache.has_key(type):
        ret = cache[type]
        return ret

    if type.__class__ is TAtom:
        cache[type] = type
        return type

    if type.__class__ is TFunc:
        tfunc = TFunc([], None)
        cache[type] = tfunc
        for param_t in type.params_t:
            concrete_p = concrete_type(param_t, type_map, cache)
            tfunc.params_t.append(concrete_p)
            pass
        tfunc.ret_t = concrete_type(type.ret_t, type_map, cache)
        return tfunc

    if type.__class__ is TVar:
        concrete_t = type_map.reduction(type)
        cache[type] = concrete_t
        return concrete_t

    raise TypeCheckError("mismatched structure of types: %s <=> %s" %\
                         (repr(left), repr(right)))

型推論実行例

実行と表示のための便利関数

def print_type(expr, label):
    print "== %s ==" % label
    try:
        print expr
        print
        expr_type, type_map = inference(expr, TypeEnv(), TypeMap())
        print "[raw expression type]"
        print "  " + repr(expr_type)
        print "[type map]"
        print "  " + repr(type_map)
        print "[concrete expression type]"
        print "  " + repr(concrete_type(expr_type, type_map, {}))
    except:
        import sys, traceback
        extype, exvalue, trback = sys.exc_info()
        traceback.print_exception(extype,exvalue, trback, file=sys.stdout)
        print "<<%s is INVALID>>" % label
        pass
    print
    pass

で、例いろいろ。うしろのeとついた式が型エラーを「起こす」式。ついてないのは型エラーにならない式。

例コード

# examples

# valid: <type=int>
#  x = 10
#  func = \num -> x + num
#  func(5)
expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))
print_type(expr0, "expr0")

# structral error:
#  x = 0
#  func = \num -> x + num
#  func(\x -> 5)
expr1e = Let("x", Val(10),
             Let("func", Lambda(["num"],
                                Add(Ref("x"), Ref("num"))),
                 Apply(Ref("func"), [Lambda(["x"], Val(5))])))

print_type(expr1e, "expr1e")

# valid:
#  x = 0
#  func = y = 6
#         \num -> x + num
#  func(5)
expr1_1 = Let("x", Val(10),
              Let("func", Let("y", Val(6),
                              Lambda(["num"],
                                     Add(Ref("x"), Ref("num")))),
                  Apply(Ref("func"), [Val(5)])))

print_type(expr1_1, "expr1_1")

# structral error:
#  x = 0
#  func = y = 6
#         \num -> x + num
#  func(\x -> 5)
expr1_1e = Let("x", Val(10),
               Let("func", Let("y", Val(6),
                               Lambda(["num"],
                                      Add(Ref("x"), Ref("num")))),
                   Apply(Ref("func"), [Lambda(["x"], Val(5))])))

print_type(expr1_1e, "expr1_1e")

# valid
#  x = 0
#  func = \num -> y = 6
#                 x + num
#  func(z = 5
#       z)
expr1_2 = Let("x", Val(10),
              Let("func", Lambda(["num"],
                                 Let("y", Val(6),
                                     Add(Ref("x"), Ref("num")))),
                  Apply(Ref("func"),
                        [Let("z", Val(5), Ref("z"))])))
print_type(expr1_2, "expr1_2")

# structral error:
#  x = 0
#  func = \num -> y = 6
#                 x + num
#  func(z = 5
#       \x -> z)
expr1_2e = Let("x", Val(10),
               Let("func", Lambda(["num"],
                                  Let("y", Val(6),
                                      Add(Ref("x"), Ref("num")))),
                   Apply(Ref("func"),
                         [Let("z", Val(5), Lambda(["x"], Ref("z")))])))

print_type(expr1_2e, "expr1_2e")


# invalid?
# \x->x x
expr1_3e = Lambda(["x"], Apply(Ref("x"), [Ref("x")]))
print_type(expr1_3e, "expr1_3e")

# valid
# (\x->x) 10
expr1_4 = Apply(Lambda(["x"], Ref("x")), [Val(10)])
print_type(expr1_4, "expr1_4")


# valid: <var> -> <var>
#  id = \x -> x
#  id
expr2_0 = Let("id", Lambda(["x"], Ref("x")),
              Ref("id"))
print_type(expr2_0, "expr2_0")

# valid: int
#  id = \x -> x
#  id 10
expr2_1 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Val(10)]))
print_type(expr2_1, "expr2_1")



# valid: <var> -> <var>
#  id: 'a->'a
#  id(id): 'a->'a
expr2_2 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Ref("id")]))
print_type(expr2_2, "expr2_2")


# invalid:
#  idid = \x -> x(x)
#  idid: (('a->'b)->'b)->'b)
expr2_3e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
               Ref("idid"))
print_type(expr2_3e, "expr2_3e")

expr2_3_1e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                 Let("id", Lambda(["x"], Ref("x")),
                     Apply(Ref("idid"), [Ref("id")])))
print_type(expr2_3_1e, "expr2_3_1e")

# invalid:
#  idid = \x -> x(x)
#  id = \x -> x
#  idid(id)(10)
print "expr2_3_2e"
expr2_3_2e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                 Let("id", Lambda(["x"], Ref("x")),
                     Apply(Apply(Ref("idid"), [Ref("id")]), [Val(10)])))
print_type(expr2_3_2e, "expr2_3_2e")
# print(evaluate(expr2_3_2e, Env())): but could evaluate

# invalid:
#  idid = \x -> x(x)
#  idid(idid)=>idid(idid)=>...
expr2_4e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Apply(Ref("idid"), [Ref("idid")]))
print_type(expr2_4e, "expr2_4e")
# print(evaluate(expr2_4, Env())): infinite loop

# invalid:
#  idid = \x -> x(x)
#  idid(10)
expr2_5e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
               Apply(Ref("idid"), [Val(10)]))
print_type(expr2_5e, "expr2_5e")

# invalid:
#  idid = \x -> x(x)
#  idid(10)
expr2_5_1e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                 Apply(Apply(Ref("idid"), [Ref("idid")]), [Val(10)]))
print_type(expr2_5_1e, "expr2_5_1e")

# invalid:
#  idid = \x -> x(x)
#  idid(\x->x)(10)
expr2_5_2e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Apply(Apply(Ref("idid"), [Lambda(["x"], Ref("x"))]), [Val(10)]))
print_type(expr2_5_2e, "expr2_5_2e")
# print(evaluate(expr2_5_2e, Env())): but could evaluate

まとめとか

とりあえず、きちんと動き、型推論できるようにはなっている。言語としてLetRecは入れたいところ。

式要素が7つで、これだけなので、要素が数十ある普通の言語だとどうなるんだろうか。構造チェックと同一視を行う推論部分は型構造に対してのべた書きだけど、汎用の推論エンジン(ってある?)をつかえば減らせるのかもしれない。そうなったらそうなったで推論エンジンの使い方に戸惑いそうだけど。

リンク

動機とか紆余曲折とか、あとアルゴリズムの情報源も