staload "libc/SATS/stdio.sats"
staload _ = "prelude/DATS/list.dats"
staload _ = "prelude/DATS/array.dats"
#define :: list_cons
#define nil list_nil

fn parseIP {n:nat}{p0:nat | p0 <= n} 
  (s : string n, start_pos : size_t p0): [endp:nat | endp <= n] @(uint, size_t endp) = let
    fun loop {pos : nat | pos <= n} .<n-pos>.
    (octet : uint, ip : uint, pos : size_t pos):<cloref0> [ep:nat | ep <= n] @(uint, size_t ep) = 
      if string_is_at_end(s, pos) then @(ip * 256U + octet, pos) 
      else let val c = s[pos] in
        if c >= '0' && c <= '9' then loop(octet * 10U + uint_of_char c - 48U, ip, pos+1)
        else if c = '.' then loop(0U, ip*256U + octet, pos+1) 
        else @(ip*256U + octet, pos)
      end
in  
  loop(0U, 0U, start_pos)
end

typedef IPpair = @(uint, uint)
  
fn parseIPPair{n:nat}(s : string n): IPpair =
  let val @(ip1, pos) = parseIP(s, 0) in
    if string_is_at_end(s, pos) then @(ip1, 0U) 
    else let val @(ip2, _) = parseIP(s, pos+1) in @(ip1, ip2) end
  end

fn cmp_pair(p1 : IPpair, p2 : IPpair):<0> int = let
  val @(p1a, p1b) = p1;  val @(p2a, p2b) = p2
  val da = compare_uint_uint(p1a, p2a)
in  
  if da <> 0 then da else compare_uint_uint(p1b, p2b) 
end

fn load_pairs(file : FILEref): [n:nat] list(IPpair, n) = let
  fun loop {n:nat}(file : FILEref, pairs : list(IPpair, n)): [m : nat | m >= n] list(IPpair, m) =
    let val sopt = input_line file in
      if stropt_is_some sopt then loop(file, parseIPPair(stropt_unsome sopt) :: pairs) 
      else pairs
    end
in
  loop(file, nil)
end

fun toSplitPoints {n:nat} (o : Option uint, ps : list(IPpair,n) ): [m:nat] list(uint,m) = 
  case+ (o, ps) of
  | (None (), nil()) => nil
  | (Some e, nil()) => e::nil
  | (None (), @(a,b)::ps) => a::(toSplitPoints(Some b, ps))
  | (Some e, @(a,b)::ps) =>
      if a > e then e::a::toSplitPoints(Some b, ps)
      else toSplitPoints(Some(max(e,b)), ps) 

fun bsearch {n,a,b:nat | a < n; b < n; a <= b} .<b-a>.
  (arr: array(uint,n), x:uint, a:int a, b:int b): Option uint = 
  if a=b then (if arr[a] > x then Some(uint_of_int a) else None)
  else  let val mid = a + (b-a)/2 in
          if arr[mid] > x then bsearch(arr, x, a,mid) else bsearch(arr, x, mid+1,b)
        end

fn ipgood {n:nat | n > 0} (ps: array(uint,n), n:int n, ip:uint): bool = 
  case+ bsearch(ps, ip, 0,n-1) of
  | None () => false
  | Some i => land_uint_uint(i, 1U) > 0U
  
#define BUFSZ 128  
  
fn filterIPs {n:nat | n > 0} (file: &FILE r, ps: array(uint,n), n:int n): void = let
  val (pf_gc, pf_buf | buf) = malloc_gc (BUFSZ)
  fun loop {sz:nat | sz==BUFSZ}{l:addr}
  (pf: !b0ytes sz @ l | buf: ptr l, file: &FILE r):<cloref1> void = let
    val (pf_res | res) = fgets_err (file_mode_lte_r_r, pf | buf, BUFSZ, file)
    in
      if res <> null then let
        prval fgets_v_succ (pf_strbuf) = pf_res
        val s = __cast res where { extern castfn __cast (x: ptr): String }    
        val (ip, _) = parseIP(s, 0)
      in 
        pf := bytes_v_of_strbuf_v pf_strbuf;
        if ipgood(ps, n, ip) then print s;
        loop (pf | buf, file)
      end else let 
        prval fgets_v_fail pf_bytes = pf_res in
        pf := bytes_v_of_b0ytes_v pf_bytes 
      end
    end
in
  loop (pf_buf | buf, file);
  free_gc(pf_gc, pf_buf | buf)
end
  
implement main (argc, argv) = let
    val rng_fname = if argc > 1 then argv[1] else "ipranges.txt"
    val ips_fname = if argc > 2 then argv[2] else "ips.txt"
    val rng_file = open_file_exn (rng_fname, file_mode_r)
    val pairs = load_pairs rng_file
    val () = close_file_exn rng_file
    var env:ptr = null
    val pairs = list_mergesort(pairs, lam (x1,x2,e)=> cmp_pair(x1,x2), env)
    val pnts = toSplitPoints(None, pairs)
    val np = length pnts
    val ps = array_make_lst(size1_of_int1 np, pnts)
    val (fv | ips_file) = fopen_exn (ips_fname, file_mode_r)
  in    
    if np > 0 then filterIPs(!ips_file, ps, np);
    fclose_exn (fv | ips_file);
  end