Skip to content

Commit

Permalink
feat: Oracle mocker for nargo test (#2928)
Browse files Browse the repository at this point in the history
Co-authored-by: Maxim Vezenov <[email protected]>
  • Loading branch information
sirasistant and vezenovm authored Oct 2, 2023
1 parent 285aa21 commit 0dd1e77
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 10 deletions.
6 changes: 3 additions & 3 deletions compiler/noirc_frontend/src/hir/def_map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ impl CrateDefMap {
.value_definitions()
.filter_map(|id| {
id.as_function().map(|function_id| {
let is_entry_point = !interner
.function_attributes(&function_id)
.has_contract_library_method();
let attributes = interner.function_attributes(&function_id);
let is_entry_point = !attributes.has_contract_library_method()
&& !attributes.is_test_function();
ContractFunctionMeta { function_id, is_entry_point }
})
})
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/lexer/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ impl Attributes {
.any(|attribute| attribute == &SecondaryAttribute::ContractLibraryMethod)
}

pub fn is_test_function(&self) -> bool {
matches!(self.function, Some(FunctionAttribute::Test(_)))
}

/// Returns note if a deprecated secondary attribute is found
pub fn get_deprecated_note(&self) -> Option<Option<String>> {
self.secondary.iter().find_map(|attr| match attr {
Expand Down
1 change: 1 addition & 0 deletions noir_stdlib/src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod collections;
mod compat;
mod option;
mod string;
mod test;

// Oracle calls are required to be wrapped in an unconstrained function
// Thus, the only argument to the `println` oracle is expected to always be an ident
Expand Down
45 changes: 45 additions & 0 deletions noir_stdlib/src/test.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#[oracle(create_mock)]
unconstrained fn create_mock_oracle<N>(_name: str<N>) -> Field {}

#[oracle(set_mock_params)]
unconstrained fn set_mock_params_oracle<P>(_id: Field, _params: P) {}

#[oracle(set_mock_returns)]
unconstrained fn set_mock_returns_oracle<R>(_id: Field, _returns: R) {}

#[oracle(set_mock_times)]
unconstrained fn set_mock_times_oracle(_id: Field, _times: u64) {}

#[oracle(clear_mock)]
unconstrained fn clear_mock_oracle(_id: Field) {}

struct OracleMock {
id: Field,
}

impl OracleMock {
unconstrained pub fn mock<N>(name: str<N>) -> Self {
Self {
id: create_mock_oracle(name),
}
}

unconstrained pub fn with_params<P>(self, params: P) -> Self {
set_mock_params_oracle(self.id, params);
self
}

unconstrained pub fn returns<R>(self, returns: R) -> Self {
set_mock_returns_oracle(self.id, returns);
self
}

unconstrained pub fn times(self, times: u64) -> Self {
set_mock_times_oracle(self.id, times);
self
}

unconstrained pub fn clear(self) {
clear_mock_oracle(self.id);
}
}
7 changes: 5 additions & 2 deletions tooling/nargo/src/ops/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use acvm::{acir::circuit::Circuit, acir::native_types::WitnessMap};
use crate::errors::ExecutionError;
use crate::NargoError;

use super::foreign_calls::ForeignCall;
use super::foreign_calls::ForeignCallExecutor;

pub fn execute_circuit<B: BlackBoxFunctionSolver>(
blackbox_solver: &B,
Expand All @@ -24,6 +24,8 @@ pub fn execute_circuit<B: BlackBoxFunctionSolver>(
.map(|(_, message)| message.clone())
};

let mut foreign_call_executor = ForeignCallExecutor::default();

loop {
let solver_status = acvm.solve();

Expand Down Expand Up @@ -57,7 +59,8 @@ pub fn execute_circuit<B: BlackBoxFunctionSolver>(
}));
}
ACVMStatus::RequiresForeignCall(foreign_call) => {
let foreign_call_result = ForeignCall::execute(&foreign_call, show_output)?;
let foreign_call_result =
foreign_call_executor.execute(&foreign_call, show_output)?;
acvm.resolve_pending_foreign_call(foreign_call_result);
}
}
Expand Down
150 changes: 145 additions & 5 deletions tooling/nargo/src/ops/foreign_calls.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use acvm::{
acir::brillig::{ForeignCallResult, Value},
brillig_vm::brillig::ForeignCallParam,
acir::brillig::{ForeignCallParam, ForeignCallResult, Value},
pwg::ForeignCallWaitInfo,
};
use iter_extended::vecmap;
use noirc_printable_type::PrintableValueDisplay;
use noirc_printable_type::{decode_string_value, ForeignCallError, PrintableValueDisplay};

use crate::NargoError;

Expand All @@ -14,6 +13,11 @@ pub(crate) enum ForeignCall {
Println,
Sequence,
ReverseSequence,
CreateMock,
SetMockParams,
SetMockReturns,
SetMockTimes,
ClearMock,
}

impl std::fmt::Display for ForeignCall {
Expand All @@ -28,6 +32,11 @@ impl ForeignCall {
ForeignCall::Println => "println",
ForeignCall::Sequence => "get_number_sequence",
ForeignCall::ReverseSequence => "get_reverse_number_sequence",
ForeignCall::CreateMock => "create_mock",
ForeignCall::SetMockParams => "set_mock_params",
ForeignCall::SetMockReturns => "set_mock_returns",
ForeignCall::SetMockTimes => "set_mock_times",
ForeignCall::ClearMock => "clear_mock",
}
}

Expand All @@ -36,16 +45,65 @@ impl ForeignCall {
"println" => Some(ForeignCall::Println),
"get_number_sequence" => Some(ForeignCall::Sequence),
"get_reverse_number_sequence" => Some(ForeignCall::ReverseSequence),
"create_mock" => Some(ForeignCall::CreateMock),
"set_mock_params" => Some(ForeignCall::SetMockParams),
"set_mock_returns" => Some(ForeignCall::SetMockReturns),
"set_mock_times" => Some(ForeignCall::SetMockTimes),
"clear_mock" => Some(ForeignCall::ClearMock),
_ => None,
}
}
}

/// This struct represents an oracle mock. It can be used for testing programs that use oracles.
#[derive(Debug, PartialEq, Eq, Clone)]
struct MockedCall {
/// The id of the mock, used to update or remove it
id: usize,
/// The oracle it's mocking
name: String,
/// Optionally match the parameters
params: Option<Vec<ForeignCallParam>>,
/// The result to return when this mock is called
result: ForeignCallResult,
/// How many times should this mock be called before it is removed
times_left: Option<u64>,
}

impl MockedCall {
fn new(id: usize, name: String) -> Self {
Self {
id,
name,
params: None,
result: ForeignCallResult { values: vec![] },
times_left: None,
}
}
}

impl MockedCall {
fn matches(&self, name: &str, params: &Vec<ForeignCallParam>) -> bool {
self.name == name && (self.params.is_none() || self.params.as_ref() == Some(params))
}
}

#[derive(Debug, Default)]
pub(crate) struct ForeignCallExecutor {
/// Mocks have unique ids used to identify them in Noir, allowing to update or remove them.
last_mock_id: usize,
/// The registered mocks
mocked_responses: Vec<MockedCall>,
}

impl ForeignCallExecutor {
pub(crate) fn execute(
&mut self,
foreign_call: &ForeignCallWaitInfo,
show_output: bool,
) -> Result<ForeignCallResult, NargoError> {
let foreign_call_name = foreign_call.function.as_str();
match Self::lookup(foreign_call_name) {
match ForeignCall::lookup(foreign_call_name) {
Some(ForeignCall::Println) => {
if show_output {
Self::execute_println(&foreign_call.inputs)?;
Expand Down Expand Up @@ -76,10 +134,92 @@ impl ForeignCall {
],
})
}
None => panic!("unexpected foreign call {foreign_call_name:?}"),
Some(ForeignCall::CreateMock) => {
let mock_oracle_name = Self::parse_string(&foreign_call.inputs[0]);
assert!(ForeignCall::lookup(&mock_oracle_name).is_none());
let id = self.last_mock_id;
self.mocked_responses.push(MockedCall::new(id, mock_oracle_name));
self.last_mock_id += 1;

Ok(ForeignCallResult { values: vec![Value::from(id).into()] })
}
Some(ForeignCall::SetMockParams) => {
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?;
self.find_mock_by_id(id)
.unwrap_or_else(|| panic!("Unknown mock id {}", id))
.params = Some(params.to_vec());

Ok(ForeignCallResult { values: vec![] })
}
Some(ForeignCall::SetMockReturns) => {
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?;
self.find_mock_by_id(id)
.unwrap_or_else(|| panic!("Unknown mock id {}", id))
.result = ForeignCallResult { values: params.to_vec() };

Ok(ForeignCallResult { values: vec![] })
}
Some(ForeignCall::SetMockTimes) => {
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?;
let times = params[0]
.unwrap_value()
.to_field()
.try_to_u64()
.expect("Invalid bit size of times");

self.find_mock_by_id(id)
.unwrap_or_else(|| panic!("Unknown mock id {}", id))
.times_left = Some(times);

Ok(ForeignCallResult { values: vec![] })
}
Some(ForeignCall::ClearMock) => {
let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?;
self.mocked_responses.retain(|response| response.id != id);
Ok(ForeignCallResult { values: vec![] })
}
None => {
let response_position = self
.mocked_responses
.iter()
.position(|response| response.matches(foreign_call_name, &foreign_call.inputs))
.unwrap_or_else(|| panic!("Unknown foreign call {}", foreign_call_name));

let mock = self
.mocked_responses
.get_mut(response_position)
.expect("Invalid position of mocked response");
let result = mock.result.values.clone();

if let Some(times_left) = &mut mock.times_left {
*times_left -= 1;
if *times_left == 0 {
self.mocked_responses.remove(response_position);
}
}

Ok(ForeignCallResult { values: result })
}
}
}

fn extract_mock_id(
foreign_call_inputs: &[ForeignCallParam],
) -> Result<(usize, &[ForeignCallParam]), ForeignCallError> {
let (id, params) =
foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?;
Ok((id.unwrap_value().to_usize(), params))
}

fn find_mock_by_id(&mut self, id: usize) -> Option<&mut MockedCall> {
self.mocked_responses.iter_mut().find(|response| response.id == id)
}

fn parse_string(param: &ForeignCallParam) -> String {
let fields: Vec<_> = param.values().into_iter().map(|value| value.to_field()).collect();
decode_string_value(&fields)
}

fn execute_println(foreign_call_inputs: &[ForeignCallParam]) -> Result<(), NargoError> {
let display_values: PrintableValueDisplay = foreign_call_inputs.try_into()?;
println!("{display_values}");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "mock_oracle"
type = "bin"
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "10"

30 changes: 30 additions & 0 deletions tooling/nargo_cli/tests/execution_success/mock_oracle/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use dep::std::test::OracleMock;

struct Point {
x: Field,
y: Field,
}

#[oracle(foo)]
unconstrained fn foo_oracle(_point: Point, _array: [Field; 4]) -> Field {}

unconstrained fn main() {
let array = [1,2,3,4];
let another_array = [4,3,2,1];
let point = Point {
x: 14,
y: 27,
};

OracleMock::mock("foo").returns(42).times(1);
let mock = OracleMock::mock("foo").returns(0);
assert_eq(42, foo_oracle(point, array));
assert_eq(0, foo_oracle(point, array));
mock.clear();

OracleMock::mock("foo").with_params((point, array)).returns(10);
OracleMock::mock("foo").with_params((point, another_array)).returns(20);
assert_eq(10, foo_oracle(point, array));
assert_eq(20, foo_oracle(point, another_array));
}

0 comments on commit 0dd1e77

Please sign in to comment.