3 people like it.

Type Inference for ML/Haskell

This is just the ML code from http://jozefg.bitbucket.org/posts/2015-02-28-type-inference.html ported to F#.

  1: 
  2: 
  3: 
  4: 
  5: 
  6: 
  7: 
  8: 
  9: 
 10: 
 11: 
 12: 
 13: 
 14: 
 15: 
 16: 
 17: 
 18: 
 19: 
 20: 
 21: 
 22: 
 23: 
 24: 
 25: 
 26: 
 27: 
 28: 
 29: 
 30: 
 31: 
 32: 
 33: 
 34: 
 35: 
 36: 
 37: 
 38: 
 39: 
 40: 
 41: 
 42: 
 43: 
 44: 
 45: 
 46: 
 47: 
 48: 
 49: 
 50: 
 51: 
 52: 
 53: 
 54: 
 55: 
 56: 
 57: 
 58: 
 59: 
 60: 
 61: 
 62: 
 63: 
 64: 
 65: 
 66: 
 67: 
 68: 
 69: 
 70: 
 71: 
 72: 
 73: 
 74: 
 75: 
 76: 
 77: 
 78: 
 79: 
 80: 
 81: 
 82: 
 83: 
 84: 
 85: 
 86: 
 87: 
 88: 
 89: 
 90: 
 91: 
 92: 
 93: 
 94: 
 95: 
 96: 
 97: 
 98: 
 99: 
100: 
101: 
102: 
103: 
104: 
105: 
106: 
107: 
108: 
109: 
110: 
111: 
112: 
113: 
114: 
115: 
116: 
117: 
118: 
119: 
120: 
121: 
122: 
123: 
124: 
125: 
126: 
127: 
128: 
129: 
130: 
131: 
132: 
133: 
134: 
135: 
136: 
137: 
138: 
139: 
140: 
141: 
142: 
143: 
144: 
145: 
146: 
147: 
148: 
149: 
type tvar = int

let freshSource = ref 0
let fresh () : tvar = 
    let v = !freshSource
    freshSource := !freshSource + 1
    v


type monotype =     | TBool
                    | TArr of monotype * monotype
                    | TVar of tvar

type polytype = PolyType of int list * monotype

type exp = 
        | True
        | False
        | Var of int
        | App of exp * exp
        | Let of exp * exp
        | Fn of exp
        | If of exp * exp * exp

type info = | PolyTypeVar of polytype
            | MonoTypeVar of monotype

type context = info list

// We’ll also need to substitute a type variable for a type.
let rec subst ty' var ty =
    match ty with 
    | TVar var' -> if var = var' then ty' else TVar var'
    | TArr (l, r) -> TArr (subst ty' var l, subst ty' var r)
    | TBool -> TBool

// We also want to be able to find out all the free variables in a type.
let rec freeVars t =
    match t with
    | TVar v -> [v]
    | TArr (l, r) -> freeVars l @ freeVars r
    | TBool -> []

let rec dedup l = 
    match l with
    | [] -> []
    | (x :: xs) ->
        if List.exists (fun y -> x = y) xs
        then dedup xs
        else x :: dedup xs

let generalizeMonoType ctx ty =
    let notMem xs x = List.forall (fun y -> x <> y) xs
    let free m = 
        match m with
        | (MonoTypeVar m) -> freeVars m
        | (PolyTypeVar (PolyType (bs, m))) ->
            List.filter (notMem bs) (freeVars m)

    let ctxVars = List.concat (List.map free ctx)
    let polyVars = List.filter (notMem ctxVars) (freeVars ty)
    PolyType (dedup polyVars, ty)

let mintNewMonoType (PolyType (ls, ty)) =
    List.foldBack (fun v t -> subst (TVar (fresh ())) v t) ls ty 

exception UnboundVar of int
let lookupVar var ctx =
    try match List.nth ctx var with
        | PolyTypeVar pty -> mintNewMonoType pty
        | MonoTypeVar mty -> mty 
    with ex -> raise (UnboundVar var)

let applySol sol ty =
    List.foldBack (fun (v, ty) ty' -> subst ty v ty') sol ty 

let applySolCxt sol cxt =
    let applyInfo i =
        match i with
        | PolyTypeVar (PolyType (bs, m)) ->
                PolyTypeVar (PolyType (bs, (applySol sol m)))
        | MonoTypeVar m -> MonoTypeVar (applySol sol m)
    List.map applyInfo cxt

let addSol v ty sol = (v, applySol sol ty) :: sol

let occursIn v ty = List.exists (fun v' -> v = v') (freeVars ty)

let substConstrs ty var cs =
    List.map (fun (l, r) -> (subst ty var l, subst ty var r)) cs

exception UnificationError of monotype * monotype
let rec unify csl =
    match csl with
    | [] -> []
    | (c :: constrs) -> 
        match c with
        | (TBool, TBool) -> unify constrs
        | (TVar i, TVar j) ->
            if i = j
            then unify constrs
            else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
        | ((TVar i, ty) | (ty, TVar i)) ->
            if occursIn i ty
            then raise (UnificationError c)
            else addSol i ty (unify (substConstrs ty i constrs))
        | (TArr (l, r), TArr (l', r')) ->
            unify ((l, l') :: (r, r') :: constrs)
        | _ -> raise (UnificationError c)

let (<+>) sol1 sol2 =
    let notInSol2 v = List.forall (fun (v', _) -> v <> v') sol2
    let sol1' = List.filter (fun (v, _) -> notInSol2 v) sol1
    List.map (fun (v, ty) -> (v, applySol sol1 ty)) sol2 @ sol1'
    
let rec constrain ctx v =
    match v with 
    | True -> (TBool, [])
    | False -> (TBool, [])
    | Var i -> (lookupVar i ctx, [])
    | Fn body ->
        let argTy = TVar (fresh ())
        let (rTy, sol) = constrain (MonoTypeVar argTy :: ctx) body
        (TArr (applySol sol argTy, rTy), sol) 
    | If (i, t, e) ->
        let (iTy, sol1) = constrain ctx i
        let (tTy, sol2) = constrain (applySolCxt sol1 ctx) t
        let (eTy, sol3) = constrain (applySolCxt (sol1 <+> sol2) ctx) e
        let sol = sol1 <+> sol2 <+> sol3
        let sol = sol <+> unify [ (applySol sol iTy, TBool); (applySol sol tTy, applySol sol eTy)]
        (tTy, sol)
    | App (l, r) ->
        let (domTy, ranTy) = (TVar (fresh ()), TVar (fresh ()))
        let (funTy, sol1) = constrain ctx l
        let (argTy, sol2) = constrain (applySolCxt sol1 ctx) r
        let sol = sol1 <+> sol2
        let sol = sol <+> unify [(applySol sol funTy, applySol sol (TArr (domTy, ranTy)));
                                 (applySol sol argTy, applySol sol domTy)]
        (ranTy, sol)
     | Let (e, body) ->
        let (eTy, sol1) = constrain ctx e
        let ctx' = applySolCxt sol1 ctx
        let eTy' = generalizeMonoType ctx' (applySol sol1 eTy)
        let (rTy, sol2) = constrain (PolyTypeVar eTy' :: ctx') body
        (rTy, sol1 <+> sol2)

let infer e =
    let (ty, sol) = constrain [] e
    generalizeMonoType [] (applySol sol ty) 
Multiple items
val int : value:'T -> int (requires member op_Explicit)

Full name: Microsoft.FSharp.Core.Operators.int

--------------------
type int = int32

Full name: Microsoft.FSharp.Core.int

--------------------
type int<'Measure> = int

Full name: Microsoft.FSharp.Core.int<_>
val freshSource : int ref

Full name: Script.freshSource
Multiple items
val ref : value:'T -> 'T ref

Full name: Microsoft.FSharp.Core.Operators.ref

--------------------
type 'T ref = Ref<'T>

Full name: Microsoft.FSharp.Core.ref<_>
val fresh : unit -> tvar

Full name: Script.fresh
type tvar = int

Full name: Script.tvar
val v : int
type monotype =
  | TBool
  | TArr of monotype * monotype
  | TVar of tvar

Full name: Script.monotype
union case monotype.TBool: monotype
union case monotype.TArr: monotype * monotype -> monotype
union case monotype.TVar: tvar -> monotype
type polytype = | PolyType of int list * monotype

Full name: Script.polytype
union case polytype.PolyType: int list * monotype -> polytype
type 'T list = List<'T>

Full name: Microsoft.FSharp.Collections.list<_>
Multiple items
val exp : value:'T -> 'T (requires member Exp)

Full name: Microsoft.FSharp.Core.Operators.exp

--------------------
type exp =
  | True
  | False
  | Var of int
  | App of exp * exp
  | Let of exp * exp
  | Fn of exp
  | If of exp * exp * exp

Full name: Script.exp
union case exp.True: exp
union case exp.False: exp
union case exp.Var: int -> exp
union case exp.App: exp * exp -> exp
union case exp.Let: exp * exp -> exp
union case exp.Fn: exp -> exp
union case exp.If: exp * exp * exp -> exp
type info =
  | PolyTypeVar of polytype
  | MonoTypeVar of monotype

Full name: Script.info
union case info.PolyTypeVar: polytype -> info
union case info.MonoTypeVar: monotype -> info
type context = info list

Full name: Script.context
val subst : ty':monotype -> var:tvar -> ty:monotype -> monotype

Full name: Script.subst
val ty' : monotype
val var : tvar
val ty : monotype
val var' : tvar
val l : monotype
val r : monotype
val freeVars : t:monotype -> tvar list

Full name: Script.freeVars
val t : monotype
val v : tvar
val dedup : l:'a list -> 'a list (requires equality)

Full name: Script.dedup
val l : 'a list (requires equality)
val x : 'a (requires equality)
val xs : 'a list (requires equality)
Multiple items
module List

from Microsoft.FSharp.Collections

--------------------
type List<'T> =
  | ( [] )
  | ( :: ) of Head: 'T * Tail: 'T list
  interface IEnumerable
  interface IEnumerable<'T>
  member Head : 'T
  member IsEmpty : bool
  member Item : index:int -> 'T with get
  member Length : int
  member Tail : 'T list
  static member Cons : head:'T * tail:'T list -> 'T list
  static member Empty : 'T list

Full name: Microsoft.FSharp.Collections.List<_>
val exists : predicate:('T -> bool) -> list:'T list -> bool

Full name: Microsoft.FSharp.Collections.List.exists
val y : 'a (requires equality)
val generalizeMonoType : ctx:info list -> ty:monotype -> polytype

Full name: Script.generalizeMonoType
val ctx : info list
val notMem : ('a list -> 'a -> bool) (requires equality)
val forall : predicate:('T -> bool) -> list:'T list -> bool

Full name: Microsoft.FSharp.Collections.List.forall
val free : (info -> tvar list)
val m : info
val m : monotype
val bs : int list
val filter : predicate:('T -> bool) -> list:'T list -> 'T list

Full name: Microsoft.FSharp.Collections.List.filter
val ctxVars : tvar list
val concat : lists:seq<'T list> -> 'T list

Full name: Microsoft.FSharp.Collections.List.concat
val map : mapping:('T -> 'U) -> list:'T list -> 'U list

Full name: Microsoft.FSharp.Collections.List.map
val polyVars : tvar list
val mintNewMonoType : polytype -> monotype

Full name: Script.mintNewMonoType
val ls : int list
val foldBack : folder:('T -> 'State -> 'State) -> list:'T list -> state:'State -> 'State

Full name: Microsoft.FSharp.Collections.List.foldBack
exception UnboundVar of int

Full name: Script.UnboundVar
val lookupVar : var:int -> ctx:info list -> monotype

Full name: Script.lookupVar
val var : int
val nth : list:'T list -> index:int -> 'T

Full name: Microsoft.FSharp.Collections.List.nth
val pty : polytype
val mty : monotype
val ex : exn
val raise : exn:System.Exception -> 'T

Full name: Microsoft.FSharp.Core.Operators.raise
val applySol : sol:(tvar * monotype) list -> ty:monotype -> monotype

Full name: Script.applySol
val sol : (tvar * monotype) list
val applySolCxt : sol:(tvar * monotype) list -> cxt:info list -> info list

Full name: Script.applySolCxt
val cxt : info list
val applyInfo : (info -> info)
val i : info
val addSol : v:tvar -> ty:monotype -> sol:(tvar * monotype) list -> (tvar * monotype) list

Full name: Script.addSol
val occursIn : v:tvar -> ty:monotype -> bool

Full name: Script.occursIn
val v' : tvar
val substConstrs : ty:monotype -> var:tvar -> cs:(monotype * monotype) list -> (monotype * monotype) list

Full name: Script.substConstrs
val cs : (monotype * monotype) list
exception UnificationError of monotype * monotype

Full name: Script.UnificationError
val unify : csl:(monotype * monotype) list -> (tvar * monotype) list

Full name: Script.unify
val csl : (monotype * monotype) list
val c : monotype * monotype
val constrs : (monotype * monotype) list
val i : tvar
val j : tvar
val l' : monotype
val r' : monotype
val sol1 : (tvar * monotype) list
val sol2 : (tvar * monotype) list
val notInSol2 : (tvar -> bool)
val sol1' : (tvar * monotype) list
val constrain : ctx:info list -> v:exp -> monotype * (tvar * monotype) list

Full name: Script.constrain
val v : exp
val i : int
val body : exp
val argTy : monotype
val rTy : monotype
val i : exp
val t : exp
val e : exp
val iTy : monotype
val tTy : monotype
val eTy : monotype
val sol3 : (tvar * monotype) list
val l : exp
val r : exp
val domTy : monotype
val ranTy : monotype
val funTy : monotype
val ctx' : info list
val eTy' : polytype
val infer : e:exp -> polytype

Full name: Script.infer

More information

Link:http://fssnip.net/q2
Posted:9 years ago
Author:Rick Minerich
Tags: types , inference