From 2fbc29e3dba410238b3c142a6e91b381627d91fa Mon Sep 17 00:00:00 2001 From: Guillaume Claret Date: Mon, 2 Dec 2024 20:34:04 +0100 Subject: [PATCH 1/2] proof: more work on base64 --- CoqOfNoir/CoqOfNoir.v | 9 +- CoqOfNoir/base64/monomorphic.v | 40 ++-- CoqOfNoir/base64/polymorphic.v | 22 +- CoqOfNoir/base64/simulation.v | 378 +++++++++++++++++++++++-------- CoqOfNoir/proof/CoqOfNoir.v | 167 +++++++++----- CoqOfNoir/simulation/CoqOfNoir.v | 270 +++++++++++++++------- 6 files changed, 617 insertions(+), 269 deletions(-) diff --git a/CoqOfNoir/CoqOfNoir.v b/CoqOfNoir/CoqOfNoir.v index 9a147e3abbc..e6f96e3fd7b 100644 --- a/CoqOfNoir/CoqOfNoir.v +++ b/CoqOfNoir/CoqOfNoir.v @@ -3,6 +3,7 @@ Require Export Coq.Strings.Ascii. Require Coq.Strings.HexString. Require Export Coq.Strings.String. Require Export Coq.ZArith.ZArith. +Require Coq.micromega.ZifyBool. Require coqutil.Datatypes.List. Require Export RecordUpdate. @@ -122,7 +123,7 @@ Module Value. match value with | Tuple values => match List.listUpdate_error values (Z.to_nat i) update with - | Some new_values => Some (Tuple values) + | Some new_values => Some (Tuple new_values) | None => None end | _ => None @@ -131,7 +132,7 @@ Module Value. match value with | Array values => match List.listUpdate_error values (Z.to_nat i) update with - | Some new_values => Some (Array values) + | Some new_values => Some (Array new_values) | None => None end | _ => None @@ -454,8 +455,6 @@ Module M. | _ => impossible "index: expected a pointer" end. - Parameter assign : Value.t -> Value.t -> M.t. - Definition extract_tuple_field (tuple : Value.t) (field : Z) : M.t := match tuple with | Value.Pointer tuple_pointer => @@ -487,7 +486,7 @@ Module M. Fixpoint for_nat (end_ : Z) (fuel : nat) (body : Z -> M.t) {struct fuel} : M.t := match fuel with - | O => pure (Value.Tuple []) + | O => pure (alloc (Value.Tuple [])) | S fuel' => let* _ := body (end_ - Z.of_nat fuel) in for_nat end_ fuel' body diff --git a/CoqOfNoir/base64/monomorphic.v b/CoqOfNoir/base64/monomorphic.v index 8fed1e096bf..8721d3b4229 100644 --- a/CoqOfNoir/base64/monomorphic.v +++ b/CoqOfNoir/base64/monomorphic.v @@ -295,7 +295,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := M.read (| BYTES_PER_CHUNK |), fun (j : Value.t) => do~ [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -304,7 +304,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := |)) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.add (| M.read (| slice |), @@ -341,7 +341,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := M.read (| BASE64_ELEMENTS_PER_CHUNK |), fun (j : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, Binary.add (| @@ -383,7 +383,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := M.read (| bytes_in_final_chunk |), fun (j : Value.t) => do~ [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -392,7 +392,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := |)) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.add (| M.read (| slice |), @@ -423,7 +423,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := M.read (| BYTES_PER_CHUNK |), fun (_ : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -460,7 +460,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := M.read (| num_elements_in_final_chunk |), fun (i : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, Binary.add (| @@ -483,7 +483,7 @@ Definition base64_encode₁ (α : list Value.t) : M.t := |) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| result, M.call_closure (| get_function "base64_encode_elements" 6, @@ -537,7 +537,7 @@ Definition eq₂ (α : list Value.t) : M.t := |), fun (i : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| result, Binary.and_ (| M.read (| result |), @@ -760,7 +760,7 @@ Definition base64_decode₃ (α : list Value.t) : M.t := M.read (| BASE64_ELEMENTS_PER_CHUNK |), fun (j : Value.t) => do~ [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -769,7 +769,7 @@ Definition base64_decode₃ (α : list Value.t) : M.t := |)) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.add (| M.read (| slice |), @@ -805,7 +805,7 @@ Definition base64_decode₃ (α : list Value.t) : M.t := M.read (| BYTES_PER_CHUNK |), fun (j : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, Binary.add (| @@ -847,7 +847,7 @@ Definition base64_decode₃ (α : list Value.t) : M.t := M.read (| base64_elements_in_final_chunk |), fun (j : Value.t) => do~ [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -856,7 +856,7 @@ Definition base64_decode₃ (α : list Value.t) : M.t := |)) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.add (| M.read (| slice |), @@ -887,7 +887,7 @@ Definition base64_decode₃ (α : list Value.t) : M.t := M.read (| BASE64_ELEMENTS_PER_CHUNK |), fun (_ : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -923,7 +923,7 @@ Definition base64_decode₃ (α : list Value.t) : M.t := M.read (| num_bytes_in_final_chunk |), fun (i : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, Binary.add (| @@ -989,7 +989,7 @@ Definition eq₄ (α : list Value.t) : M.t := |), fun (i : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| result, Binary.and_ (| M.read (| result |), @@ -1210,7 +1210,7 @@ Definition base64_encode_elements₆ (α : list Value.t) : M.t := Value.Integer IntegerKind.U32 118, fun (i : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, M.read (| i |) @@ -1427,7 +1427,7 @@ Definition base64_decode_elements₈ (α : list Value.t) : M.t := |) |) ]] in do~ [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, M.read (| i |) @@ -1621,7 +1621,7 @@ Definition to_be_bytes₉ (α : list Value.t) : M.t := |)) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| ok, Value.Bool true |)) diff --git a/CoqOfNoir/base64/polymorphic.v b/CoqOfNoir/base64/polymorphic.v index b0029acc6ae..62f84623df8 100644 --- a/CoqOfNoir/base64/polymorphic.v +++ b/CoqOfNoir/base64/polymorphic.v @@ -231,7 +231,7 @@ Definition base64_encode_elements (InputElements : U32.t) (α : list Value.t) : |) ]] in let~ result := [[ M.copy_mutable (| M.alloc (Value.Array ( - List.repeat (Value.Integer IntegerKind.U8 0) (Integer.to_nat InputElements) + List.repeat (Value.Integer IntegerKind.U8 0) (Z.to_nat (Integer.to_Z InputElements)) )) |) ]] in do~ [[ @@ -240,7 +240,7 @@ Definition base64_encode_elements (InputElements : U32.t) (α : list Value.t) : to_value InputElements, fun (i : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, M.read (| i |) @@ -288,7 +288,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) let~ result := [[ M.copy_mutable (| M.alloc (Value.Array (List.repeat (Value.Integer IntegerKind.U8 0) - (Integer.to_nat OutputElements) + (Z.to_nat (Integer.to_Z OutputElements)) )) |) ]] in let~ BASE64_ELEMENTS_PER_CHUNK := [[ M.copy (| @@ -338,7 +338,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) M.read (| BYTES_PER_CHUNK |), fun (j : Value.t) => do~ [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -347,7 +347,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) |)) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.add (| M.read (| slice |), @@ -384,7 +384,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) M.read (| BASE64_ELEMENTS_PER_CHUNK |), fun (j : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, Binary.add (| @@ -426,7 +426,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) M.read (| bytes_in_final_chunk |), fun (j : Value.t) => do~ [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -435,7 +435,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) |)) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.add (| M.read (| slice |), @@ -466,7 +466,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) M.read (| BYTES_PER_CHUNK |), fun (_ : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| slice, Binary.multiply (| M.read (| slice |), @@ -503,7 +503,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) M.read (| num_elements_in_final_chunk |), fun (i : Value.t) => [[ - M.alloc (M.assign (| + M.alloc (M.write (| M.index (| result, Binary.add (| @@ -526,7 +526,7 @@ Definition base64_encode (InputBytes OutputElements : U32.t) (α : list Value.t) |) ]] in [[ - M.alloc (M.assign (| + M.alloc (M.write (| result, M.call_closure (| closure (base64_encode_elements (U32.Build_t 118)), diff --git a/CoqOfNoir/base64/simulation.v b/CoqOfNoir/base64/simulation.v index ed8e982d4dc..e1565fc044f 100644 --- a/CoqOfNoir/base64/simulation.v +++ b/CoqOfNoir/base64/simulation.v @@ -20,6 +20,11 @@ Module Base64EncodeBE. Value.Tuple [to_value x.(table)]; }. + Lemma rewrite_to_value (x : t) : + Value.Tuple [to_value x.(table)] = to_value x. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. + Definition ascii_codes : list U8.t := List.map U8.Build_t [ 65; 66; 67; 68; 69; 70; 71; 72; 73; 74; 75; 76; 77; 78; 79; 80; 81; 82; 83; 84; 85; 86; 87; 88; 89; 90; 97; 98; 99; 100; 101; 102; 103; 104; 105; 106; 107; 108; 109; 110; 111; 112; 113; 114; 115; 116; 117; 118; 119; 120; 121; 122; @@ -52,10 +57,10 @@ Module Base64EncodeBE. Lemma run_new {State Address : Set} `{State.Trait State Address} (p : Z) (state : State) : - {{ p, state | - polymorphic.Base64EncodeBE.new [] ⇓ + {{ p, state ⏩ + polymorphic.Base64EncodeBE.new [] 🔽 Result.Ok (to_value new) - | state }}. + ⏩ state }}. Proof. unfold polymorphic.Base64EncodeBE.new, new. eapply Run.Let. { @@ -63,6 +68,7 @@ Module Base64EncodeBE. } apply Run.Pure. Qed. + Global Opaque new. (* fn get(self, idx: Field) -> u8 { @@ -74,26 +80,17 @@ Module Base64EncodeBE. Lemma run_get {State Address : Set} `{State.Trait State Address} (p : Z) (state : State) - (self : t) (idx : Field.t) (result : U8.t) - (H_result : get self idx = return! result) : - {{ p, state | - polymorphic.Base64EncodeBE.get [to_value self; to_value idx] ⇓ - Result.Ok (to_value result) - | state }}. + (self : t) (idx : Field.t) : + {{ p, state ⏩ + polymorphic.Base64EncodeBE.get [to_value self; to_value idx] 🔽 + Panic.to_result (get self idx) + ⏩ state }}. Proof. - unfold polymorphic.Base64EncodeBE.get. - unfold get, Array.read in H_result. - (* destruct self as [ [table] ], idx as [idx]. *) - cbn in *. + unfold polymorphic.Base64EncodeBE.get, get, Array.read; cbn. rewrite List.nth_error_map. - destruct List.nth_error; cbn. - { inversion_clear H_result. - apply Run.Pure. - } - { exfalso. - discriminate. - } + destruct List.nth_error; cbn; apply Run.Pure. Qed. + Global Opaque get. (* (** How accessing the table of characters is used in practice *) @@ -124,10 +121,10 @@ Module Base64EncodeBE. (p : Z) (state : State) (idx : Z) (H_idx : 0 <= idx < 64) : - {{ p, state | - polymorphic.Base64EncodeBE.get [to_value new; to_value (Field.Build_t idx)] ⇓ + {{ p, state ⏩ + polymorphic.Base64EncodeBE.get [to_value new; to_value (Field.Build_t idx)] 🔽 Result.Ok (to_value (U8.Build_t (get_ascii_table idx))) - | state }}. + ⏩ state }}. Proof. apply run_get. now rewrite get_ascii_table_eq. @@ -135,49 +132,15 @@ Module Base64EncodeBE. *) End Base64EncodeBE. -(* -/** - * @brief Take an array of ASCII values and convert into base64 values - **/ -pub fn base64_encode_elements(input: [u8; InputElements]) -> [u8; InputElements] { - // for some reason, if the lookup table is not defined in a struct, access costs are expensive and ROM tables aren't being used :/ - let mut Base64Encoder = Base64EncodeBE::new(); - - let mut result: [u8; InputElements] = [0; InputElements]; - - for i in 0..InputElements { - result[i] = Base64Encoder.get(input[i] as Field); - } - result -} -*) -Definition base64_encode_elements {InputElements : U32.t} (input : Array.t U8.t InputElements) : - M! (Array.t U8.t InputElements) := - let Base64Encoder := Base64EncodeBE.new in - - let result : Array.t U8.t InputElements := Array.repeat InputElements (U8.Build_t 0) in - - List.fold_left - (fun (result : M! (Array.t U8.t InputElements)) (i : nat) => - let! result := result in - let i : U32.t := U32.Build_t (Z.of_nat i) in - let! input_i := Array.read input i in - let! new_result_i := Base64EncodeBE.get Base64Encoder (Field.Build_t (Integer.to_Z input_i)) in - Array.write result i new_result_i - ) - (List.seq 0 (Z.to_nat (SemiInteger.to_Z InputElements))) - (return! result). - Module base64_encode_elements. Module State. Record t : Set := { - base64_encoder : option Value.t; + Base64Encoder : option Value.t; result : option Value.t; }. - Arguments t : clear implicits. Definition init : t := {| - base64_encoder := None; + Base64Encoder := None; result := None; |}. End State. @@ -185,73 +148,307 @@ Module base64_encode_elements. Module Address. Inductive t : Set := | Base64Encoder - | Result. + | result. End Address. Global Instance Impl_State : State.Trait State.t Address.t := { - read a s := - match a with - | Address.Base64Encoder => s.(State.base64_encoder) - | Address.Result => s.(State.result) + read state address := + match address with + | Address.Base64Encoder => state.(State.Base64Encoder) + | Address.result => state.(State.result) end; - alloc_write a s v := - match a with - | Address.Base64Encoder => Some (s <| State.base64_encoder := Some v |>) - | Address.Result => Some (s <| State.result := Some v |>) + alloc_write state address value := + match address with + | Address.Base64Encoder => Some (state <| State.Base64Encoder := Some value |>) + | Address.result => Some (state <| State.result := Some value |>) end; }. - Lemma Impl_IsStateValid : State.Valid.t Impl_State. + Lemma IsStateValid : State.Valid.t Impl_State. Proof. sauto. Qed. End base64_encode_elements. +Module State. + Record t : Set := { + base64_encode_elements : base64_encode_elements.State.t; + }. + + Definition init : t := {| + base64_encode_elements := base64_encode_elements.State.init; + |}. +End State. + +Module Address. + Inductive t : Set := + | base64_encode_elements (address : base64_encode_elements.Address.t). +End Address. + +Global Instance Impl_State : State.Trait State.t Address.t := { + read state address := + match address with + | Address.base64_encode_elements address => + State.read state.(State.base64_encode_elements) address + end; + alloc_write state address value := + match address with + | Address.base64_encode_elements address => + match State.alloc_write state.(State.base64_encode_elements) address value with + | Some base64_encode_elements => + Some (state <| State.base64_encode_elements := base64_encode_elements |>) + | None => None + end + end; +}. + +Lemma IsStateValid : State.Valid.t Impl_State. +Proof. + sauto lq: on rew: off. +Qed. + +(* +/** + * @brief Take an array of ASCII values and convert into base64 values + **/ +pub fn base64_encode_elements(input: [u8; InputElements]) -> [u8; InputElements] { + // for some reason, if the lookup table is not defined in a struct, access costs are expensive and ROM tables aren't being used :/ + let mut Base64Encoder = Base64EncodeBE::new(); + + let mut result: [u8; InputElements] = [0; InputElements]; + + for i in 0..InputElements { + result[i] = Base64Encoder.get(input[i] as Field); + } + result +} +*) +Definition base64_encode_elements_for_init {InputElements : U32.t} + (input : Array.t U8.t InputElements) : + Array.t U8.t InputElements := + Array.repeat InputElements (U8.Build_t 0). + +Definition base64_encode_elements_for_body (p : Z) {InputElements : U32.t} + (input : Array.t U8.t InputElements) (i : Z) : + MS! (Array.t U8.t InputElements) unit := + let i : U32.t := U32.Build_t i in + letS! input_i := return!toS! (Array.read input i) in + letS! input_i := return!toS! (cast_to_field p input_i) in + letS! new_result_i := + return!toS! (Base64EncodeBE.get Base64EncodeBE.new input_i)in + letS! result := readS! in + letS! result := return!toS! (Array.write result i new_result_i) in + writeS! result. + +Definition base64_encode_elements (p : Z) {InputElements : U32.t} (input : Array.t U8.t InputElements) : + M! (Array.t U8.t InputElements) * Array.t U8.t InputElements := + let Base64Encoder := Base64EncodeBE.new in + + ( + doS! ( + foldS! + tt + (List.map Z.of_nat (List.seq 0 (Z.to_nat (ToZ.to_Z InputElements)))) + (fun result i => base64_encode_elements_for_body p input i) + ) in + letS! result := readS! in + returnS! result + ) (base64_encode_elements_for_init input). + Ltac cbn_goal := match goal with - | |- Run.t _ _ _ _ ?e => + | |- Run.t _ ?result _ _ ?e => + let result' := eval cbn in result in + change result with result'; let e' := eval cbn in e in change e with e' end. -(* +Lemma map_listUpdate_eq {A B : Type} (f : A -> B) (l : list A) (i : nat) (x : A) (y : B) + (H_y : y = f x) : + List.listUpdate (List.map f l) i y = List.map f (List.listUpdate l i x). +Proof. +Admitted. + +Lemma map_listUpdate_error_eq {A B : Type} (f : A -> B) (l : list A) (i : nat) (x : A) (y : B) + (H_y : y = f x) : + List.listUpdate_error (List.map f l) i y = option_map (List.map f) (List.listUpdate_error l i x). +Proof. + unfold List.listUpdate_error. + rewrite List.map_length. + destruct (_ True - end. + let output := base64_encode_elements p input in + let state_end : State.t := + State.init <| + State.base64_encode_elements := {| + base64_encode_elements.State.Base64Encoder := Some (to_value Base64EncodeBE.new); + base64_encode_elements.State.result := Some (to_value (snd output)); + |} + |> in + {{ p, State.init ⏩ + polymorphic.base64_encode_elements InputElements [to_value input] 🔽 + Panic.to_result (fst output) + ⏩ state_end }}. Proof. - Opaque Base64EncodeBE.get M.index. - destruct base64_encode_elements as [output|] eqn:H_base64_encode_elements; - [|trivial]. - unfold polymorphic.base64_encode_elements. + unfold polymorphic.base64_encode_elements, base64_encode_elements. eapply Run.Let. { eapply Run.CallClosure. { apply Base64EncodeBE.run_new. } - eapply CallPrimitiveStateAlloc with (address := base64_encode_elements.Address.Base64Encoder); + eapply CallPrimitiveStateAlloc with (address := + Address.base64_encode_elements (base64_encode_elements.Address.Base64Encoder) + ); try reflexivity. apply Run.Pure. } eapply Run.Let. { - eapply CallPrimitiveStateAlloc with (address := base64_encode_elements.Address.Result); + eapply CallPrimitiveStateAlloc with (address := + Address.base64_encode_elements (base64_encode_elements.Address.result) + ); try reflexivity. apply Run.Pure. } fold @LowM.let_. + eapply Run.Let. { + eapply Run.For with + (inject := + fun state accumulator => + state <| State.base64_encode_elements := + state.(State.base64_encode_elements) <| + base64_encode_elements.State.result := Some (to_value accumulator) + |> + |> + ) + (accumulator_in := base64_encode_elements_for_init input) + (len := Z.to_nat InputElements.(U32.value)) + (body_expression := base64_encode_elements_for_body p input). + 2: { + unfold set. + repeat f_equal. + cbn; f_equal. + now rewrite List.map_repeat. + } + 2: { + reflexivity. + } + 2: { + unfold Integer.Valid.t in H_InputElements; cbn in *. + f_equal. + lia. + } + intros. + eapply Run.CallPrimitiveStateRead; [reflexivity|]. + fold @LowM.let_. + unfold set; cbn. + unfold Array.read; cbn. + rewrite List.nth_error_map. + destruct List.nth_error as [result|]; cbn; [|apply Run.Pure]. + apply Run.CallPrimitiveGetFieldPrime. + unfold cast_to_field; cbn. + destruct (_ && _); cbn; [|apply Run.Pure]. + eapply Run.CallClosure. { + repeat rewrite Array.rewrite_to_value by (intros; now autorewrite with to_value). + autorewrite with to_value. + match goal with + | |- context[Value.Integer IntegerKind.Field ?i] => + change (Value.Integer IntegerKind.Field i) with (to_value (Field.Build_t i)) + end. + apply Base64EncodeBE.run_get. + } + destruct Base64EncodeBE.get; cbn; [|apply Run.Pure]. + eapply Run.CallPrimitiveStateRead; [reflexivity|]. + unfold Array.write; cbn. + rewrite List.nth_error_map. + destruct List.nth_error as [unused|] eqn:H_nth_error; cbn. + { clear H_nth_error unused. + erewrite map_listUpdate_error_eq by reflexivity. + unfold List.listUpdate_error. + destruct (_ + state <| State.base64_encode_elements := + state.(State.base64_encode_elements) <| + base64_encode_elements.State.result := Some accumulator + |> + |> + ) + ). + apply H. + eapply (Run.For (State := State.t)). + } + apply Run.LetUnfold. fold @LowM.let_. + apply Run.LetUnfold. unfold M.for_, M.for_Z. cbn_goal. unfold Integer.to_nat, Integer.to_Z. @@ -288,14 +485,13 @@ Proof. destruct input as [input]. simpl. Qed. -*) (* Lemma run_eq₂ {State Address : Set} `{State.Trait State Address} (state : State) (self other : Array.t U8.t 36) : {{ state | - translation.eq₂ [to_value self; to_value other] ⇓ + translation.eq₂ [to_value self; to_value other] 🔽 Result.Ok (to_value (Eq.eq self other)) - | state }}. + ⏩ state }}. Proof. unfold translation.eq₂. *) diff --git a/CoqOfNoir/proof/CoqOfNoir.v b/CoqOfNoir/proof/CoqOfNoir.v index 57c6ff8190e..5182daeb57a 100644 --- a/CoqOfNoir/proof/CoqOfNoir.v +++ b/CoqOfNoir/proof/CoqOfNoir.v @@ -4,8 +4,8 @@ Require Import CoqOfNoir.simulation.CoqOfNoir. Module State. Class Trait (State Address : Set) : Type := { - read (a : Address) : State -> option Value.t; - alloc_write (a : Address) : State -> Value.t -> option State; + read : State -> Address -> option Value.t; + alloc_write : State -> Address -> Value.t -> option State; }. Module Valid. @@ -15,20 +15,20 @@ Module State. allocated values. *) Record t `(Trait) : Prop := { (* [alloc_write] can only fail on new cells *) - not_allocated (a : Address) (s : State) (v : Value.t) : - match alloc_write a s v with + not_allocated (state : State) (address : Address) (value : Value.t) : + match alloc_write state address value with | Some _ => True - | None => read a s = None + | None => read state address = None end; - same (a : Address) (s : State) (v : Value.t) : - match alloc_write a s v with - | Some s => read a s = Some v + same (state : State) (address : Address) (value : Value.t) : + match alloc_write state address value with + | Some state => read state address = Some value | None => True end; - different (a1 a2 : Address) (s : State) (v2 : Value.t) : - a1 <> a2 -> - match alloc_write a2 s v2 with - | Some s' => read a1 s' = read a1 s + different (state : State) (address1 address2 : Address) (value2 : Value.t) : + address1 <> address2 -> + match alloc_write state address2 value2 with + | Some state' => read state' address1 = read state address1 | None => True end; }. @@ -36,108 +36,108 @@ Module State. End State. Module Run. - Reserved Notation "{{ p , state_in | e ⇓ output | state_out }}". + Reserved Notation "{{ p , state_in ⏩ e 🔽 output ⏩ state_out }}". Inductive t {State Address : Set} `{State.Trait State Address} (p : Z) (output : Result.t) (state_out : State) : State -> M.t -> Prop := | Pure : (* This should be the only case where the input and output states are the same. *) - {{ p, state_out | LowM.Pure output ⇓ output | state_out }} + {{ p, state_out ⏩ LowM.Pure output 🔽 output ⏩ state_out }} | CallPrimitiveStateAlloc (value : Value.t) (address : Address) (k : Value.t -> M.t) (state_in state_in' : State) : let pointer := Pointer.Mutable (Pointer.Mutable.Make address []) in - State.read address state_in = None -> - State.alloc_write address state_in value = Some state_in' -> - {{ p, state_in' | k (Value.Pointer pointer) ⇓ output | state_out }} -> - {{ p, state_in | LowM.CallPrimitive (Primitive.StateAlloc value) k ⇓ output | state_out }} + State.read state_in address = None -> + State.alloc_write state_in address value = Some state_in' -> + {{ p, state_in' ⏩ k (Value.Pointer pointer) 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.CallPrimitive (Primitive.StateAlloc value) k 🔽 output ⏩ state_out }} | CallPrimitiveStateRead (address : Address) (value : Value.t) (k : Value.t -> M.t) (state_in : State) : - State.read address state_in = Some value -> - {{ p, state_in | k value ⇓ output | state_out }} -> - {{ p, state_in | LowM.CallPrimitive (Primitive.StateRead address) k ⇓ output | state_out }} + State.read state_in address = Some value -> + {{ p, state_in ⏩ k value 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.CallPrimitive (Primitive.StateRead address) k 🔽 output ⏩ state_out }} | CallPrimitiveStateWrite (value : Value.t) (address : Address) (k : unit -> M.t) (state_in state_in' : State) : - State.alloc_write address state_in value = Some state_in' -> - {{ p, state_in' | k tt ⇓ output | state_out }} -> - {{ p, state_in | - LowM.CallPrimitive (Primitive.StateWrite address value) k ⇓ output - | state_out }} + State.alloc_write state_in address value = Some state_in' -> + {{ p, state_in' ⏩ k tt 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ + LowM.CallPrimitive (Primitive.StateWrite address value) k 🔽 output + ⏩ state_out }} | CallPrimitiveGetFieldPrime (k : Z -> M.t) (state_in : State) : - {{ p, state_in | k p ⇓ output | state_out }} -> - {{ p, state_in | - LowM.CallPrimitive Primitive.GetFieldPrime k ⇓ output - | state_out }} + {{ p, state_in ⏩ k p 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ + LowM.CallPrimitive Primitive.GetFieldPrime k 🔽 output + ⏩ state_out }} | CallPrimitiveIsEqualTrue (value1 value2 : Value.t) (k : bool -> M.t) (state_in : State) : (* The hypothesis of equality is explicit as this should be more convenient for the proofs *) value1 = value2 -> - {{ p, state_in | k true ⇓ output | state_out }} -> - {{ p, state_in | LowM.CallPrimitive (Primitive.IsEqual value1 value2) k ⇓ output | state_out }} + {{ p, state_in ⏩ k true 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.CallPrimitive (Primitive.IsEqual value1 value2) k 🔽 output ⏩ state_out }} | CallPrimitiveIsEqualFalse (value1 value2 : Value.t) (k : bool -> M.t) (state_in : State) : value1 <> value2 -> - {{ p, state_in | k false ⇓ output | state_out }} -> - {{ p, state_in | LowM.CallPrimitive (Primitive.IsEqual value1 value2) k ⇓ output | state_out }} + {{ p, state_in ⏩ k false 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.CallPrimitive (Primitive.IsEqual value1 value2) k 🔽 output ⏩ state_out }} | CallClosure (f : list Value.t -> M.t) (args : list Value.t) (k : Result.t -> M.t) (output_inter : Result.t) (state_in state_inter : State) : let closure := Value.Closure (existS (_, _) f) in - {{ p, state_in | f args ⇓ output_inter | state_inter }} -> - {{ p, state_inter | k output_inter ⇓ output | state_out }} -> - {{ p, state_in | LowM.CallClosure closure args k ⇓ output | state_out }} + {{ p, state_in ⏩ f args 🔽 output_inter ⏩ state_inter }} -> + {{ p, state_inter ⏩ k output_inter 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.CallClosure closure args k 🔽 output ⏩ state_out }} | Let (e : M.t) (k : Result.t -> M.t) (output_inter : Result.t) (state_in state_inter : State) : - {{ p, state_in | e ⇓ output_inter | state_inter }} -> - {{ p, state_inter | k output_inter ⇓ output | state_out }} -> - {{ p, state_in | LowM.Let e k ⇓ output | state_out }} + {{ p, state_in ⏩ e 🔽 output_inter ⏩ state_inter }} -> + {{ p, state_inter ⏩ k output_inter 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.Let e k 🔽 output ⏩ state_out }} | LetUnfold (e : M.t) (k : Result.t -> M.t) (state_in : State) : - {{ p, state_in | LowM.let_ e k ⇓ output | state_out }} -> - {{ p, state_in | LowM.Let e k ⇓ output | state_out }} + {{ p, state_in ⏩ LowM.let_ e k 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.Let e k 🔽 output ⏩ state_out }} | LetUnUnfold (e : M.t) (k : Result.t -> M.t) (state_in : State) : - {{ p, state_in | LowM.Let e k ⇓ output | state_out }} -> - {{ p, state_in | LowM.let_ e k ⇓ output | state_out }} + {{ p, state_in ⏩ LowM.Let e k 🔽 output ⏩ state_out }} -> + {{ p, state_in ⏩ LowM.let_ e k 🔽 output ⏩ state_out }} - where "{{ p , state_in | e ⇓ output | state_out }}" := + where "{{ p , state_in ⏩ e 🔽 output ⏩ state_out }}" := (t p output state_out state_in e). Lemma PureEq {State Address : Set} `{State.Trait State Address} (p : Z) (output output' : Result.t) (state state' : State) : output = output' -> state = state' -> - {{ p, state | LowM.Pure output ⇓ output' | state' }}. + {{ p, state ⏩ LowM.Pure output 🔽 output' ⏩ state' }}. Proof. intros -> ->. apply Pure. Qed. - Lemma For {State Address : Set} `{State.Trait State Address} + Lemma For_aux {State Address : Set} `{State.Trait State Address} (p : Z) (state_in : State) (integer_kind : IntegerKind.t) (start : Z) (len : nat) (body : Value.t -> M.t) {Accumulator : Set} @@ -146,10 +146,10 @@ Module Run. (body_expression : Z -> MS! Accumulator unit) (H_body : forall (accumulator_in : Accumulator) (i : Z), let output_accumulator_out := body_expression i accumulator_in in - {{ p, inject state_in accumulator_in | - body (M.alloc (Value.Integer integer_kind i)) ⇓ - Panic.to_result (fst output_accumulator_out) - | inject state_in (snd output_accumulator_out) }} + {{ p, inject state_in accumulator_in ⏩ + body (M.alloc (Value.Integer integer_kind i)) 🔽 + Panic.to_result_alloc (fst output_accumulator_out) + ⏩ inject state_in (snd output_accumulator_out) }} ) : let output_accumulator_out := foldS! @@ -157,13 +157,13 @@ Module Run. (List.map (fun offset => start + Z.of_nat offset) (List.seq 0 len)) (fun (_ : unit) => body_expression) accumulator_in in - {{ p, inject state_in accumulator_in | + {{ p, inject state_in accumulator_in ⏩ M.for_ (Value.Integer integer_kind start) (Value.Integer integer_kind (start + Z.of_nat len)) - body ⇓ - Panic.to_result (fst output_accumulator_out) - | inject state_in (snd output_accumulator_out) }}. + body 🔽 + Panic.to_result_alloc (fst output_accumulator_out) + ⏩ inject state_in (snd output_accumulator_out) }}. Proof. revert start accumulator_in. induction len as [| len IHlen]; intros; unfold M.for_, M.for_Z in *; simpl in *. @@ -204,6 +204,38 @@ Module Run. { apply Run.Pure. } } Qed. + + Lemma For {State Address : Set} `{State.Trait State Address} + (p : Z) (state_in : State) + (integer_kind : IntegerKind.t) (start_z : Z) (len : nat) (body : Value.t -> M.t) + (start end_ : Value.t) + {Accumulator : Set} + (inject : State -> Accumulator -> State) + (accumulator_in : Accumulator) + (body_expression : Z -> MS! Accumulator unit) + (H_body : forall (accumulator_in : Accumulator) (i : Z), + let output_accumulator_out := body_expression i accumulator_in in + {{ p, inject state_in accumulator_in ⏩ + body (M.alloc (Value.Integer integer_kind i)) 🔽 + Panic.to_result_alloc (fst output_accumulator_out) + ⏩ inject state_in (snd output_accumulator_out) }} + ) : + let output_accumulator_out := + foldS! + tt + (List.map (fun offset => start_z + Z.of_nat offset) (List.seq 0 len)) + (fun (_ : unit) => body_expression) + accumulator_in in + state_in = inject state_in accumulator_in -> + start = Value.Integer integer_kind start_z -> + end_ = Value.Integer integer_kind (start_z + Z.of_nat len) -> + {{ p, state_in ⏩ + M.for_ start end_ body 🔽 + Panic.to_result_alloc (fst output_accumulator_out) + ⏩ inject state_in (snd output_accumulator_out) }}. + Proof. + hauto q: on use: For_aux. + Qed. End Run. Module Singleton. @@ -216,8 +248,8 @@ Module Singleton. End Address. Global Instance IsState : State.Trait State.t Address.t := { - read _ s := s; - alloc_write _ s v := Some (Some v); + read state _ := state; + alloc_write state _ value := Some (Some value); }. Lemma IsStateValid : State.Valid.t IsState. @@ -225,3 +257,24 @@ Module Singleton. sauto lq: on rew: off. Qed. End Singleton. + +Module Field. + Module Valid. + Definition t (p : Z) (x : Field.t) : Prop := + 0 <= x.(Field.value) < p. + End Valid. +End Field. + +Module Integer. + Module Valid. + Definition t {A : Set} `{Integer.Trait A} (x : A) : Prop := + Integer.min (Self := A) <= Integer.to_Z x <= Integer.max (Self := A). + End Valid. +End Integer. + +Module Array. + Module Valid. + Definition t {A : Set} {size : U32.t} (array : Array.t A size) : Prop := + List.length array.(Array.value) = Z.to_nat (Integer.to_Z size). + End Valid. +End Array. diff --git a/CoqOfNoir/simulation/CoqOfNoir.v b/CoqOfNoir/simulation/CoqOfNoir.v index d4398e8cfa8..28f1e6f74fa 100644 --- a/CoqOfNoir/simulation/CoqOfNoir.v +++ b/CoqOfNoir/simulation/CoqOfNoir.v @@ -1,4 +1,5 @@ Require Import CoqOfNoir.CoqOfNoir. +Require Import Coq.Logic.FunctionalExtensionality. Module ToValue. Class Trait (Self : Set) : Set := { @@ -22,6 +23,14 @@ Module Panic. | Error => Result.Panic end. + (** For some intermediate results, we need to make an allocation to be like in the translated + code *) + Definition to_result_alloc {A : Set} `{ToValue.Trait A} (value : t A) : Result.t := + match value with + | Success value => Result.Ok (M.alloc (to_value value)) + | Error => Result.Panic + end. + Definition return_ {A : Set} (value : A) : t A := Success value. Arguments return_ /. @@ -80,6 +89,15 @@ Module StatePanic. | [] => return_ init | x :: l => bind (f init x) (fun init => fold_left init l f) end. + + Definition lift_from_panic {State A : Set} (value : Panic.t A) : t State A := + fun state => (value, state). + + Definition read {State : Set} : t State State := + fun state => (Panic.return_ state, state). + + Definition write {State : Set} (state : State) : t State unit := + fun _ => (Panic.return_ tt, state). End StatePanic. Module StatePanicNotations. @@ -97,6 +115,12 @@ Module StatePanicNotations. (at level 200, X at level 100, Y at level 200). Notation "foldS!" := StatePanic.fold_left. + + Notation "return!toS!" := StatePanic.lift_from_panic. + + Notation "readS!" := StatePanic.read. + + Notation "writeS!" := StatePanic.write. End StatePanicNotations. Export PanicNotations. @@ -107,11 +131,21 @@ Global Instance Impl_ToValue_for_unit : ToValue.Trait unit := { Value.Tuple []; }. +Lemma rewrite_to_value_unit : + Value.Tuple [] = to_value tt. +Proof. reflexivity. Qed. +Global Hint Rewrite rewrite_to_value_unit : to_value. + Global Instance Impl_ToValue_for_bool : ToValue.Trait bool := { to_value (b : bool) := Value.Bool b; }. +Lemma rewrite_to_value_bool (b : bool) : + Value.Bool b = to_value b. +Proof. reflexivity. Qed. +Global Hint Rewrite rewrite_to_value_bool : to_value. + Module Field. Record t : Set := { value : Z; @@ -121,6 +155,11 @@ Module Field. to_value (i : t) := Value.Integer IntegerKind.Field i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.Field i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End Field. Module U1. @@ -132,6 +171,11 @@ Module U1. to_value (i : t) := Value.Integer IntegerKind.U1 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.U1 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End U1. Module U8. @@ -143,6 +187,11 @@ Module U8. to_value (i : t) := Value.Integer IntegerKind.U8 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.U8 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End U8. Module U16. @@ -154,6 +203,11 @@ Module U16. to_value (i : t) := Value.Integer IntegerKind.U16 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.U16 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End U16. Module U32. @@ -165,6 +219,11 @@ Module U32. to_value (i : t) := Value.Integer IntegerKind.U32 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.U32 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End U32. Module U64. @@ -176,6 +235,11 @@ Module U64. to_value (i : t) := Value.Integer IntegerKind.U64 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.U64 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End U64. Module I1. @@ -187,6 +251,11 @@ Module I1. to_value (i : t) := Value.Integer IntegerKind.I1 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.I1 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End I1. Module I8. @@ -198,6 +267,11 @@ Module I8. to_value (i : t) := Value.Integer IntegerKind.I8 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.I8 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End I8. Module I16. @@ -209,6 +283,11 @@ Module I16. to_value (i : t) := Value.Integer IntegerKind.I16 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.I16 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End I16. Module I32. @@ -220,6 +299,11 @@ Module I32. to_value (i : t) := Value.Integer IntegerKind.I32 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.I32 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End I32. Module I64. @@ -231,147 +315,146 @@ Module I64. to_value (i : t) := Value.Integer IntegerKind.I64 i.(value); }. + + Lemma rewrite_to_value (i : t) : + Value.Integer IntegerKind.I64 i.(value) = to_value i. + Proof. reflexivity. Qed. + Global Hint Rewrite rewrite_to_value : to_value. End I64. -Module SemiInteger. +Module Integer. Class Trait (Self : Set) : Set := { to_Z : Self -> Z; - }. -End SemiInteger. - -Global Instance Impl_SemiInteger_for_Field : SemiInteger.Trait Field.t := { - SemiInteger.to_Z (i : Field.t) := - i.(Field.value); -}. - -Global Instance Impl_SemiInteger_for_U1 : SemiInteger.Trait U1.t := { - SemiInteger.to_Z (i : U1.t) := - i.(U1.value); -}. - -Global Instance Impl_SemiInteger_for_U8 : SemiInteger.Trait U8.t := { - SemiInteger.to_Z (i : U8.t) := - i.(U8.value); -}. - -Global Instance Impl_SemiInteger_for_U16 : SemiInteger.Trait U16.t := { - SemiInteger.to_Z (i : U16.t) := - i.(U16.value); -}. - -Global Instance Impl_SemiInteger_for_U32 : SemiInteger.Trait U32.t := { - SemiInteger.to_Z (i : U32.t) := - i.(U32.value); -}. - -Global Instance Impl_SemiInteger_for_U64 : SemiInteger.Trait U64.t := { - SemiInteger.to_Z (i : U64.t) := - i.(U64.value); -}. - -Global Instance Impl_SemiInteger_for_I1 : SemiInteger.Trait I1.t := { - SemiInteger.to_Z (i : I1.t) := - i.(I1.value); -}. - -Global Instance Impl_SemiInteger_for_I8 : SemiInteger.Trait I8.t := { - SemiInteger.to_Z (i : I8.t) := - i.(I8.value); -}. - -Global Instance Impl_SemiInteger_for_I16 : SemiInteger.Trait I16.t := { - SemiInteger.to_Z (i : I16.t) := - i.(I16.value); -}. - -Global Instance Impl_SemiInteger_for_I32 : SemiInteger.Trait I32.t := { - SemiInteger.to_Z (i : I32.t) := - i.(I32.value); -}. - -Global Instance Impl_SemiInteger_for_I64 : SemiInteger.Trait I64.t := { - SemiInteger.to_Z (i : I64.t) := - i.(I64.value); -}. - -Module Integer. - Class Trait (Self : Set) `{SemiInteger.Trait Self} : Set := { of_Z : Z -> Self; + min : Z; + max : Z; }. - Definition to_Z {Self : Set} `{Trait Self} (self : Self) : Z := - SemiInteger.to_Z self. - - Definition to_nat {Self : Set} `{Trait Self} (self : Self) : nat := - Z.to_nat (to_Z self). - Definition add {Self : Set} `{Trait Self} (self other : Self) : Self := - of_Z (SemiInteger.to_Z self + SemiInteger.to_Z other). + of_Z (Integer.to_Z self + Integer.to_Z other). Definition sub {Self : Set} `{Trait Self} (self other : Self) : Self := - of_Z (SemiInteger.to_Z self - SemiInteger.to_Z other). + of_Z (Integer.to_Z self - Integer.to_Z other). Definition mul {Self : Set} `{Trait Self} (self other : Self) : Self := - of_Z (SemiInteger.to_Z self * SemiInteger.to_Z other). + of_Z (Integer.to_Z self * Integer.to_Z other). Definition div {Self : Set} `{Trait Self} (self other : Self) : Self := - of_Z (SemiInteger.to_Z self / SemiInteger.to_Z other). + of_Z (Integer.to_Z self / Integer.to_Z other). Definition mod_ {Self : Set} `{Trait Self} (self other : Self) : Self := - of_Z (SemiInteger.to_Z self mod SemiInteger.to_Z other). + of_Z (Integer.to_Z self mod Integer.to_Z other). Definition of_bool {Self : Set} `{Trait Self} (b : bool) : Self := of_Z (if b then 1 else 0). End Integer. Global Instance Impl_Integer_for_U1 : Integer.Trait U1.t := { + Integer.to_Z (i : U1.t) := + i.(U1.value); Integer.of_Z (i : Z) := U1.Build_t (i mod (2^1)); + Integer.min := 0; + Integer.max := 1; }. Global Instance Impl_Integer_for_U8 : Integer.Trait U8.t := { + Integer.to_Z (i : U8.t) := + i.(U8.value); Integer.of_Z (i : Z) := U8.Build_t (i mod (2^8)); + Integer.min := 0; + Integer.max := 2^8 - 1; }. Global Instance Impl_Integer_for_U16 : Integer.Trait U16.t := { + Integer.to_Z (i : U16.t) := + i.(U16.value); Integer.of_Z (i : Z) := U16.Build_t (i mod (2^16)); + Integer.min := 0; + Integer.max := 2^16 - 1; }. Global Instance Impl_Integer_for_U32 : Integer.Trait U32.t := { + Integer.to_Z (i : U32.t) := + i.(U32.value); Integer.of_Z (i : Z) := U32.Build_t (i mod (2^32)); + Integer.min := 0; + Integer.max := 2^32 - 1; }. Global Instance Impl_Integer_for_U64 : Integer.Trait U64.t := { + Integer.to_Z (i : U64.t) := + i.(U64.value); Integer.of_Z (i : Z) := U64.Build_t (i mod (2^64)); + Integer.min := 0; + Integer.max := 2^64 - 1; }. Global Instance Impl_Integer_for_I1 : Integer.Trait I1.t := { + Integer.to_Z (i : I1.t) := + i.(I1.value); Integer.of_Z (i : Z) := I1.Build_t (((i + 2^0) mod (2^1)) - 2^0); + Integer.min := -1; + Integer.max := 0; }. Global Instance Impl_Integer_for_I8 : Integer.Trait I8.t := { + Integer.to_Z (i : I8.t) := + i.(I8.value); Integer.of_Z (i : Z) := I8.Build_t (((i + 2^7) mod (2^8)) - 2^7); + Integer.min := -2^7; + Integer.max := 2^7 - 1; }. Global Instance Impl_Integer_for_I16 : Integer.Trait I16.t := { + Integer.to_Z (i : I16.t) := + i.(I16.value); Integer.of_Z (i : Z) := I16.Build_t (((i + 2^15) mod (2^16)) - 2^15); + Integer.min := -2^15; + Integer.max := 2^15 - 1; }. Global Instance Impl_Integer_for_I32 : Integer.Trait I32.t := { + Integer.to_Z (i : I32.t) := + i.(I32.value); Integer.of_Z (i : Z) := I32.Build_t (((i + 2^31) mod (2^32)) - 2^31); + Integer.min := -2^31; + Integer.max := 2^31 - 1; }. Global Instance Impl_Integer_for_I64 : Integer.Trait I64.t := { + Integer.to_Z (i : I64.t) := + i.(I64.value); Integer.of_Z (i : Z) := I64.Build_t (((i + 2^63) mod (2^64)) - 2^63); + Integer.min := -2^63; + Integer.max := 2^63 - 1; +}. + +(** With this trait, we can take into account both standard integers and fields, whose size depends + on a parameter [p]. *) +Module ToZ. + Class Trait (Self : Set) : Set := { + to_Z : Self -> Z; + }. +End ToZ. + +Global Instance Impl_ToZ_for_Field : ToZ.Trait Field.t := { + ToZ.to_Z (i : Field.t) := + i.(Field.value); +}. + +Global Instance Impl_ToZ_for_Integer {A : Set} `{Integer.Trait A} : ToZ.Trait A := { + ToZ.to_Z (i : A) := + Integer.to_Z i; }. Module Array. @@ -383,34 +466,37 @@ Module Array. Arguments t : clear implicits. Arguments Build_t {_ _}. - Module Valid. - Definition t {A : Set} {size : U32.t} (array : t A size) : Prop := - List.length array.(value) = Z.to_nat (SemiInteger.to_Z size). - End Valid. - Global Instance Impl_ToValue {A : Set} `{ToValue.Trait A} {size : U32.t} : ToValue.Trait (t A size) := { to_value (array : t A size) := Value.Array (List.map to_value array.(value)); }. + Lemma rewrite_to_value {A : Set} `{ToValue.Trait A} {size : U32.t} (array : t A size) f : + (forall (x : A), f x = to_value x) -> + Value.Array (List.map f array.(value)) = to_value array. + Proof. + hauto lq: on use: functional_extensionality. + Qed. + Global Hint Rewrite @rewrite_to_value : to_value. + Definition repeat {A : Set} (size : U32.t) (value : A) : t A size := {| - value := List.repeat value (Z.to_nat (SemiInteger.to_Z size)) + value := List.repeat value (Z.to_nat (Integer.to_Z size)) |}. - Definition read {A Index: Set} `{SemiInteger.Trait Index} {size : U32.t} + Definition read {A Index: Set} `{ToZ.Trait Index} {size : U32.t} (array : t A size) (index : Index) : M! A := - match List.nth_error array.(value) (Z.to_nat (SemiInteger.to_Z index)) with + match List.nth_error array.(value) (Z.to_nat (ToZ.to_Z index)) with | Some result => return! result | None => panic! ("Array.get: index out of bounds", array, index) end. - Definition write {A Index: Set} `{SemiInteger.Trait Index} {size : U32.t} + Definition write {A Index: Set} `{ToZ.Trait Index} {size : U32.t} (array : t A size) (index : Index) (update : A) : M! (t A size) := - match List.listUpdate_error array.(value) (Z.to_nat (SemiInteger.to_Z index)) update with + match List.listUpdate_error array.(value) (Z.to_nat (ToZ.to_Z index)) update with | Some array => return! (Build_t array) | None => panic! ("Array.write: index out of bounds", array, index) end. @@ -422,9 +508,9 @@ Module Eq. }. End Eq. -Global Instance Impl_Eq_for_U8 : Eq.Trait U8.t := { - Eq.eq (self other : U8.t) := - self.(U8.value) =? other.(U8.value); +Global Instance Impl_Eq_for_Integer {A : Set} `{Integer.Trait A} : Eq.Trait A := { + Eq.eq (self other : A) := + Integer.to_Z self =? Integer.to_Z other; }. Global Instance Impl_Eq_for_Array {A : Set} `{Eq.Trait A} {size : U32.t} : @@ -432,3 +518,17 @@ Global Instance Impl_Eq_for_Array {A : Set} `{Eq.Trait A} {size : U32.t} : Eq.eq (self other : Array.t A size) := List.fold_left andb (List.zip Eq.eq self.(Array.value) other.(Array.value)) true; }. + +Definition cast_to_integer {A B : Set} `{ToZ.Trait A} `{Integer.Trait B} (value : A) : M! B := + let value := ToZ.to_Z value in + if (Integer.min <=? value) && (value <=? Integer.max) then + return! (Integer.of_Z value) + else + panic! ("cast: out of bounds", value). + +Definition cast_to_field {A : Set} `{ToZ.Trait A} (p : Z) (value : A) : M! Field.t := + let value := ToZ.to_Z value in + if (0 <=? value) && (value Date: Tue, 10 Dec 2024 18:31:16 +0100 Subject: [PATCH 2/2] proof: simplify the proof thanks to autorewrite --- CoqOfNoir/base64/simulation.v | 120 +++---------------------------- CoqOfNoir/simulation/CoqOfNoir.v | 58 +++++++-------- scripts/coq_of_noir.py | 2 +- 3 files changed, 38 insertions(+), 142 deletions(-) diff --git a/CoqOfNoir/base64/simulation.v b/CoqOfNoir/base64/simulation.v index e1565fc044f..5eb8fbf5038 100644 --- a/CoqOfNoir/base64/simulation.v +++ b/CoqOfNoir/base64/simulation.v @@ -240,7 +240,8 @@ Definition base64_encode_elements_for_body (p : Z) {InputElements : U32.t} letS! result := return!toS! (Array.write result i new_result_i) in writeS! result. -Definition base64_encode_elements (p : Z) {InputElements : U32.t} (input : Array.t U8.t InputElements) : +Definition base64_encode_elements (p : Z) {InputElements : U32.t} + (input : Array.t U8.t InputElements) : M! (Array.t U8.t InputElements) * Array.t U8.t InputElements := let Base64Encoder := Base64EncodeBE.new in @@ -255,20 +256,14 @@ Definition base64_encode_elements (p : Z) {InputElements : U32.t} (input : Array returnS! result ) (base64_encode_elements_for_init input). -Ltac cbn_goal := - match goal with - | |- Run.t _ ?result _ _ ?e => - let result' := eval cbn in result in - change result with result'; - let e' := eval cbn in e in - change e with e' - end. - Lemma map_listUpdate_eq {A B : Type} (f : A -> B) (l : list A) (i : nat) (x : A) (y : B) (H_y : y = f x) : List.listUpdate (List.map f l) i y = List.map f (List.listUpdate l i x). Proof. -Admitted. + unfold List.listUpdate. + rewrite List.firstn_map, List.skipn_map, List.map_app. + sfirstorder. +Qed. Lemma map_listUpdate_error_eq {A B : Type} (f : A -> B) (l : list A) (i : nat) (x : A) (y : B) (H_y : y = f x) : @@ -356,7 +351,6 @@ Proof. unfold cast_to_field; cbn. destruct (_ && _); cbn; [|apply Run.Pure]. eapply Run.CallClosure. { - repeat rewrite Array.rewrite_to_value by (intros; now autorewrite with to_value). autorewrite with to_value. match goal with | |- context[Value.Integer IntegerKind.Field ?i] => @@ -385,105 +379,11 @@ Proof. } } fold @LowM.let_. - (* destruct fst; cbn; [|apply Run.Pure]. *) + unfold StatePanic.bind. destruct (foldS! _ _ _) as [status result]. - destruct status; cbn. - { - - } - { exfalso. - set (Z.to_nat i) in *. - - lia. - } - [lia|]. - Search List.listUpdate_error. - { - (* pose proof (List.nth_error_None accumulator_in.(Array.value) (Z.to_nat i)). - best. *) - - } - epose proof (List.nth_error_None _ _). H_nth_error). - } - set (length := Z.of_nat (List.length accumulator_in.(Array.value))). - destruct (i - state <| State.base64_encode_elements := - state.(State.base64_encode_elements) <| - base64_encode_elements.State.result := Some accumulator - |> - |> - ) - ). - apply H. - eapply (Run.For (State := State.t)). - } - - apply Run.LetUnfold. - fold @LowM.let_. - apply Run.LetUnfold. - unfold M.for_, M.for_Z. - cbn_goal. - unfold Integer.to_nat, Integer.to_Z. - repeat match goal with - | |- context[Z.to_nat ?x] => - let n' := eval cbn in (Z.to_nat x) in - change (Z.to_nat x) with n' - end. - match goal with - | |- context[?x - 0] => - replace (x - 0) with x by lia - end. - unfold Array.Valid.t in H_input. - cbn in H_input. - unfold base64_encode_elements, Array.repeat in H_base64_encode_elements. - cbn in H_base64_encode_elements. - induction (Z.to_nat _); cbn_goal. - { eapply Run.CallPrimitiveStateRead; [reflexivity|]. - cbn in H_base64_encode_elements. - inversion_clear H_base64_encode_elements. - apply Run.Pure. - } - { eapply Run.CallPrimitiveStateRead; [reflexivity|]. - fold @LowM.let_. - Transparent M.index. - unfold M.index. - cbn_goal. - } - set (n := Z.to_nat _). - - simpl. - cbn. - (* Entering the loop *) - destruct input as [input]. - simpl. + destruct status; cbn; [|apply Run.Pure]. + eapply Run.CallPrimitiveStateRead; [reflexivity|]. + apply Run.Pure. Qed. (* Lemma run_eq₂ {State Address : Set} `{State.Trait State Address} diff --git a/CoqOfNoir/simulation/CoqOfNoir.v b/CoqOfNoir/simulation/CoqOfNoir.v index 28f1e6f74fa..ff9ed759581 100644 --- a/CoqOfNoir/simulation/CoqOfNoir.v +++ b/CoqOfNoir/simulation/CoqOfNoir.v @@ -1,5 +1,4 @@ Require Import CoqOfNoir.CoqOfNoir. -Require Import Coq.Logic.FunctionalExtensionality. Module ToValue. Class Trait (Self : Set) : Set := { @@ -141,8 +140,8 @@ Global Instance Impl_ToValue_for_bool : ToValue.Trait bool := { Value.Bool b; }. -Lemma rewrite_to_value_bool (b : bool) : - Value.Bool b = to_value b. +Lemma rewrite_to_value_bool : + Value.Bool = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value_bool : to_value. @@ -156,8 +155,8 @@ Module Field. Value.Integer IntegerKind.Field i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.Field i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.Field i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End Field. @@ -172,8 +171,8 @@ Module U1. Value.Integer IntegerKind.U1 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.U1 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.U1 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End U1. @@ -188,8 +187,8 @@ Module U8. Value.Integer IntegerKind.U8 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.U8 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.U8 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End U8. @@ -204,8 +203,8 @@ Module U16. Value.Integer IntegerKind.U16 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.U16 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.U16 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End U16. @@ -220,8 +219,8 @@ Module U32. Value.Integer IntegerKind.U32 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.U32 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.U32 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End U32. @@ -236,8 +235,8 @@ Module U64. Value.Integer IntegerKind.U64 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.U64 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.U64 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End U64. @@ -252,8 +251,8 @@ Module I1. Value.Integer IntegerKind.I1 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.I1 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.I1 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End I1. @@ -268,8 +267,8 @@ Module I8. Value.Integer IntegerKind.I8 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.I8 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.I8 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End I8. @@ -284,8 +283,8 @@ Module I16. Value.Integer IntegerKind.I16 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.I16 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.I16 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End I16. @@ -300,8 +299,8 @@ Module I32. Value.Integer IntegerKind.I32 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.I32 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.I32 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End I32. @@ -316,8 +315,8 @@ Module I64. Value.Integer IntegerKind.I64 i.(value); }. - Lemma rewrite_to_value (i : t) : - Value.Integer IntegerKind.I64 i.(value) = to_value i. + Lemma rewrite_to_value : + (fun i => Value.Integer IntegerKind.I64 i.(value)) = to_value. Proof. reflexivity. Qed. Global Hint Rewrite rewrite_to_value : to_value. End I64. @@ -472,12 +471,9 @@ Module Array. Value.Array (List.map to_value array.(value)); }. - Lemma rewrite_to_value {A : Set} `{ToValue.Trait A} {size : U32.t} (array : t A size) f : - (forall (x : A), f x = to_value x) -> - Value.Array (List.map f array.(value)) = to_value array. - Proof. - hauto lq: on use: functional_extensionality. - Qed. + Lemma rewrite_to_value {A : Set} `{ToValue.Trait A} {size : U32.t} (array : t A size) : + Value.Array (List.map to_value array.(value)) = to_value array. + Proof. reflexivity. Qed. Global Hint Rewrite @rewrite_to_value : to_value. Definition repeat {A : Set} (size : U32.t) (value : A) : t A size := diff --git a/scripts/coq_of_noir.py b/scripts/coq_of_noir.py index b0cffb0ea8f..b5929a55e68 100644 --- a/scripts/coq_of_noir.py +++ b/scripts/coq_of_noir.py @@ -416,7 +416,7 @@ def lvalue_to_coq(node) -> str: ''' def assign_to_coq(node) -> str: return alloc( - "M.assign (|\n" + + "M.write (|\n" + indent( read(lvalue_to_coq(node["lvalue"])) + ",\n" + read(expression_to_coq(node["expression"]))