-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Better abstraction for ecCircuits
- Loading branch information
Gustavo Delerue
committed
Dec 2, 2024
1 parent
39310b5
commit 5913127
Showing
3 changed files
with
106 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1086,10 +1086,10 @@ type cache = (ident, (cinput * circuit)) Map.t | |
if not: remove env argument from recursive calls *) | ||
let circuit_of_form | ||
?(pstate : pstate = Map.empty) (* Program variable values *) | ||
?(cache : cache = Map.empty) (* Let-bindings and such *) | ||
(hyps : hyps) | ||
(f_ : EcAst.form) | ||
: circuit = | ||
let cache = Map.empty in | ||
|
||
let rec doit (cache: (ident, (cinput * circuit)) Map.t) (hyps: hyps) (f_: form) : hyps * circuit = | ||
let env = toenv hyps in | ||
|
@@ -1497,18 +1497,18 @@ let pstate_of_memtype ?pstate (env: env) (mt : memtype) = | |
) (Option.get lmt).lmt_decl in | ||
pstate_of_variables ?pstate env vars | ||
|
||
let process_instr (hyps: hyps) (mem: memory) ?(cache: cache = Map.empty) (pstate: _) (inst: instr) = | ||
let process_instr (hyps: hyps) (mem: memory) (pstate: _) (inst: instr) = | ||
let env = toenv hyps in | ||
(* Format.eprintf "[W]Processing : %a@." (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst; *) | ||
(* let start = Unix.gettimeofday () in *) | ||
try | ||
match inst.i_node with | ||
| Sasgn (LvVar (PVloc v, _ty), e) -> | ||
let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) pstate in | ||
let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate hyps) pstate in | ||
(* Format.eprintf "[W] Took %f seconds@." (Unix.gettimeofday() -. start); *) | ||
pstate | ||
| Sasgn (LvTuple (vs), e) -> | ||
let tp = (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) in | ||
let tp = (form_of_expr mem e |> circuit_of_form ~pstate hyps) in | ||
assert (is_bwtuple tp.circ); | ||
let comps = circuits_of_circuit tp in | ||
let pstate = List.fold_left2 (fun pstate (pv, _ty) c -> | ||
|
@@ -1590,3 +1590,55 @@ let instrs_equiv | |
let circ2 = { circ2 with inps = inputs @ circ2.inps } in | ||
circ_equiv circ1 circ2 None | ||
) | ||
|
||
let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t = | ||
let pstate : (symbol, circuit) Map.t = Map.empty in | ||
|
||
let inps = List.map (input_of_variable env) invs in | ||
let inpcs, inps = List.split inps in | ||
(* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *) | ||
let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in | ||
|
||
inps, List.fold_left | ||
(fun pstate (inp, v) -> Map.add v inp pstate) | ||
pstate inpcs | ||
|
||
(* Generates pstate : (symbol, circuit) Map from program | ||
and inputs associated to the program | ||
Throws: CircError on failure | ||
*) | ||
let pstate_of_prog (hyps: hyps) (mem: memory) (proc: instr list) (invs: variable list) : (symbol, circuit) Map.t = | ||
let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in | ||
|
||
let pstate = | ||
List.fold_left (process_instr hyps mem) pstate proc | ||
in | ||
Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate | ||
|
||
(* FIXME: refactor this function *) | ||
let rec circ_simplify_form_bitstring_equality | ||
?(mem = mhr) | ||
?(pstate: (symbol, circuit) Map.t = Map.empty) | ||
?(pcond: circuit option) | ||
(hyps: hyps) | ||
(f: form) | ||
: form = | ||
let env = toenv hyps in | ||
|
||
let rec check (f : form) = | ||
match EcFol.sform_of_form f with | ||
| SFeq (f1, f2) | ||
when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty) | ||
|| (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty) | ||
-> | ||
let c1 = circuit_of_form ~pstate hyps f1 in | ||
let c2 = circuit_of_form ~pstate hyps f2 in | ||
Format.eprintf "[W]Testing circuit equivalence for forms: | ||
%a@.%[email protected] circuits: %s | %s@." | ||
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1 | ||
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2 | ||
(circuit_to_string c1) | ||
(circuit_to_string c2); | ||
f_bool (circ_equiv c1 c2 pcond) | ||
| _ -> f_map (fun ty -> ty) check f | ||
in check f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,32 +57,6 @@ let circ_of_qsymbol (hyps: hyps) (qs: qsymbol) : circuit = | |
fc | ||
with CircError err -> | ||
raise (BDepError err) | ||
|
||
|
||
let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t = | ||
let pstate : (symbol, circuit) Map.t = Map.empty in | ||
|
||
let inps = List.map (EcCircuits.input_of_variable env) invs in | ||
let inpcs, inps = List.split inps in | ||
(* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *) | ||
let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in | ||
|
||
inps, List.fold_left | ||
(fun pstate (inp, v) -> Map.add v inp pstate) | ||
pstate inpcs | ||
|
||
(* Generates pstate : (symbol, circuit) Map from program | ||
Throws: BDepError on failure | ||
*) | ||
let pstate_of_prog (hyps: hyps) (mem: memory) (proc: stmt) (invs: variable list) : (symbol, circuit) Map.t = | ||
let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in | ||
|
||
let pstate = try | ||
List.fold_left (EcCircuits.process_instr hyps mem) pstate proc.s_node | ||
with CircError err -> | ||
raise (BDepError err) | ||
in | ||
Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate | ||
|
||
|
||
(* -------------------------------------------------------------------- *) | ||
|
@@ -117,7 +91,11 @@ let mapreduce | |
|
||
let tm = time tm "Precondition circuit generation done" in | ||
|
||
let pstate = pstate_of_prog hyps mem proc invs in | ||
let pstate = try | ||
EcCircuits.pstate_of_prog hyps mem proc.s_node invs | ||
with CircError err -> | ||
raise (BDepError err) | ||
in | ||
|
||
let tm = time tm "Program circuit generation done" in | ||
|
||
|
@@ -126,7 +104,7 @@ let mapreduce | |
(List.map (fun v -> v.v_name) outvs) in | ||
|
||
(* This is required for now as we do not allow mapreduce with multiple arguments *) | ||
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); | ||
(* assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); *) | ||
|
||
let c = try | ||
(circuit_aggregate circs) | ||
|
@@ -178,9 +156,17 @@ let prog_equiv_prod | |
in | ||
let tm = Unix.gettimeofday () in | ||
|
||
let pstate_l : (symbol, circuit) Map.t = pstate_of_prog hyps meml proc_l invs_l in | ||
let pstate_l : (symbol, circuit) Map.t = try | ||
EcCircuits.pstate_of_prog hyps meml proc_l.s_node invs_l | ||
with CircError err -> | ||
raise (BDepError err) | ||
in | ||
let tm = time tm "Left program generation done" in | ||
let pstate_r : (symbol, circuit) Map.t = pstate_of_prog hyps memr proc_r invs_l in | ||
let pstate_r : (symbol, circuit) Map.t = try | ||
EcCircuits.pstate_of_prog hyps memr proc_r.s_node invs_l | ||
with CircError err -> | ||
raise (BDepError err) | ||
in | ||
let tm = time tm "Right program generation done" in | ||
|
||
begin | ||
|
@@ -189,14 +175,8 @@ let prog_equiv_prod | |
let circs_r = List.map (fun v -> Option.get (Map.find_opt v pstate_r)) | ||
(List.map (fun v -> v.v_name) outvs_r) in | ||
|
||
(* let () = List.iter2 (fun c v -> Format.eprintf "%s inputs: " v.v_name; *) | ||
(* List.iter (Format.eprintf "%s ") (List.map cinput_to_string c.inps); *) | ||
(* Format.eprintf "@."; ) circs outvs in *) | ||
|
||
(* let () = List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) circs in *) | ||
(* Only one input supported for now *) | ||
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1); | ||
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1); | ||
(*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1); *) | ||
(*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1);*) | ||
let c_l = try | ||
(circuit_aggregate circs_l) | ||
with CircError _err -> | ||
|
@@ -263,37 +243,6 @@ let prog_equiv_prod | |
if both sides are equivalent as circuits | ||
or false otherwise | ||
*) | ||
let rec circ_simplify_form_bitstring_equality | ||
?(mem = mhr) | ||
?(pstate: (symbol, circuit) Map.t = Map.empty) | ||
?(pcond: circuit option) | ||
?(inps: cinput list option) | ||
(hyps: hyps) | ||
(f: form) | ||
: form = | ||
let env = toenv hyps in | ||
|
||
let rec check (f : form) = | ||
match sform_of_form f with | ||
| SFeq (f1, f2) | ||
when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty) | ||
|| (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty) | ||
-> | ||
let c1 = circuit_of_form ~pstate hyps f1 in | ||
let c2 = circuit_of_form ~pstate hyps f2 in | ||
let c1, c2 = match inps with | ||
| Some inps -> {c1 with inps = inps}, {c2 with inps = inps} | ||
| None -> c1, c2 | ||
in | ||
Format.eprintf "[W]Testing circuit equivalence for forms: | ||
%a@.%[email protected] circuits: %s | %s@." | ||
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1 | ||
(EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2 | ||
(circuit_to_string c1) | ||
(circuit_to_string c2); | ||
f_bool (circ_equiv c1 c2 pcond) | ||
| _ -> f_map (fun ty -> ty) check f | ||
in check f | ||
|
||
let circ_form_eval_plus_equiv | ||
?(mem = mhr) | ||
|
@@ -307,8 +256,8 @@ let circ_form_eval_plus_equiv | |
let env = toenv hyps in | ||
let redmode = circ_red hyps in | ||
let (@@!) = EcTypesafeFol.f_app_safe env in | ||
let inps = List.map (EcCircuits.input_of_variable env) invs in | ||
let inpcs, inps = List.split inps in | ||
(*let inps = List.map (EcCircuits.input_of_variable env) invs in*) | ||
(*let inpcs, inps = List.split inps in*) | ||
let size, of_int = match EcEnv.Circuit.lookup_bitstring env v.v_type with | ||
| Some {size; ofint} -> size, ofint | ||
| None -> | ||
|
@@ -322,11 +271,6 @@ let circ_form_eval_plus_equiv | |
true | ||
else | ||
let cur_val = of_int @@! [f_int cur] in | ||
let pstate : (symbol, circuit) Map.t = Map.empty in | ||
let pstate = List.fold_left2 | ||
(fun pstate inp v -> Map.add v inp pstate) | ||
pstate inpcs (invs |> List.map (fun v -> v.v_name)) | ||
in | ||
let insts = List.map (fun i -> | ||
match i.i_node with | ||
| Sasgn (lv, e) -> | ||
|
@@ -338,12 +282,12 @@ let circ_form_eval_plus_equiv | |
| _ -> i | ||
) proc.s_node | ||
in | ||
let pstate = try | ||
List.fold_left (EcCircuits.process_instr hyps mem) pstate insts | ||
with CircError err -> | ||
raise (BDepError ("Program circuit generation failed with error:\n" ^ err)) | ||
let pstate = try | ||
EcCircuits.pstate_of_prog hyps mem insts invs | ||
with CircError err -> | ||
raise (BDepError err) | ||
in | ||
let pstate = Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate in | ||
|
||
let f = EcPV.PVM.subst1 env (PVloc v.v_name) mem cur_val f in | ||
let pcond = match Map.find_opt v.v_name pstate with | ||
| Some circ -> begin try | ||
|
@@ -353,10 +297,10 @@ let circ_form_eval_plus_equiv | |
end | ||
| None -> None | ||
in | ||
let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in | ||
(*let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*) | ||
let f = EcCallbyValue.norm_cbv redmode hyps f in | ||
let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in | ||
let f = circ_simplify_form_bitstring_equality ~mem ~pstate ~inps ?pcond hyps f in | ||
(*let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*) | ||
let f = EcCircuits.circ_simplify_form_bitstring_equality ~mem ~pstate ?pcond hyps f in | ||
let f = EcCallbyValue.norm_cbv (EcReduction.full_red) hyps f in | ||
if f <> f_true then | ||
(Format.eprintf "Got %a after reduction@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f; | ||
|
@@ -387,17 +331,19 @@ let mapreduce_eval | |
|
||
let tm = time tm "Lane function circuit generation done" in | ||
|
||
let pstate = pstate_of_prog hyps mem proc invs in | ||
let pstate = try | ||
EcCircuits.pstate_of_prog hyps mem proc.s_node invs | ||
with CircError err -> | ||
raise (BDepError err) | ||
in | ||
|
||
let tm = time tm "Program circuit generation done" in | ||
|
||
begin | ||
let circs = List.map (fun v -> Option.get (Map.find_opt v pstate)) (List.map (fun v -> v.v_name) outvs) in | ||
|
||
assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); | ||
let cinp = (List.hd circs).inps in | ||
let c = try | ||
{(circuit_aggregate circs) with inps=cinp} | ||
(circuit_aggregate circs) | ||
with CircError _err -> | ||
raise (BDepError "Failed to concatenate program outputs") | ||
in | ||
|
@@ -410,8 +356,6 @@ let mapreduce_eval | |
|
||
let tm = time tm "circuit dependecy analysis + splitting done" in | ||
|
||
List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) cs; | ||
|
||
List.iteri (fun i c -> | ||
if circ_equiv ~strict:true (List.hd cs) c None | ||
then () | ||
|