Skip to content

Commit

Permalink
WIP: Better abstraction for ecCircuits
Browse files Browse the repository at this point in the history
  • Loading branch information
Gustavo Delerue committed Dec 2, 2024
1 parent 39310b5 commit 5913127
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 106 deletions.
60 changes: 56 additions & 4 deletions src/ecCircuits.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1086,10 +1086,10 @@ type cache = (ident, (cinput * circuit)) Map.t
if not: remove env argument from recursive calls *)
let circuit_of_form
?(pstate : pstate = Map.empty) (* Program variable values *)
?(cache : cache = Map.empty) (* Let-bindings and such *)
(hyps : hyps)
(f_ : EcAst.form)
: circuit =
let cache = Map.empty in

let rec doit (cache: (ident, (cinput * circuit)) Map.t) (hyps: hyps) (f_: form) : hyps * circuit =
let env = toenv hyps in
Expand Down Expand Up @@ -1497,18 +1497,18 @@ let pstate_of_memtype ?pstate (env: env) (mt : memtype) =
) (Option.get lmt).lmt_decl in
pstate_of_variables ?pstate env vars

let process_instr (hyps: hyps) (mem: memory) ?(cache: cache = Map.empty) (pstate: _) (inst: instr) =
let process_instr (hyps: hyps) (mem: memory) (pstate: _) (inst: instr) =
let env = toenv hyps in
(* Format.eprintf "[W]Processing : %a@." (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst; *)
(* let start = Unix.gettimeofday () in *)
try
match inst.i_node with
| Sasgn (LvVar (PVloc v, _ty), e) ->
let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) pstate in
let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate hyps) pstate in
(* Format.eprintf "[W] Took %f seconds@." (Unix.gettimeofday() -. start); *)
pstate
| Sasgn (LvTuple (vs), e) ->
let tp = (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) in
let tp = (form_of_expr mem e |> circuit_of_form ~pstate hyps) in
assert (is_bwtuple tp.circ);
let comps = circuits_of_circuit tp in
let pstate = List.fold_left2 (fun pstate (pv, _ty) c ->
Expand Down Expand Up @@ -1590,3 +1590,55 @@ let instrs_equiv
let circ2 = { circ2 with inps = inputs @ circ2.inps } in
circ_equiv circ1 circ2 None
)

let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t =
let pstate : (symbol, circuit) Map.t = Map.empty in

let inps = List.map (input_of_variable env) invs in
let inpcs, inps = List.split inps in
(* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *)
let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in

inps, List.fold_left
(fun pstate (inp, v) -> Map.add v inp pstate)
pstate inpcs

(* Generates pstate : (symbol, circuit) Map from program
and inputs associated to the program
Throws: CircError on failure
*)
let pstate_of_prog (hyps: hyps) (mem: memory) (proc: instr list) (invs: variable list) : (symbol, circuit) Map.t =
let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in

let pstate =
List.fold_left (process_instr hyps mem) pstate proc
in
Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate

(* FIXME: refactor this function *)
let rec circ_simplify_form_bitstring_equality
?(mem = mhr)
?(pstate: (symbol, circuit) Map.t = Map.empty)
?(pcond: circuit option)
(hyps: hyps)
(f: form)
: form =
let env = toenv hyps in

let rec check (f : form) =
match EcFol.sform_of_form f with
| SFeq (f1, f2)
when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty)
|| (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty)
->
let c1 = circuit_of_form ~pstate hyps f1 in
let c2 = circuit_of_form ~pstate hyps f2 in
Format.eprintf "[W]Testing circuit equivalence for forms:
%a@.%[email protected] circuits: %s | %s@."
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2
(circuit_to_string c1)
(circuit_to_string c2);
f_bool (circ_equiv c1 c2 pcond)
| _ -> f_map (fun ty -> ty) check f
in check f
28 changes: 16 additions & 12 deletions src/ecCircuits.mli
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,38 @@ open LDecl
module Map = Batteries.Map

(* -------------------------------------------------------------------- *)
type circ
type cinput
type circuit = { circ: circ; inps: cinput list; }
type circuit
type pstate = (symbol, circuit) Map.t
type cache = (EcIdent.t, (cinput * circuit)) Map.t
(*type cache = (EcIdent.t, (cinput * circuit)) Map.t*)

(* -------------------------------------------------------------------- *)
exception CircError of string

(* -------------------------------------------------------------------- *)
val get_specification_by_name : string -> Lospecs.Ast.adef option
val circ_red : hyps -> EcReduction.reduction_info
val cinput_to_string : cinput -> string
val cinput_of_type : ?idn:ident -> env -> ty -> cinput
(*val cinput_to_string : cinput -> string*)
(*val cinput_of_type : ?idn:ident -> env -> ty -> cinput*)
val width_of_type : env -> ty -> int
val size_of_circ : circ -> int
(*val size_of_circ : circ -> int *)
val compute : sign:bool -> circuit -> BI.zint list -> BI.zint
val circuit_to_string : circuit -> string
val circ_ident : cinput -> circuit
(*val circ_ident : cinput -> circuit*)
val circuit_ueq : circuit -> circuit -> circuit
val circuit_aggregate : circuit list -> circuit
val circuit_aggregate_inps : circuit -> circuit
val circuit_flatten : circuit -> circuit
val circuit_permutation : int -> int -> (int -> int) -> circuit
val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list
val circ_equiv : ?strict:bool -> circuit -> circuit -> circuit option -> bool
val circuit_of_form : ?pstate:pstate -> ?cache:cache -> hyps -> form -> circuit
val pstate_of_memtype : ?pstate:pstate -> env -> memtype -> pstate * cinput list
val input_of_variable : env -> variable -> circuit * cinput
val circuit_of_form : ?pstate:pstate -> hyps -> form -> circuit
(*val pstate_of_memtype : ?pstate:pstate -> env -> memtype -> pstate * cinput list*)
val pstate_of_prog : hyps -> memory -> instr list -> variable list -> (symbol, circuit) Map.t
(*val input_of_variable : env -> variable -> circuit * cinput*)
val instrs_equiv : hyps -> memenv -> ?keep:EcPV.PV.t -> ?pstate:pstate -> instr list -> instr list -> bool
val process_instr : hyps -> memory -> ?cache:cache -> pstate -> instr -> (symbol, circuit) Map.t
val process_instr : hyps -> memory -> pstate -> instr -> (symbol, circuit) Map.t
val circ_simplify_form_bitstring_equality :
?mem:EcMemory.memory ->
?pstate:(string, circuit) Map.t ->
?pcond:circuit -> hyps -> form -> form

124 changes: 34 additions & 90 deletions src/phl/ecPhlBDep.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,6 @@ let circ_of_qsymbol (hyps: hyps) (qs: qsymbol) : circuit =
fc
with CircError err ->
raise (BDepError err)


let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t =
let pstate : (symbol, circuit) Map.t = Map.empty in

let inps = List.map (EcCircuits.input_of_variable env) invs in
let inpcs, inps = List.split inps in
(* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *)
let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in

inps, List.fold_left
(fun pstate (inp, v) -> Map.add v inp pstate)
pstate inpcs

(* Generates pstate : (symbol, circuit) Map from program
Throws: BDepError on failure
*)
let pstate_of_prog (hyps: hyps) (mem: memory) (proc: stmt) (invs: variable list) : (symbol, circuit) Map.t =
let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in

let pstate = try
List.fold_left (EcCircuits.process_instr hyps mem) pstate proc.s_node
with CircError err ->
raise (BDepError err)
in
Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate


(* -------------------------------------------------------------------- *)
Expand Down Expand Up @@ -117,7 +91,11 @@ let mapreduce

let tm = time tm "Precondition circuit generation done" in

let pstate = pstate_of_prog hyps mem proc invs in
let pstate = try
EcCircuits.pstate_of_prog hyps mem proc.s_node invs
with CircError err ->
raise (BDepError err)
in

let tm = time tm "Program circuit generation done" in

Expand All @@ -126,7 +104,7 @@ let mapreduce
(List.map (fun v -> v.v_name) outvs) in

(* This is required for now as we do not allow mapreduce with multiple arguments *)
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1);
(* assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); *)

let c = try
(circuit_aggregate circs)
Expand Down Expand Up @@ -178,9 +156,17 @@ let prog_equiv_prod
in
let tm = Unix.gettimeofday () in

let pstate_l : (symbol, circuit) Map.t = pstate_of_prog hyps meml proc_l invs_l in
let pstate_l : (symbol, circuit) Map.t = try
EcCircuits.pstate_of_prog hyps meml proc_l.s_node invs_l
with CircError err ->
raise (BDepError err)
in
let tm = time tm "Left program generation done" in
let pstate_r : (symbol, circuit) Map.t = pstate_of_prog hyps memr proc_r invs_l in
let pstate_r : (symbol, circuit) Map.t = try
EcCircuits.pstate_of_prog hyps memr proc_r.s_node invs_l
with CircError err ->
raise (BDepError err)
in
let tm = time tm "Right program generation done" in

begin
Expand All @@ -189,14 +175,8 @@ let prog_equiv_prod
let circs_r = List.map (fun v -> Option.get (Map.find_opt v pstate_r))
(List.map (fun v -> v.v_name) outvs_r) in

(* let () = List.iter2 (fun c v -> Format.eprintf "%s inputs: " v.v_name; *)
(* List.iter (Format.eprintf "%s ") (List.map cinput_to_string c.inps); *)
(* Format.eprintf "@."; ) circs outvs in *)

(* let () = List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) circs in *)
(* Only one input supported for now *)
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1);
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1);
(*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1); *)
(*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1);*)
let c_l = try
(circuit_aggregate circs_l)
with CircError _err ->
Expand Down Expand Up @@ -263,37 +243,6 @@ let prog_equiv_prod
if both sides are equivalent as circuits
or false otherwise
*)
let rec circ_simplify_form_bitstring_equality
?(mem = mhr)
?(pstate: (symbol, circuit) Map.t = Map.empty)
?(pcond: circuit option)
?(inps: cinput list option)
(hyps: hyps)
(f: form)
: form =
let env = toenv hyps in

let rec check (f : form) =
match sform_of_form f with
| SFeq (f1, f2)
when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty)
|| (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty)
->
let c1 = circuit_of_form ~pstate hyps f1 in
let c2 = circuit_of_form ~pstate hyps f2 in
let c1, c2 = match inps with
| Some inps -> {c1 with inps = inps}, {c2 with inps = inps}
| None -> c1, c2
in
Format.eprintf "[W]Testing circuit equivalence for forms:
%a@.%[email protected] circuits: %s | %s@."
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2
(circuit_to_string c1)
(circuit_to_string c2);
f_bool (circ_equiv c1 c2 pcond)
| _ -> f_map (fun ty -> ty) check f
in check f

let circ_form_eval_plus_equiv
?(mem = mhr)
Expand All @@ -307,8 +256,8 @@ let circ_form_eval_plus_equiv
let env = toenv hyps in
let redmode = circ_red hyps in
let (@@!) = EcTypesafeFol.f_app_safe env in
let inps = List.map (EcCircuits.input_of_variable env) invs in
let inpcs, inps = List.split inps in
(*let inps = List.map (EcCircuits.input_of_variable env) invs in*)
(*let inpcs, inps = List.split inps in*)
let size, of_int = match EcEnv.Circuit.lookup_bitstring env v.v_type with
| Some {size; ofint} -> size, ofint
| None ->
Expand All @@ -322,11 +271,6 @@ let circ_form_eval_plus_equiv
true
else
let cur_val = of_int @@! [f_int cur] in
let pstate : (symbol, circuit) Map.t = Map.empty in
let pstate = List.fold_left2
(fun pstate inp v -> Map.add v inp pstate)
pstate inpcs (invs |> List.map (fun v -> v.v_name))
in
let insts = List.map (fun i ->
match i.i_node with
| Sasgn (lv, e) ->
Expand All @@ -338,12 +282,12 @@ let circ_form_eval_plus_equiv
| _ -> i
) proc.s_node
in
let pstate = try
List.fold_left (EcCircuits.process_instr hyps mem) pstate insts
with CircError err ->
raise (BDepError ("Program circuit generation failed with error:\n" ^ err))
let pstate = try
EcCircuits.pstate_of_prog hyps mem insts invs
with CircError err ->
raise (BDepError err)
in
let pstate = Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate in

let f = EcPV.PVM.subst1 env (PVloc v.v_name) mem cur_val f in
let pcond = match Map.find_opt v.v_name pstate with
| Some circ -> begin try
Expand All @@ -353,10 +297,10 @@ let circ_form_eval_plus_equiv
end
| None -> None
in
let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in
(*let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*)
let f = EcCallbyValue.norm_cbv redmode hyps f in
let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in
let f = circ_simplify_form_bitstring_equality ~mem ~pstate ~inps ?pcond hyps f in
(*let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*)
let f = EcCircuits.circ_simplify_form_bitstring_equality ~mem ~pstate ?pcond hyps f in
let f = EcCallbyValue.norm_cbv (EcReduction.full_red) hyps f in
if f <> f_true then
(Format.eprintf "Got %a after reduction@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f;
Expand Down Expand Up @@ -387,17 +331,19 @@ let mapreduce_eval

let tm = time tm "Lane function circuit generation done" in

let pstate = pstate_of_prog hyps mem proc invs in
let pstate = try
EcCircuits.pstate_of_prog hyps mem proc.s_node invs
with CircError err ->
raise (BDepError err)
in

let tm = time tm "Program circuit generation done" in

begin
let circs = List.map (fun v -> Option.get (Map.find_opt v pstate)) (List.map (fun v -> v.v_name) outvs) in

assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1);
let cinp = (List.hd circs).inps in
let c = try
{(circuit_aggregate circs) with inps=cinp}
(circuit_aggregate circs)
with CircError _err ->
raise (BDepError "Failed to concatenate program outputs")
in
Expand All @@ -410,8 +356,6 @@ let mapreduce_eval

let tm = time tm "circuit dependecy analysis + splitting done" in

List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) cs;

List.iteri (fun i c ->
if circ_equiv ~strict:true (List.hd cs) c None
then ()
Expand Down

0 comments on commit 5913127

Please sign in to comment.