let (|>) x f = f x;;  
let last_uid = ref 0 
let uid () = incr last_uid; !last_uid;; 
let mkvar x = Printf.sprintf "%s%d" x (uid ()) 
 
module M = Map.Make(String) 
let setv env var value = M.add var value env 
let getv env var = try M.find var env with Not_found -> failwith ("var not found: "^var)  
 
type name = string 
type rt_val = (* runtime values *) 
  | VInt of int | VUnit | VPair of rt_val * rt_val | VLeft of rt_val | VRight of rt_val  
  | VCont of name * il_expr * env_t | VPrint of il_expr | VStop of string 
 
and env_t = rt_val M.t           
                                           
and il_expr = 
  | Var of name 
  | Const of int 
  | Unit 
  | Pair of il_expr * il_expr 
  | Left of il_expr 
  | Right of il_expr 
  | Fst of il_expr 
  | Snd of il_expr 
  | Case of il_expr * il_expr * il_expr (* case e1 of Left x -> (x)e2 | Right y -> (y)e3 *) 
  | Print of il_expr (* print argument; run expr *) 
  | Lambda of name * il_expr 
  | RunCont of il_expr * il_expr (* e1 <= e2 *) 
  | Stop of string 
  | Error of string 
 
let rec show = function 
  | VInt i -> string_of_int i | VUnit -> "()" | VPair(v1, v2) -> Printf.sprintf "(%s, %s)" (show v1) (show v2) 
  | VLeft v -> "Left " ^ (show v)   
  | VRight v ->  "Right " ^ (show v)  
  | VCont(arg, e, en) -> Printf.sprintf "f(%s)" arg  
  | VPrint v -> "print; "  
  | VStop s -> "stop: " ^ s 
 
let rec eval_expr env = function 
  | Const n -> VInt n 
  | Unit -> VUnit 
  | Pair(e1, e2) -> VPair(eval_expr env e1, eval_expr env e2) 
  | Left e -> VLeft(eval_expr env e) 
  | Right e -> VRight(eval_expr env e) 
  | Fst e -> (match eval_expr env e with VPair(e1, e2) -> e1 | _ -> failwith "fst") 
  | Snd e -> (match eval_expr env e with VPair(e1, e2) -> e2 | _ -> failwith "fst") 
  | Var name -> getv env name  
  | Lambda(x, e) -> VCont(x, e, env) 
  | Case(e, Lambda(arg1, e1), Lambda(arg2, e2)) -> 
      (match eval_expr env e with 
      | VLeft  v -> eval_expr (setv env arg1 v) e1 
      | VRight v -> eval_expr (setv env arg2 v) e2 
      | _ -> failwith "not sum type in case") 
  | Case _ -> failwith "bad case"   
  | Print e -> VPrint e 
  | RunCont _ -> failwith "evaluating runcont" 
  | Stop s -> VStop s 
  | Error s -> failwith ("Error raised: " ^ s) 
 
let rec run_expr env = function 
  | RunCont(ef, ex) ->      
      (let vf = eval_expr env ef and vx = eval_expr env ex in 
      match vf with 
      | VCont(arg, e, env') -> run_expr (setv env' arg vx) e 
      | VStop s -> print_endline ("Stopped: "^s) 
      | VPrint nxt -> Printf.printf "%s " (show vx); run_expr env nxt 
      | _ -> failwith "not a cont on left side of RunCont") 
  | Case(e, Lambda(arg1, e1), Lambda(arg2, e2)) -> 
      (match eval_expr env e with 
      | VLeft  v -> run_expr (setv env arg1 v) e1 
      | VRight v -> run_expr (setv env arg2 v) e2 
      | _ -> failwith "not sum type in case") 
  | Stop s ->  print_endline ("Stopped: "^s) 
  | Error s -> failwith ("Error raised: " ^ s) 
  | _ -> failwith "bad expr in run";; 
                         
let rec subst var sub ex = match ex with 
  | Const _  
  | Unit  
  | Stop _ | Error _ -> ex 
  | Pair(e1, e2) -> Pair(subst var sub e1, subst var sub e2) 
  | Left e -> Left (subst var sub e) 
  | Right e -> Right(subst var sub e) 
  | Fst e -> Fst(subst var sub e) 
  | Snd e -> Snd(subst var sub e) 
  | Var name -> if name = var then sub else ex 
  | Lambda(x, e) -> if x = var then ex else Lambda(x, subst var sub e)  
  | Case(e, e1, e2) -> Case(subst var sub e, subst var sub e1, subst var sub e2) 
  | Print e -> Print (subst var sub e) 
  | RunCont(e1,e2)  -> RunCont(subst var sub e1, subst var sub e2) 
         
let rec show_expr = function 
  | Const n -> string_of_int n 
  | Unit -> "()" 
  | Pair(e1, e2) -> Printf.sprintf "(%s, %s)" (show_expr e1) (show_expr e2) 
  | Left e -> "Left " ^ (show_expr e) 
  | Right e -> "Right " ^ (show_expr e) 
  | Fst e -> "Fst " ^ (show_expr e) 
  | Snd e -> "Snd " ^ (show_expr e) 
  | Var name -> name  
  | Lambda(x, e) -> Printf.sprintf "\\%s . %s" x (show_expr e) 
  | Case(e, e1, e2) -> Printf.sprintf "case %s of \nLeft  -> %s\nRight -> %s\n" (show_expr e) (show_expr e1) (show_expr e2) 
  | Print e -> "Print " ^ (show_expr e) 
  | RunCont(e1, e2) -> Printf.sprintf "(%s) < (%s)" (show_expr e1) (show_expr e2) 
  | Stop s -> "Stop " ^ s         
  | Error s -> "Error: " ^ s 
                                                         
type fl_expr =  
  | FVar of name 
  | FConst of int 
  | FLambda of name * fl_expr 
  | FApply of fl_expr * fl_expr (* f x *) 
  | FUnit 
  | FPair of fl_expr * fl_expr 
  | FCaseP of fl_expr * name * name * fl_expr (* case e1 of (x,y) -> (x,y)e2 *) 
  | FLeft of fl_expr 
  | FRight of fl_expr 
  | FCase of fl_expr * fl_expr * fl_expr (* case e1 of Left x -> (x)e2 | Right y -> (y)e3 *) 
  | FStop of string 
  | FPrint of fl_expr * fl_expr (* print e1; return e2 *) 
  | FError of string 
 
let rec compile_strict e = function 
  | FVar name -> RunCont(e, Var name) 
  | FConst n -> RunCont(e, Const n) 
  | FError s -> Error s  
  | FLambda(arg, expr) ->  
      let w = mkvar "w" in 
      let exp = compile_strict (Snd(Var w)) expr in 
      RunCont(e, Lambda(w, subst arg (Fst(Var w)) exp))    
  | FApply(e1, e2) ->  
      let x = mkvar "x" and x' = mkvar "x'" in 
      compile_strict (Lambda(x, compile_strict (Lambda(x', RunCont(Var x, Pair(Var x', e)))) e2)) e1 
  | FUnit -> RunCont(e, Unit) 
  | FPair(e1, e2) ->  
      let x = mkvar "x" and x' = mkvar "x'" in 
      compile_strict (Lambda(x, compile_strict (Lambda(x', RunCont(e, Pair(Var x, Var x')))) e2)) e1 
  | FCaseP(pe, lname, rname, expr) -> 
       let w = mkvar "w" in 
      let e' = compile_strict e expr |> subst lname (Fst(Var w)) |> subst rname (Snd(Var w)) in 
      compile_strict (Lambda(w, e')) pe     
  | FLeft  expr -> let x = mkvar "x" in compile_strict (Lambda(x, RunCont(e, Left (Var x)))) expr 
  | FRight expr -> let x = mkvar "x" in compile_strict (Lambda(x, RunCont(e, Right (Var x)))) expr 
  | FCase(sume, FLambda(arg1, e1), FLambda(arg2, e2)) -> 
      let z = mkvar "z" in 
      let case1 = compile_strict e e1 and case2 = compile_strict e e2 in  
      compile_strict (Lambda(z, Case(Var z, Lambda(arg1, case1), Lambda(arg2, case2)))) sume 
  | FCase _ -> failwith "bad fcase"   
  | FStop s -> Stop s 
  | FPrint(x, r) -> compile_strict (Print (compile_strict e r)) x;;  
 
let rec lz e =  
  let k = mkvar "k" in Lambda(k, compile_lazy (Var k) e)  
 
and compile_lazy e = function 
  | FVar name -> RunCont(Var name, e) 
  | FConst n -> RunCont(e, Const n)  
  | FError s -> Error s   
  | FLambda(arg, expr) ->  
      let w = mkvar "w" in 
      let exp = compile_lazy (Snd(Var w)) expr in 
      RunCont(e, Lambda(w, subst arg (Fst(Var w)) exp))    
  | FApply(e1, e2) ->  
      let x = mkvar "x" in 
      compile_lazy (Lambda(x, RunCont(Var x, Pair(lz e2, e)))) e1 
  | FUnit -> RunCont(e, Unit) 
  | FPair(e1, e2) -> RunCont(e, Pair(lz e1, lz e2))      
  | FCaseP(pe, lname, rname, expr) -> 
       let w = mkvar "w" in 
      let e' = compile_lazy e expr |> subst lname (Fst(Var w)) |> subst rname (Snd(Var w)) in 
      compile_lazy (Lambda(w, e')) pe     
  | FLeft  expr -> RunCont(e, Left(lz expr)) 
  | FRight expr -> RunCont(e, Right(lz expr)) 
  | FCase(sume, FLambda(arg1, e1), FLambda(arg2, e2)) -> 
      let z = mkvar "z" in 
      let case1 = compile_lazy e e1 and case2 = compile_lazy e e2 in  
      compile_lazy (Lambda(z, Case(Var z, Lambda(arg1, case1), Lambda(arg2, case2)))) sume 
  | FCase _ -> failwith "bad fcase"   
  | FStop s -> Stop s 
  | FPrint(x, r) ->  compile_lazy (Print (compile_lazy e r)) x;; 
                         
let swap = FLambda("pair", FCaseP(FVar "pair", "a", "b", FPair(FVar "b", FVar "a"))) in 
let pair = FPair(FError "hi", FConst 44) in  
let pair2 = FPair(FPrint(FConst 101, FConst 1), FPrint(FConst 102, FConst 2)) in 
let first = FLambda("p", FCaseP(FVar "p", "fst", "snd", FVar "fst")) in   
let fprg =  
  FPrint(FApply(first, FApply(swap, pair2)), FUnit)   
in 
let ilprg = compile_lazy (Stop "ok") fprg in 
print_endline (show_expr ilprg); 
run_expr M.empty ilprg;;