Hatena::Groupcoders

ラシウラ出張所 このページをアンテナに追加 RSSフィード

2007/11/30

[] 型推論をPythonで実装してみる  型推論をPythonで実装してみる - ラシウラ出張所 を含むブックマーク はてなブックマーク -  型推論をPythonで実装してみる - ラシウラ出張所  型推論をPythonで実装してみる - ラシウラ出張所 のブックマークコメント

プログラミング言語での型チェックというのは、簡単に言うと、

  • プログラム中の式が、それを構成する要素(部分式や、それ自体では式にならない項)の間で型情報の関係が整合性が取れているかどうかをチェックすること。

型情報のつけ方には種類があって、

  • 項そのものを見るだけで型がわかるタイプ: (現時点での)Java, Cなど
    • 関数や変数すべてにそれ自体に具体的な型情報がついている
  • 項だけじゃ型がわからないかもしれないタイプ: OCamlHaskellなど
    • 関数や変数の型は、その関連する式から「推論」する必要がある

型推論の実装とは、後者の型チェックを行うための仕組み。

以下はソースコード。


言語定義

ここでは、パーザーではなく、抽象構文木(AST)を定義する。型推論(や型チェック)は構文木に対して行われるものだからパーザーは省略している。

x = 10
func = \num -> x + num
func(5)

というプログラムは、

expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))

という構文木として表現するものになります。

(以下ずっと上から順につなげて、最後に例題をつなげれば実行できます。)

# syntax tree nodes for the lang
class Expr:
    """absctract expression base class"""
    def __repr__(self):
        return self.__class__.__name__ + repr(self.__dict__)

    pass

class Val(Expr):
    """value, literal"""
    def __init__(self, python_value):
        self.value = python_value
        pass

    pass

class Add(Expr):
    """+"""
    def __init__(self, left, right):
        self.left = left
        self.right = right
        pass
    pass


class Let(Expr):
    """set variable.

    let name = value; body"""
    def __init__(self, name, value, body):
        self.name = name
        self.value = value
        self.body = body
        pass
    pass

class Ref(Expr):
    """variable reference by the name"""
    def __init__(self, name):
        self.name = name
        pass
    pass

class Lambda(Expr):
    """anonymous function.

    \param1, param2, ... -> body
    """
    def __init__(self, params, body):
        self.params = params
        self.body = body
        pass
    pass

class Apply(Expr):
    """apply function and args.

    func(arg1, arg2, ...)
    """
    def __init__(self, func, args):
        self.func = func
        self.args = args
        pass
    pass

class Typed(Expr):
    """ for type declaration

    expr :: type
    """
    def __init__(self, expr, type):
        self.expr = expr
        self.type = type
        pass
    pass

# types for the lang
class Type:
    """abstract type base class"""
    pass

class TAtom(Type):
    """atomic type"""
    def __init__(self, label):
        self.label = label
        pass

    def __repr__(self):
        return self.label
    pass

class TFunc(Type):
    """ function type """
    def __init__(self, params_t, ret_t):
        self.params_t = params_t
        self.ret_t = ret_t
        pass

    def __repr__(self):
        return VarNames().repr(self)

    pass

class TVar(Type):
    """variable type: must instanciate for each point of code
    """
    def __init__(self):
        pass

    def __repr__(self):
        return "<vtype: %s>" % str(id(self))

    pass

class VarNames:
    """Utility for manage human-readable names for TVar"""
    def __init__(self):
        self.names = {}
        self.current = ord("a")
        pass

    def append(self, vtype):
        """register the name for vtype"""
        if self.names.has_key(vtype):
            return
        self.names[vtype] = "$%s" % chr(self.current)
        self.current += 1
        pass

    def __getitem__(self, vtype):
        self.append(vtype)
        return self.names[vtype]

    def repr(self, type):
        """get human-readable string for the type"""
        if type.__class__ is TVar:
            return self[type]
        if type.__class__ is TAtom:
            return repr(type)
        if type.__class__ is TFunc:
            params = []
            for param_t in type.params_t:
                params.append(self.repr(param_t))
                pass
            ret = self.repr(type.ret_t)
            return "(%s)->%s" % (", ".join(params), ret)
        pass

    def __repr__(self):
        return repr(self.names)

    pass

最後のVarNamesはTVarやTVar入りのTFuncを人間が読める型変数名にして文字列化するユーティリティ。

一応文法定義すると、プログラム自体はEXPRで以下のようになる

  • EXPR = VAL | LET | REF | LAMBDA | APPLY | ADD | TYPED
  • VAL = <literal>
  • LET = <identifier> '=' EXPR (';'|<lf>) EXPR
  • REF = <identifier>
  • LAMBDA = '\' (<identifier> ',')* <identifier> '->' EXPR
  • APPLY = EXPR '(' (EXPR ',')* EXPR ')'
  • ADD = EXPR '+' EXPR
  • TYPED = EXPR '::' <type>

型は

  • TYPE = ATOM | FUNC | VAR
  • ATOM = int | string
  • FUNC = '(' (TYPE ',')* TYPE ')' '->' TYPE
  • VAR = '$'<identifier>

文や再帰関数はない。

eval

型推論にはまったく不要だが、簡単だしチェックにも使えるので、evaluatorも用意しておいた。

# for evaluator
class VarNotFound(Exception):
    pass

class Env:
    """ Env table"""
    def __init__(self, parent=None):
        self.parent = parent
        self.table = {}
        pass

    def put(self, name, value):
        self.table[name] = value
        pass

    def get(self, name):
        try:
            return self.table[name]
        except:
            if self.parent is not None:
                return self.parent.get(name)
            else:
                raise VarNotFound(name)
            pass
        pass
    pass

# for runtime value
class Value:
    pass

class VInt(Value):
    def __init__(self, value):
        self.value = value
        pass
    def __repr__(self):
        return repr(self.value)
    pass

class VFunc(Value):
    def __init__(self, code, env):
        self.env = env
        self.code = code
        pass
    def __repr__(self):
        return "func: %s" % repr(self.code)
    pass
def evaluate(expr, env):
    """ (Expr, Env) -> Val
    """
    if expr.__class__ is Val:
        return VInt(expr.value)
    if expr.__class__ is Add:
        return VInt(evaluate(expr.left, env).value + \
                    evaluate(expr.right, env).value)
    if expr.__class__ is Let:
        child_env = Env(env)
        child_env.put(expr.name, evaluate(expr.value, env))
        return evaluate(expr.body, child_env)
    if expr.__class__ is Ref:
        return env.get(expr.name)
    if expr.__class__ is Lambda:
        return VFunc(expr, env)
    if expr.__class__ is Apply:
        func = evaluate(expr.func, env)
        body = func.code.body
        child_env = Env(func.env)
        for index, name in enumerate(func.code.params):
            arg_value = evaluate(expr.args[index], env)
            child_env.put(name, arg_value)
            pass
        return evaluate(body, child_env)
    if expr.__class__ is Typed:
        return evaluate(expr.expr)
    raise Exception("Not supported Expr: %s" % expr.__class__.__name__)

一応使い方は、

expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))

print(evaluate(expr0, Env()))

型推論

型推論実装のコードは大きいので、トップレベルでごとに解説をつけます。

inference関数

型推論機構のトップレベル関数。

式の構造をチェックしあっていれば、同一視する型のペアを作っていく。たとえばApplyなら第一引数の型は関数、Applyの引数の型とApplyの型で作った関数とそれを同一視させる。

結果として、引数の式の型と、その型の具体型を導出するための情報が返ってくる。

def inference(expr, table, type_map):
    """type inference: returns type for the expr.

    (Expr, TypeEnv, TypeMap) -> Type, TypeMap
    """
    if expr.__class__ is Val:
        return TAtom(type(expr.value).__name__), type_map

    if expr.__class__ is Add:
        parametric = TVar()
        func_type = TFunc([parametric, parametric], parametric)

        left_type, type_map = inference(expr.left, table, type_map)
        right_type, type_map = inference(expr.right, table, type_map)
        ret_type = parametric
        compare_type = TFunc([left_type, right_type], ret_type)
        try:
            type_map = unify([(func_type, compare_type)], type_map)
            return ret_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Let:
        value_type, type_map = inference(expr.value, table, type_map)
        table = TypeEnv(table)
        value_type = concrete_type(value_type, type_map, {})
        table.put(expr.name, value_type)
        body_type, type_map = inference(expr.body, table, type_map)
        return body_type, type_map

    if expr.__class__ is Ref:
        name_type = table.get(expr.name)
        return name_type, type_map

    if expr.__class__ is Lambda:
        table = TypeEnv(table)
        for key in expr.params:
            table.put(key, TVar())
            pass
        body_type, type_map = inference(expr.body, table, type_map)
        arg_types = [concrete_type(table.get(key), type_map, {}) \
                     for key in expr.params]
        func_type = TFunc(arg_types, body_type)
        # check cycric
        try:
            check_cycric(func_type, type_map)
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        return func_type, type_map


    if expr.__class__ is Apply:
        func_type, type_map = inference(expr.func, table, type_map)
        func_type = concrete_type(func_type, type_map, {})
        arg_types = []
        for arg in expr.args:
            arg_type, type_map = inference(arg, table, type_map)
            arg_types.append(arg_type)
            pass
        ret_type = TVar()
        compare_type = TFunc(arg_types, ret_type)
        try:
            type_map = unify([(func_type, compare_type)], type_map)
            return ret_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass

    if expr.__class__ is Typed:
        expr_type, type_map = inference(expr.expr, table, type_map)
        try:
            type_map = unify([(expr.type, expr_type)], type_map)
            return expr_type, type_map
        except:
            print "Type Error at: %s" % repr(expr)
            raise
        pass
    raise Exception("Not supported Expr: %s" % expr.__class__.__name__)

更新: Applyのとき、func_typeを一旦具体化させてからunifyさせてる。id=\x->x; ((id id) 10)対策。

unify関数

queueに入っている型のペアを同一視させる関数。

  • ペアになってる型同士の構造が不整合がないかチェック
    • たとえば、TFuncとTAtomは絶対に同一視できない
  • ペアがTFuncのような構造型どうしだったら、再帰的に要素の型を同一視させる
  • ペアに変数型があれば、それを推移(もう一方がTVarじゃない場合)か同値(もう一方もTVarの場合)として型対応情報を作る

戻り値は結果としての型対応情報


def unify(queue, type_map):
    """unify type-pairs on the queue.
    - check structures of the pair could be same.
    - recursively unify component types(TFunc) in structred types.
    - make reduction or equation for TVar in the pairs.

    ([(Type, Type)], TypeMap) -> TypeMap"""
    if len(queue) == 0:
        return type_map
    left, right = queue.pop()

    if left.__class__ is TVar:
        type_map = TypeMap(type_map)
        if right.__class__ is TVar:
            type_map.put_eq(left, right)
            pass
        else:
            type_map.put(left, right)
            pass
        return unify(queue, type_map)

    if left.__class__ is TFunc and right.__class__ is TFunc:
        if len(left.params_t) != len(right.params_t):
            raise TypeChechError("mismatched param size: %s <=> %s" %\
                                 (left, right))
        for param_l, param_r in zip(left.params_t, right.params_t):
            queue.append((param_l, param_r))
            pass
        queue.append((left.ret_t, right.ret_t))
        return unify(queue, type_map)

    if left.__class__ is TAtom and right.__class__ is TAtom:
        if left.label != right.label:
            raise TypeChechError("mismatched atomic types: %s <=> %s" %\
                                 (left.label, right.label))
        return unify(queue, type_map)

    if right.__class__ is TVar:
        queue.append((right, left))
        return unify(queue, type_map)

    raise TypeCheckError("mismatched structure of types: %s <=> %s" %\
                         (repr(left), repr(right)))

check_cycric関数

Lambdaの型は、無限型(..->t->tのように無限に続く型)にならないようにする。Lambdaを推論した結果に無限に循環する型がないかチェックする。これは

そのための関数。

無限ループの原因になる。再帰関数を作るためのYコンビネータもこの型になる。再帰関数を可能にする場合、Y相当のものを特別に独自の「構文要素(たとえばLetRecとか)として」用意する。

def check_cycric(func_type, type_map):
    """raise error if find cyclic function type in its params"""
    def check(type, vcache):
        if type.__class__ is TFunc:
            for pt in type.params_t:
                check(pt, vcache)
                pass
            return

        if type.__class__ is TVar:
            if type in vcache:
                raise TypeCheckError("cyclic found: %s" % repr(type))
            concrete = type_map.reduction(type)
            if concrete == type:
                return
            if concrete.__class__ is TFunc:
                vcache.append(type)
            check(concrete, vcache)
            if concrete.__class__ is TFunc:
                vcache.pop()
            return

        if type.__class__ is TAtom:
            return
        pass
    try:
        check(func_type, [])
    except TypeCheckError, ex:
        concrete_tfunc = concrete_type(func_type, type_map, {})
        raise TypeCheckError("cycric type: %s" % repr(concrete_tfunc))
    pass
TypeCheckErrorクラス

推論が失敗したことを表す例外クラス。つまり型エラー時起きる。

この推論の実装では例外を、どの式で型エラーが起きたかまでの、中間脱出としても使っている。

class TypeCheckError(Exception):
    pass

TypeEnvクラス

名前と型を対応付けるクラス。

LetやLambdaごとに追加で用意され、Refで引かれる。外側には波及しない。実行のEnvと似たような形のライフサイクルをとる。

class TypeEnv:
    """ Type of Name Table for inferencing """
    def __init__(self, parent=None):
        self.table = {}
        self.parent = parent
        pass

    def put(self, name, type):
        self.table[name] = type
        pass

    def get(self, name):
        try:
            return self.table[name]
        except:
            if self.parent is not None:
                return self.parent.get(name)
            else:
                raise TypeCheckError("%s not found in type env" % name)
            pass
        pass

    pass
TypeMapクラス

変数型(TVar)同士の同値関係(get/putメソッド)や、TVarから具体型への推移関係(put_eqメソッド)の情報を保持して、問い合わせで同値と推移結果を解決する(reductionメソッド)クラス。この型推論実装でのキモ。

class TypeMap:
    def __init__(self, parent=None):
        self.parent = parent
        self.table = {}
        pass

    def put(self, vtype, type):
        try:
            self.table[vtype].add(type)
        except:
            self.table[vtype] = set([type])
            pass
        pass

    def put_eq(self, vtypea, vtypeb):
        self.put(vtypea, vtypeb)
        self.put(vtypeb, vtypea)
        pass


    def reduction(self, vtype):
        types = list(self.get_types(set([vtype])))
        types.sort()
        #print types
        for type in types:
            if not isinstance(type, TVar):
                return type
            pass
        return types[0]

    def get_types(self, types):
        eq_types = self._collect_types(types)
        if self.parent is not None:
            p_types = self.parent.get_types(eq_types)
            if p_types != eq_types:
                return self.get_types(p_types)
            else:
                return eq_types
            pass
        else:
            return eq_types
        pass

    def _collect_types(self, types):
        eq_types = set().union(types)
        for type in types:
            if not isinstance(type, TVar): continue
            try:
                eq_types = eq_types.union(self.table[type])
            except KeyError:
                pass
            pass
        if types != eq_types:
            return self._collect_types(eq_types)
        else:
            return types
        pass

    def __repr__(self):
        return repr(self.table) + repr(self.parent)

    pass

更新: setを使うようにしました。

concrete_type関数

「表示可能な形で」具体型を解決する関数。つまり無限型だと途中で止めるようキャッシュしてる。


def concrete_type(type, type_map, cache):
    """resolve printable concrete type for the type"""
    if cache.has_key(type):
        ret = cache[type]
        return ret

    if type.__class__ is TAtom:
        cache[type] = type
        return type

    if type.__class__ is TFunc:
        tfunc = TFunc([], None)
        cache[type] = tfunc
        for param_t in type.params_t:
            concrete_p = concrete_type(param_t, type_map, cache)
            tfunc.params_t.append(concrete_p)
            pass
        tfunc.ret_t = concrete_type(type.ret_t, type_map, cache)
        return tfunc

    if type.__class__ is TVar:
        concrete_t = type_map.reduction(type)
        cache[type] = concrete_t
        return concrete_t

    raise TypeCheckError("mismatched structure of types: %s <=> %s" %\
                         (repr(left), repr(right)))

型推論実行例

実行と表示のための便利関数

def print_type(expr, label):
    print "== %s ==" % label
    try:
        print expr
        print
        expr_type, type_map = inference(expr, TypeEnv(), TypeMap())
        print "[raw expression type]"
        print "  " + repr(expr_type)
        print "[type map]"
        print "  " + repr(type_map)
        print "[concrete expression type]"
        print "  " + repr(concrete_type(expr_type, type_map, {}))
    except:
        import sys, traceback
        extype, exvalue, trback = sys.exc_info()
        traceback.print_exception(extype,exvalue, trback, file=sys.stdout)
        print "<<%s is INVALID>>" % label
        pass
    print
    pass

で、例いろいろ。うしろのeとついた式が型エラーを「起こす」式。ついてないのは型エラーにならない式。

例コード

# examples

# valid: <type=int>
#  x = 10
#  func = \num -> x + num
#  func(5)
expr0 = Let("x", Val(10),
            Let("func", Lambda(["num"],
                               Add(Ref("x"), Ref("num"))),
                Apply(Ref("func"), [Val(5)])))
print_type(expr0, "expr0")

# structral error:
#  x = 0
#  func = \num -> x + num
#  func(\x -> 5)
expr1e = Let("x", Val(10),
             Let("func", Lambda(["num"],
                                Add(Ref("x"), Ref("num"))),
                 Apply(Ref("func"), [Lambda(["x"], Val(5))])))

print_type(expr1e, "expr1e")

# valid:
#  x = 0
#  func = y = 6
#         \num -> x + num
#  func(5)
expr1_1 = Let("x", Val(10),
              Let("func", Let("y", Val(6),
                              Lambda(["num"],
                                     Add(Ref("x"), Ref("num")))),
                  Apply(Ref("func"), [Val(5)])))

print_type(expr1_1, "expr1_1")

# structral error:
#  x = 0
#  func = y = 6
#         \num -> x + num
#  func(\x -> 5)
expr1_1e = Let("x", Val(10),
               Let("func", Let("y", Val(6),
                               Lambda(["num"],
                                      Add(Ref("x"), Ref("num")))),
                   Apply(Ref("func"), [Lambda(["x"], Val(5))])))

print_type(expr1_1e, "expr1_1e")

# valid
#  x = 0
#  func = \num -> y = 6
#                 x + num
#  func(z = 5
#       z)
expr1_2 = Let("x", Val(10),
              Let("func", Lambda(["num"],
                                 Let("y", Val(6),
                                     Add(Ref("x"), Ref("num")))),
                  Apply(Ref("func"),
                        [Let("z", Val(5), Ref("z"))])))
print_type(expr1_2, "expr1_2")

# structral error:
#  x = 0
#  func = \num -> y = 6
#                 x + num
#  func(z = 5
#       \x -> z)
expr1_2e = Let("x", Val(10),
               Let("func", Lambda(["num"],
                                  Let("y", Val(6),
                                      Add(Ref("x"), Ref("num")))),
                   Apply(Ref("func"),
                         [Let("z", Val(5), Lambda(["x"], Ref("z")))])))

print_type(expr1_2e, "expr1_2e")


# invalid?
# \x->x x
expr1_3e = Lambda(["x"], Apply(Ref("x"), [Ref("x")]))
print_type(expr1_3e, "expr1_3e")

# valid
# (\x->x) 10
expr1_4 = Apply(Lambda(["x"], Ref("x")), [Val(10)])
print_type(expr1_4, "expr1_4")


# valid: <var> -> <var>
#  id = \x -> x
#  id
expr2_0 = Let("id", Lambda(["x"], Ref("x")),
              Ref("id"))
print_type(expr2_0, "expr2_0")

# valid: int
#  id = \x -> x
#  id 10
expr2_1 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Val(10)]))
print_type(expr2_1, "expr2_1")



# valid: <var> -> <var>
#  id: 'a->'a
#  id(id): 'a->'a
expr2_2 = Let("id", Lambda(["x"], Ref("x")),
              Apply(Ref("id"), [Ref("id")]))
print_type(expr2_2, "expr2_2")


# invalid:
#  idid = \x -> x(x)
#  idid: (('a->'b)->'b)->'b)
expr2_3e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
               Ref("idid"))
print_type(expr2_3e, "expr2_3e")

expr2_3_1e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                 Let("id", Lambda(["x"], Ref("x")),
                     Apply(Ref("idid"), [Ref("id")])))
print_type(expr2_3_1e, "expr2_3_1e")

# invalid:
#  idid = \x -> x(x)
#  id = \x -> x
#  idid(id)(10)
print "expr2_3_2e"
expr2_3_2e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                 Let("id", Lambda(["x"], Ref("x")),
                     Apply(Apply(Ref("idid"), [Ref("id")]), [Val(10)])))
print_type(expr2_3_2e, "expr2_3_2e")
# print(evaluate(expr2_3_2e, Env())): but could evaluate

# invalid:
#  idid = \x -> x(x)
#  idid(idid)=>idid(idid)=>...
expr2_4e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Apply(Ref("idid"), [Ref("idid")]))
print_type(expr2_4e, "expr2_4e")
# print(evaluate(expr2_4, Env())): infinite loop

# invalid:
#  idid = \x -> x(x)
#  idid(10)
expr2_5e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
               Apply(Ref("idid"), [Val(10)]))
print_type(expr2_5e, "expr2_5e")

# invalid:
#  idid = \x -> x(x)
#  idid(10)
expr2_5_1e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
                 Apply(Apply(Ref("idid"), [Ref("idid")]), [Val(10)]))
print_type(expr2_5_1e, "expr2_5_1e")

# invalid:
#  idid = \x -> x(x)
#  idid(\x->x)(10)
expr2_5_2e = Let("idid", Lambda(["x"], Apply(Ref("x"), [Ref("x")])),
              Apply(Apply(Ref("idid"), [Lambda(["x"], Ref("x"))]), [Val(10)]))
print_type(expr2_5_2e, "expr2_5_2e")
# print(evaluate(expr2_5_2e, Env())): but could evaluate

まとめとか

とりあえず、きちんと動き、型推論できるようにはなっている。言語としてLetRecは入れたいところ。

式要素が7つで、これだけなので、要素が数十ある普通の言語だとどうなるんだろうか。構造チェックと同一視を行う推論部分は型構造に対してのべた書きだけど、汎用の推論エンジン(ってある?)をつかえば減らせるのかもしれない。そうなったらそうなったで推論エンジンの使い方に戸惑いそうだけど。

リンク

動機とか紆余曲折とか、あとアルゴリズムの情報源も

2007/10/22

[][] すごく遅まきながらOpenCVを使ってみた  すごく遅まきながらOpenCVを使ってみた - ラシウラ出張所 を含むブックマーク はてなブックマーク -  すごく遅まきながらOpenCVを使ってみた - ラシウラ出張所  すごく遅まきながらOpenCVを使ってみた - ラシウラ出張所 のブックマークコメント

しかも動くが(OpenCVの使い方が)完璧理解できてないという状態。

デフォルトPythonインタフェースが入ってるのはうれしいですが、使い方がよくわからないのが難しかった。とくにイメージ間操作が。

OpenCVが出たとき作られてた顔にイメージを当てるcgiを書いてみた。とりあえず、sample/python/facedetect.pyを元に、image2次元配列っぽく操作できるところまではできた。


#!/usr/bin/python

from wsgiref.handlers import CGIHandler
import os.path as path
import urllib2
import hashlib
import opencv.cv as cv
import opencv.highgui as highgui

class OpenCVGateway:
    def __call__(self, env, res):
        url = env["QUERY_STRING"]
        remote = urllib2.urlopen(url)
        content_type = remote.info().get("content-type")

        if not content_type.startswith("image/"):
            #return self._transfer_data(res, remote, content_type)
            return self._redirect(res, url)

        imagename = self._image_name(url, content_type)
        cachename = self._cache_name(imagename)
        resultname = self._converted_name(imagename)

        if not path.exists(resultname):
            if not path.exists(cachename):
                self._make_cache(cachename, remote)
                pass
            self._convert(cachename, resultname)
            pass

        #return self._transfer_image(res, resultname, content_type)
        redirect_url = self._redirect_url(env, resultname)
        return self._redirect(res, redirect_url)

   def _convert(self, cachename, resultname):
        # from sample/python/facedetect.py
        image_scale = 1.3
        storage = cv.cvCreateMemStorage(0)
        cascade_name = "haarcascade_frontalface_alt.xml"
        cascade = cv.cvLoadHaarClassifierCascade(cascade_name, cv.cvSize(1, 1))
        min_size = cv.cvSize(20, 20)
        haar_scale = 1.2
        min_neighbors = 2
        haar_flags = 0

        image = highgui.cvLoadImage(cachename)
        stampimage = highgui.cvLoadImage("stamp.png")
        maskimage = highgui.cvLoadImage("stamp.png", 0)
        modified = cv.cvCreateImage(cv.cvSize(image.width, image.height),
                                    image.depth, image.nChannels)
        cv.cvCopy(image, modified)

        gray = cv.cvCreateImage(cv.cvSize(image.width, image.height), 8, 1)
        small = cv.cvCreateImage(cv.cvSize(
            cv.cvRound(image.width / image_scale),
            cv.cvRound(image.height / image_scale)), 8, 1)

        cv.cvCvtColor(image, gray, cv.CV_BGR2GRAY)
        cv.cvResize(gray, small, cv.CV_INTER_LINEAR)
        cv.cvEqualizeHist(small, small)
        cv.cvClearMemStorage(storage)

        faces = cv.cvHaarDetectObjects(small, cascade, storage,
                                       haar_scale, min_neighbors, haar_flags,
                                       min_size)

        if not faces:
            highgui.cvSaveImage(resultname, image)
            return
        for rect in faces:
            top_left = cv.cvPoint(int(rect.x * image_scale),
                                  int(rect.y * image_scale))
            bottom_right = cv.cvPoint(int((rect.x + rect.width) * image_scale),
                                      int((rect.y + rect.height) * image_scale))


            # draw red rect around the face
            #cv.cvRectangle(modified, top_left, bottom_right,
            #               cv.CV_RGB(255, 0, 0), 3, 8, 0)

            # draw smaller image over the face
            stamp = cv.cvCreateImage(cv.cvSize(int(rect.width * image_scale),
                                               int(rect.height * image_scale)),
                                     stampimage.depth, stampimage.nChannels)
            cv.cvResize(stampimage, stamp, cv.CV_INTER_LINEAR)
            mask = cv.cvCreateImage(cv.cvSize(int(rect.width * image_scale),
                                              int(rect.height * image_scale)),
                                     maskimage.depth, maskimage.nChannels)
            cv.cvResize(maskimage, mask, cv.CV_INTER_LINEAR)

            # how to use cvCopy?
            for i in xrange(stamp.width):
                for j in xrange(stamp.height):
                    if mask[i, j] == 0:
                        continue
                    modified[top_left.x + i, top_left.y + j] = stamp[i, j]
                    pass
                pass
            pass

        highgui.cvSaveImage(resultname, modified)
        pass
    
    def _image_name(self, url, content_type):
        format = content_type[len("image/"):]
        filename = hashlib.sha1(url).hexdigest()
        return "%s.%s" % (filename, format)

    def _cache_name(self, image_name):
        return "cache/%s" % (image_name,)

    def _converted_name(self, image_name):
        return "converted/%s" % (image_name,)

    def _make_cache(self, filename, remote):
        cache = open(filename, "w")
        for line in remote.readlines():
            cache.write(line)
            pass
        cache.close()
        pass

    def _redirect_url(self, env, file):
        scheme = env["wsgi.url_scheme"]
        host = env["HTTP_HOST"]
        port = env["SERVER_PORT"]
        dirname = path.dirname(env["SCRIPT_NAME"])
        return "%s://%s:%s%s/%s" % (scheme, host, port, dirname, file)

    def _redirect(self, res, url):
        res("307 temporary redirect", [("Location", url)])
        return []

    # obsoleted
    def _transfer_data(self, res, fp, content_type):
        res("200 OK", [("Content-Type", content_type)])
        for line in fp.readlines():
            yield line
            pass
        pass

    # obsoleted
    def _transfer_image(self, res, filename, content_type):
        fp = open(filename)
        return self._transfer_data(res, fp, content_type)
    pass

CGIHandler().run(OpenCVGateway())

このcgiのほかに、cacheとconvertedのディレクトリ、スタンプ用のstamp.pngと、OpenCV付属のhaarcascade_frontalface_alt.xmlも必要。

長いけど、OpenCV部分は_convertメソッド内で閉じてます。ほかはWSGIの処理で、その中でURLの?以降の文字列をそのままリソースをとってきて、画像でなければそこにリダイレクト画像なら保存して変換、変換したファイルリダイレクトしてます。

_convert部分は、facedetectほぼそのまま、なぜscale変換をしてるのかもよくわかってない。検出できた顔部分のrectから、stampをそのサイズにリサイズするまではいいが、それを埋め込むときどうするかで悩む。もしかして行列操作なのかな。

2007/05/19

[] 左結合な中置演算子の作り方  左結合な中置演算子の作り方 - ラシウラ出張所 を含むブックマーク はてなブックマーク -  左結合な中置演算子の作り方 - ラシウラ出張所  左結合な中置演算子の作り方 - ラシウラ出張所 のブックマークコメント

より、無駄を減らして

class infix:
    def __init__(self, func):
        self.func = func
        pass
    def __call__(self, left, right):
        return self.func(left, right)
    def __ror__(self, left):
        return boundleft(lambda right: self.func(left, right))
    pass

class boundleft:
    def __init__(self, func):
        self.func = func
        pass
    def __call__(self, right):
        return self.func(right)
    def __or__(self, right):
        return self.func(right)
    pass

Python-2.5では使い方は以下のようになる。infixが使えるのは二引数で呼び出しできるものでなくてはいけない:

@infix
def mul(a, b):
    return a * b

@infix
def add(a, b):
    return a + b

print 2 |add| 3 |mul| 4 #=> (2 + 3) * 4 = 20

結果から左結合の二項演算子になっていることを確認できる。

旧来の使い方も可能

isa = infix(lambda a, b: a.__class__ == b.__class__)

print [1,2,3] |isa| [] #=> True

右結合は?

Pythonの右結合演算子は、代入関係と**だけになる。まず、代入演算はa += b += cのような連続はパーズエラーになるため、使えない。

唯一**が__pow__や__rpow__で上書きできる。ただし、それを有効にするためにはinstance型は、__coerce__で引数側の型を変換しなくてはいけない。

class infixr:
    def __init__(self, func):
        self.func = func
        pass
    def __call__(self, left, right):
        return self.func(left, right)
    def __pow__(self, right):
        return boundright(lambda left: self.func(left, right.obj))
    def __coerce__(self, other):
        return (self, wrap(other))
    pass

class wrap:
    def __init__(self, obj):
        self.obj = obj

class boundright:
    def __init__(self, func):
        self.func = func
        pass
    def __call__(self, left):
        return self.func(left)
    def __rpow__(self, left):
        return self.func(left.obj)
    def __coerce__(self, other):
        return (self, wrap(other))
    pass

使い方例

@infixr
def comp(f, g):
    return lambda a: f(g(a))

add2 = lambda a : a + 2
mul3 = lambda a : a * 3
add4 = lambda a : a + 4

print (add4 **comp**  mul3 **comp** add2)(1) #=> (1 +2) *3) +4 = 13

実際のところ、右結合演算子が必要な状況はほとんどない。代入、pow演算子、(Rubyのように括弧のいらない)関数適用くらいだが、それは組み込まれている。関数型言語だと、例でも使った関数合成は右結合になる。 foo.bar.buzz はfoo.(bar.buzz)である。関数適用ではfoo (bar (buzz x)))であってほしいからだ。これはpythonにもほしくなるかもしれない。

あと(遅延型)関数型言語リストの結合も右結合になる。これはリスト[a,b,c,...]が意味的に(a,(b,(c,(...,()...))))のように、(","を中置演算子としてみれば)右結合になっているからだろうか。っと思ったが、このリストPythonでいうところのgenerator/iteratorであるため、右結合が自然なのだろう。右結合のappend(alist, append(blist, append(clist, dlist))))は、自然に頭のリストをたどってそれが尽きたら次のappendに入れる。一方、左結合のappend(append(append(append(alist, blist), clist), dlist)は先頭の要素を取り出すにも一番奥のappendにいかなくてはいかなくなる。

pythonリストは左結合の+を使って連結するが、右左どちらでもかまわない。

2007/04/20

[][] tracのPluginの書き方  tracのPluginの書き方 - ラシウラ出張所 を含むブックマーク はてなブックマーク -  tracのPluginの書き方 - ラシウラ出張所  tracのPluginの書き方 - ラシウラ出張所 のブックマークコメント

tracは0.9以降、pluginシステムを持つようになりました。pluginをつくり組み込むことで、機能拡張だけでなく、システムの振る舞いを変更するようなことまでできるようになります。

tracの外部リンクを変えるようにしたい機会があったので、それができるような簡単なpluginを作りました。tracサイトを作るとき適用したpluginがあるコミュニティTrac Hacksに登録しておきました:

このExtLLinkRewriterPluginを例にします。

setuptoolsによるeggパッケージ形式

tracではsetuptoolsを使ったeggパッケージでpluginを管理します。

setuptoolsは、.NETのアセンブリ、もしくはJavajarファイルバージョン管理をつけたもの、らと同じ位置づけにあるツールライブラリです。パッケージ作成は、Python標準のdistutilsと同様、setup.pyを使うようになっています。実際distutils用のsetup.pyのほとんどは、importするモジュールをsetuptoolsに変えるだけで使えるようになります。

setuptools用のsetup.pyはbdist_eggというコマンドを提供しeggパッケージを作成できます。作成されるeggパッケージは、PKZIP形式で、中にはメタデータ情報ファイルと対象のPythonコードが入っています(このあたりはjarに似ています)。これはunzipで確認できます。

trac pluginとextension point

tracはextension pointで機能のフックを管理できるようになっており、pluginは主にextension pointに対して機能を提供していくことになります。

たとえば、ExtLinkRewriterPluginでは、trac.wiki.api.IWikiSyntaxProviderを提供しています。

ExtLinkRewriterPluginの構成

READMEやサンプルredirectorを除くと、以下の構成です:

  • setup.py
  • ExtLinkRewriter/__init__.py
  • ExtLinkRewriter/provider.py

このうち__init__.pyはExtLinkRewriterモジュール用で、中は空です。

ExtLinkRewriter/provider.pyにはExtLinkRewriter.provider.ExtLinkRewriterProviderクラスがあり、それが前述のIWikiSyntaxProvider extension pointにむけた機能を実装しています。

setup.pyにはtrac pluginのためのメタデータを設定しています。

setup.py

from setuptools import setup

setup(
    name="ExtLinkRewriter",
    version="0.4",
    packages=['ExtLinkRewriter'],
    entry_points = {'trac.plugins':
                    ['ExtLinkRewriter.provider = ExtLinkRewriter.provider',],},
    license = "BSD")

ほぼdistutilsのsetup.pyと同じです

[trac.plugins]
ExtLinkRewriter.provider = ExtLinkRewriter.provider

trac.pluginカテゴリに、左辺はプラグインID、右辺は後述するComponentのサブクラスを取り出せるモジュール名を書くようです。

ExtLinkRewriter/provider.py

このモジュールでは、Extension point機能を提供するクラス記述します。このクラスtrac.core.Componentのサブクラスである必要があります。


from trac.core import *
from trac.wiki import IWikiSyntaxProvider
from trac.util.html import html


class ExtLinkRewriterProvider(Component):
    """Rewrite External Link URL
    """
    implements(IWikiSyntaxProvider)

    _rewrite_format = "http://del.icio.us/url?url=%s"
    _rewrite_namespaces = "http,https,ftp"
    _rewrite_target = ""

    def get_wiki_syntax(self):
        """IWikiSyntaxProvider#get_wiki_syntax
        """
        return []

    def get_link_resolvers(self):
        """IWikiSyntaxProvider#get_link_resolvers
        """
        self._load_config()
        return [(ns.strip(), self._link_formatter)
                for ns in self._rewrite_namespaces.split(",")]

    def _link_formatter(self, formatter, ns, target, label):
        try:
            newtarget = self._rewrite_format % (ns + ":" + target,)
        except:
            newtarget = ns + ":" + target
            msg = "ExtLinkRewriter Plugin format error: %s"
            msg %= (self._rewrite_format,)
            self.log.error(msg)
            pass
        return self._make_ext_link(formatter, newtarget, label,
                                   self._rewrite_target)

    def _make_ext_link(self, formatter, url, text, target=""):
        """Formatter._make_ext_link with target attr
        """
        if not url.startswith(formatter._local):
            return html.A(html.SPAN(text, class_="icon"),
                          class_="ext-link", href=url, target=target or None)
        else:
            return html.A(text, href=url, target=target or None)
        pass

    def _load_config(self):
        self._update_config("format")
        self._update_config("namespaces")
        self._update_config("target")
        pass

    def _update_config(self, key):
        attrname = "_rewrite_" + key
        oldval = getattr(self, attrname)
        newval = self.config.get("extlinkrewriter", key, oldval)
        setattr(self, attrname, newval)
        pass
    pass
implements(IWikiSyntaxProvider)

implements()関数引数はExtension pointのクラスを列挙します。それによって、システムが対応するextension pointを使うときに、このコンポーネントを使ってくれるようになります。

IWikiSyntaxProviderは以下のメソッドを提供しなくてはいけません

  • get_wiki_syntax(): 今回は何もしない
  • get_link_resolvers(): 今回のメイン機能

以下のソースにはそれらメソッドの説明が書いてあります(IWikiSyntaxProviderは96行目くらい)

get_link_resolvers

このメソッドの仕様は、

    def get_link_resolvers():
         """Return an iterable over (namespace, formatter) tuples.
 
         Each formatter should be a function of the form
         fmt(formatter, ns, target, label), and should
         return some HTML fragment.
         The `label` is already HTML escaped, whereas the `target` is not.
         """

返すのは[(namespace,formatter),...](もしくはgenerator)であり、formatterは、formatter(formatter, namespace, target, label)という引数関数になります。

    def get_link_resolvers(self):
        """IWikiSyntaxProvider#get_link_resolvers
        """
        self._load_config()
        return [(ns.strip(), self._link_formatter)
                for ns in self._rewrite_namespaces.split(",")]

で、最初のself._load_config()は、trac.iniのデータを取り込む。うしろは、pluginで処理するnamespace(http,https,ftpなど)とフォーマッターself._rewrite_namespaceのペアのタプルを返しています。

_link_formatter, _make_ext_link

これはプライベートメソッドです。

このフォーマッタはリンクに関する情報を受け取り、処理した結果であるHTML文字列を返すメソッドです。

    def _link_formatter(self, formatter, ns, target, label):
        try:
            newtarget = self._rewrite_format % (ns + ":" + target,)
        except:
            newtarget = ns + ":" + target
            msg = "ExtLinkRewriter Plugin format error: %s"
            msg %= (self._rewrite_format,)
            self.log.error(msg)
            pass
        return self._make_ext_link(formatter, newtarget, label,
                                   self._rewrite_target)

    def _make_ext_link(self, formatter, url, text, target=""):
        """Formatter._make_ext_link with target attr
        """
        if not url.startswith(formatter._local):
            return html.A(html.SPAN(text, class_="icon"),
                          class_="ext-link", href=url, target=target or None)
        else:
            return html.A(text, href=url, target=target or None)
        pass

この中でURL書き換えと、リンク部分だけのHTML生成を行っています。

HTMLレンダリング部分は、Trac標準のFormatterを参考にしています:

_load_config, _update_config

これもプライベートメソッドです。

trac.iniの情報を読み込んで、メンバーフィールドの上書きしていっています。

self.configはComponentのフィールドで、getメソッド等で、trac.iniから文字列やその他形式でデータを取り出すことができます。

たとえば、

self.config.get("extlinkrewriter", "format", "")

trac.iniの

[extlinkrewriter]
format = ...

の右辺を文字列として(strip()された状態で)取り出します(項目がない場合は第三引数の値が渡されます)。

ちなみにiniの右辺をダブルクオートでくくったりすると、ダブルクオート入りのstringが入りますので注意します。

パッケージ化とインストール

ソースが出来上がったらsetup.pyを使ってeggパッケージを作成します。

python setup.py bdist_egg

すると、dist/ExtLinkRewriter-0.4-py2.5.eggのような形式でeggパッケージが作られます。

tracで使うにはこのeggファイルをtrachomeのpluginsディレクトリコピーします。

プラグイン有効化

実際にプラグインを使うにはtrac.iniのcomponentsカテゴリモジュールをenableにするような記述をする必要があります

[components]
ExtLinkRewriter.* = enabled

つぎにtracアクセスしたら、pluginが有効になっているはずです(mod_pythonだと再起動が必要かも)。

まとめ

ExtLinkRewriterは単純なプラグインですが、plugin開発で何をすればいいかを一通りたどっています。

あとExtLinkRewriterの詳細な仕様は、以下のページに書いてあります。

リソース