let (|>) x f = f x;;
let (>>) f g x = g (f x);;
let flip f x y = f y x;;
open ExtArray;;

type cmp_op = LTZ | LEZ | EQZ | GEZ | GTZ;;
type instr =  Add of int * int
            | Sub of int * int
            | Mul of int * int
            | Div of int * int
            | Out of int * int
            | Phi of int * int
            | Nop
            | Cmpz of cmp_op * int
            | Sqrt of int
            | Copy of int
            | Inp of int;;

let cmpop_of_int = function 0 -> LTZ | 1 -> LEZ | 2 -> EQZ | 3 -> GEZ | 4 -> GTZ | x -> invalid_arg "bad cmp op";;
let show_cmpop = function LTZ->"<" | LEZ->"<=" | EQZ->"=" | GEZ->">=" | GTZ->">";; 

let (>>>) i32 k = Int32.to_int (Int32.shift_right_logical i32 k);;

let instr_of_int32 i =
  let dop =  i >>> 28 and dr1 = (i >>> 14) land 0x3FFF and dr2 = (Int32.to_int i) land 0x3FFF in
  match dop with
  | 1 -> Add(dr1, dr2)
  | 2 -> Sub(dr1, dr2)
  | 3 -> Mul(dr1, dr2)
  | 4 -> Div(dr1, dr2)
  | 5 -> Out(dr1, dr2)
  | 6 -> Phi(dr1, dr2)
  | 0 -> (let sop = i >>> 24 and imm = (i >>> 21) land 7 and sr1 = dr2 in
          match sop with
          | 0 -> Nop
          | 1 -> Cmpz(cmpop_of_int imm, sr1)
          | 2 -> Sqrt sr1
          | 3 -> Copy sr1
          | 4 -> Inp sr1
          | _ -> invalid_arg "bad sop")
  | _ -> invalid_arg "bad dop";;
       
let show_instr = function
  | Add(r1,r2) -> Printf.sprintf "r%d + r%d" r1 r2
  | Sub(r1,r2) -> Printf.sprintf "r%d - r%d" r1 r2
  | Mul(r1,r2) -> Printf.sprintf "r%d * r%d" r1 r2
  | Div(r1,r2) -> Printf.sprintf "r%d / r%d" r1 r2
  | Out(r1,r2) -> Printf.sprintf "Output %d <- r%d" r1 r2
  | Phi(r1,r2) -> Printf.sprintf "r%d else r%d" r1 r2
  | Nop -> "nop"
  | Cmpz(op, r1) -> Printf.sprintf "if r%d %s 0.0 then" r1 (show_cmpop op)
  | Sqrt r1 -> Printf.sprintf "sqrt r%d" r1
  | Copy r1 -> Printf.sprintf "r%d" r1
  | Inp r1 -> Printf.sprintf "Inport %d" r1;;

let load_program fname = 
  let inp = open_in_bin fname |> IO.input_channel in
  let readpair idx =
    if idx land 1 > 0 then
      let ins = IO.read_real_i32 inp in
      let data = IO.read_double inp in
      (ins, data)
    else
      let data = IO.read_double inp in
      let ins = IO.read_real_i32 inp in
      (ins, data)   in
  let program = Array.make 16384 Nop and memory = Array.make 16384 0.0 in 
  let rec readloop idx = 
    let (ins, data) = readpair idx in
    program.(idx) <- instr_of_int32 ins; memory.(idx) <- data;
    readloop (idx+1)  in
  (try readloop 0 with
   | IO.No_more_input -> IO.close_in inp
   | ex -> Printexc.to_string ex |> print_string; raise ex);
  (program, memory);;
   
let rec prg_length prg =
  let rec loop idx = if prg.(idx) = Nop then loop (idx-1) else (idx+1) in  loop (Array.length prg - 1);;
  
let show_program prg mem = 
  Array.sub prg 0 (prg_length prg) |>  Array.enum 
  |> Enum.mapi (fun idx ins -> Printf.sprintf "r%d = %s  \t%f\n" idx (show_instr ins) mem.(idx)) |> Enum.fold (flip (^)) "";;    
  
type vm_state = { prg : instr array;
                  prglen : int;
                   mem : float array; 
                  inports : float array; 
                  outports : float array;
                  mutable status : bool;
                  mutable time : int;
                  scenario : int;
                  mutable scoretime : int;
                  mutable log : (int * float * float) list  };;   

let init_vm fname conf = 
  let program, memory = load_program fname in
  let inp = Array.make 16384 0.0 in inp.(16000) <- float_of_int conf; 
  { prg = program; prglen = prg_length program; mem = memory; 
    inports = inp; outports = Array.make 16384 0.0; status = false; 
    time = 0;  scenario = conf; scoretime = 0; log = [] };;
  
let save_state vm = (Array.copy vm.mem, Array.copy vm.inports, Array.copy vm.outports, vm.time, vm.scoretime, vm.log);;
let restore_state vm (m, i, o, t, sct, log) = 
  for j=0 to vm.prglen-1 do
    vm.mem.(j) <- m.(j); 
    vm.inports.(j) <- i.(j); 
    vm.outports.(j) <- o.(j)
  done;
  vm.outports.(16000) <- o.(16000); vm.time <- t; vm.scoretime <- sct; vm.log <- log;;       
  
let set_thrust vm dvx dvy = 
  vm.inports.(2) <- 0.0 -. dvx; vm.inports.(3) <- 0.0 -. dvy;
  vm.log <- (vm.time, 0.0 -. dvx, 0.0 -. dvy)::vm.log;;   
  
let save_trace fname vm =
  let f = open_out_bin fname |> IO.output_channel in
  IO.write_real_i32 f 0xCAFEBABEl;
  IO.write_i32 f 273;
  IO.write_i32 f vm.scenario;
  let write_frames lst = 
    List.iter (fun (t, vx, vy)->
      IO.write_i32 f t; IO.write_i32 f 2;
      IO.write_i32 f 2; IO.write_double f vx;
      IO.write_i32 f 3; IO.write_double f vy) lst in
  let log = List.rev vm.log in
  (match log with
  | (0, vx, vy)::tl -> 
      IO.write_i32 f 0; IO.write_i32 f 3;
      IO.write_i32 f 2; IO.write_double f vx;
      IO.write_i32 f 3; IO.write_double f vy;
      IO.write_i32 f 16000; IO.write_double f (float_of_int vm.scenario);
      write_frames tl
  | lst -> 
      IO.write_i32 f 0; IO.write_i32 f 1;
      IO.write_i32 f 16000; IO.write_double f (float_of_int vm.scenario);
      write_frames lst);
  IO.write_i32 f vm.scoretime; IO.write_i32 f 0;
  IO.close_out f;;  
  
let eval_instr ip vm = function
  | Add(r1,r2) -> vm.mem.(ip) <- vm.mem.(r1) +. vm.mem.(r2)
  | Sub(r1,r2) -> vm.mem.(ip) <- vm.mem.(r1) -. vm.mem.(r2) 
  | Mul(r1,r2) -> vm.mem.(ip) <- vm.mem.(r1) *. vm.mem.(r2)
  | Div(r1,r2) -> let v2 = vm.mem.(r2) in vm.mem.(ip) <- if v2=0.0 then 0.0 else vm.mem.(r1) /. v2 
  | Out(r1,r2) -> vm.outports.(r1) <- vm.mem.(r2)
  | Phi(r1,r2) -> vm.mem.(ip) <- if vm.status then vm.mem.(r1) else vm.mem.(r2)
  | Nop -> ()
  | Cmpz(LTZ, r1) -> vm.status <- vm.mem.(r1) < 0.0
  | Cmpz(LEZ, r1) -> vm.status <- vm.mem.(r1) <= 0.0
  | Cmpz(EQZ, r1) -> vm.status <- vm.mem.(r1) = 0.0
  | Cmpz(GEZ, r1) -> vm.status <- vm.mem.(r1) >= 0.0
  | Cmpz(GTZ, r1) -> vm.status <- vm.mem.(r1) > 0.0
  | Sqrt r1 -> vm.mem.(ip) <- if vm.mem.(r1) >= 0.0 then sqrt vm.mem.(r1) else 0.0
  | Copy r1 -> vm.mem.(ip) <- vm.mem.(r1)
  | Inp r1 -> vm.mem.(ip) <- vm.inports.(r1);;

let optimize vm = 
  let prg = Array.sub vm.prg 0 (vm.prglen) in  
  let propagate cons mem p = 
    let con = Array.copy cons in
    let p' = p |> Array.mapi (fun ip ins ->    
      match ins with 
      | Add(r1,r2) -> if con.(r1) && con.(r2) then (mem.(ip) <- mem.(r1) +. mem.(r2); con.(ip)<-true; Nop) else ins 
      | Sub(r1,r2) -> if con.(r1) && con.(r2) then (mem.(ip) <- mem.(r1) -. mem.(r2); con.(ip)<-true; Nop) else ins
      | Mul(r1,r2) -> if con.(r1) && con.(r2) then (mem.(ip) <- mem.(r1) *. mem.(r2); con.(ip)<-true; Nop) else ins
      | Div(r1,r2) -> if con.(r1) && con.(r2) then 
                        (mem.(ip) <- if mem.(r2)=0.0 then 0.0 else mem.(r1) /. mem.(r2); con.(ip)<-true; Nop) else ins 
      | Out(r1,r2) -> ins
      | Phi(r1,r2) -> if ip>0 then begin
                        match prg.(ip-1) with 
                        | Cmpz _ ->  if con.(ip-1) then
                                      if mem.(ip-1) > 0.0 then Copy r1 else Copy r2
                                    else ins
                        | _ -> ins
                      end  else ins
      | Nop -> ins
      | Cmpz(LTZ, r1) -> if con.(r1) then (mem.(ip) <- if mem.(r1) < 0.0 then 1.0 else 0.0; con.(ip)<-true; Nop) else ins
      | Cmpz(LEZ, r1) -> if con.(r1) then (mem.(ip) <- if mem.(r1) <= 0.0 then 1.0 else 0.0; con.(ip)<-true; Nop) else ins
      | Cmpz(EQZ, r1) -> if con.(r1) then (mem.(ip) <- if mem.(r1) = 0.0 then 1.0 else 0.0; con.(ip)<-true; Nop) else ins
      | Cmpz(GEZ, r1) -> if con.(r1) then (mem.(ip) <- if mem.(r1) >= 0.0 then 1.0 else 0.0; con.(ip)<-true; Nop) else ins
      | Cmpz(GTZ, r1) -> if con.(r1) then (mem.(ip) <- if mem.(r1) > 0.0 then 1.0 else 0.0; con.(ip)<-true; Nop) else ins
      | Sqrt r1 -> if con.(r1) then (mem.(ip) <- if mem.(r1) >= 0.0 then sqrt mem.(r1) else 0.0; con.(ip)<-true; Nop) else ins
      | Copy r1 -> if con.(r1) then (mem.(ip) <- mem.(r1); con.(ip)<-true; Nop) else ins
      | Inp r1 -> mem.(ip) <- vm.inports.(r1); con.(ip)<-true; Nop) in
    con, p'   in
  let mem = Array.copy vm.mem in
  let rec loop p con =
    let con', p' = propagate con mem p in
    if con' = con && p = p' then p, con else loop p' con'  in 
  let const = prg |> Array.map (function Nop | Inp _ -> true | _ -> false) in
  let p, con = loop prg const  in
  { prg = p; prglen = vm.prglen; mem = mem; 
    inports = Array.copy vm.inports; outports = Array.copy vm.outports; status = vm.status;
    time = vm.time; scenario = vm.scenario; log = vm.log; scoretime = vm.scoretime  };;      
            
let interpret vm = 
  for ip = 0 to vm.prglen - 1 do
    eval_instr ip vm vm.prg.(ip)
  done;
  vm.time <- vm.time + 1;
  if vm.outports.(0)>0.0 && vm.scoretime=0 then vm.scoretime <- vm.time;;
  
type next_step = Cont of (unit -> unit) | Phiargs of int * int * int * (unit -> unit);;   
  
let comp_instr ip vm ins nxt = 
  match ins, nxt with
  | Add(r1,r2), Cont k -> Cont(fun () -> vm.mem.(ip) <- vm.mem.(r1) +. vm.mem.(r2); k ())    
  | Sub(r1,r2), Cont k -> Cont(fun () -> vm.mem.(ip) <- vm.mem.(r1) -. vm.mem.(r2); k ())    
  | Mul(r1,r2), Cont k -> Cont(fun () -> vm.mem.(ip) <- vm.mem.(r1) *. vm.mem.(r2); k ())    
  | Div(r1,r2), Cont k -> Cont(fun () -> 
      let v2 = vm.mem.(r2) in vm.mem.(ip) <- if v2=0.0 then 0.0 else vm.mem.(r1) /. v2; k ())  
  | Out(r1,r2), Cont k -> Cont(fun () -> vm.outports.(r1) <- vm.mem.(r2); k ())    
  | Phi(r1,r2), Cont k -> Phiargs(r1, r2, ip, k)
  | Nop, Cont k -> Cont k 
  | Cmpz(LTZ, r1), Phiargs(p1,p2, pip, k) -> Cont(fun () -> 
      vm.mem.(pip) <- if vm.mem.(r1) < 0.0 then vm.mem.(p1) else vm.mem.(p2); k ())
  | Cmpz(LEZ, r1), Phiargs(p1,p2, pip, k) -> Cont(fun () -> 
      vm.mem.(pip) <- if vm.mem.(r1) <= 0.0 then vm.mem.(p1) else vm.mem.(p2); k ())
  | Cmpz(EQZ, r1), Phiargs(p1,p2, pip, k) -> Cont(fun () -> 
      vm.mem.(pip) <- if vm.mem.(r1) = 0.0 then vm.mem.(p1) else vm.mem.(p2); k ())
  | Cmpz(GEZ, r1), Phiargs(p1,p2, pip, k) -> Cont(fun () -> 
      vm.mem.(pip) <- if vm.mem.(r1) >= 0.0 then vm.mem.(p1) else vm.mem.(p2); k ())
  | Cmpz(GTZ, r1), Phiargs(p1,p2, pip, k) -> Cont(fun () -> 
      vm.mem.(pip) <- if vm.mem.(r1) > 0.0 then vm.mem.(p1) else vm.mem.(p2); k ())
  | Sqrt r1, Cont k -> Cont(fun () -> vm.mem.(ip) <- if vm.mem.(r1) >= 0.0 then sqrt vm.mem.(r1) else 0.0; k ())
  | Copy r1, Cont k -> Cont(fun () -> vm.mem.(ip) <- vm.mem.(r1); k ())
  | Inp r1,  Cont k -> Cont(fun () -> vm.mem.(ip) <- vm.inports.(r1); k ())
  | _, _ -> Printf.printf "compile error: ip=%d\n" ip; failwith "comp_instr";;  
  
let comp_prg vm = 
  let rec loop ip nxt = 
    if ip < 0 then nxt else
    let ci = comp_instr ip vm vm.prg.(ip) nxt in
    loop (ip-1) ci  in
  match loop (vm.prglen-1) (Cont(fun ()-> 
    vm.time <- vm.time + 1;  if vm.outports.(0)>0.0 && vm.scoretime=0 then vm.scoretime <- vm.time)) 
  with Cont k -> k  | Phiargs _ -> failwith "comp_prg: Phiargs returned";;