Skip to content

Commit

Permalink
feat(avm/brillig)!: revert/rethrow oracle (#9408)
Browse files Browse the repository at this point in the history
This PR introduces a revert oracle to be used when (and only when) rethrowing revertdata in public. The major difference with just doing `assert(false, data)` is that the latter will also add an error selector to the revertdata, which is not something we want when rethrowing.

* Creates a revert oracle to be used for rethrowing.
* Changes TRAP/REVERT to have a runtime size.
  • Loading branch information
fcarreiro authored Oct 25, 2024
1 parent 4c4974f commit 1bbd724
Show file tree
Hide file tree
Showing 17 changed files with 217 additions and 64 deletions.
78 changes: 54 additions & 24 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,29 +316,11 @@ pub fn brillig_to_avm(
});
}
BrilligOpcode::Trap { revert_data } => {
let bits_needed =
*[bits_needed_for(&revert_data.pointer), bits_needed_for(&revert_data.size)]
.iter()
.max()
.unwrap();
let avm_opcode = match bits_needed {
8 => AvmOpcode::REVERT_8,
16 => AvmOpcode::REVERT_16,
_ => panic!("REVERT only support 8 or 16 bit encodings, got: {}", bits_needed),
};
avm_instrs.push(AvmInstruction {
opcode: avm_opcode,
indirect: Some(
AddressingModeBuilder::default()
.indirect_operand(&revert_data.pointer)
.build(),
),
operands: vec![
make_operand(bits_needed, &revert_data.pointer.to_usize()),
make_operand(bits_needed, &revert_data.size),
],
..Default::default()
});
generate_revert_instruction(
&mut avm_instrs,
&revert_data.pointer,
&revert_data.size,
);
}
BrilligOpcode::Cast { destination, source, bit_size } => {
handle_cast(&mut avm_instrs, source, destination, *bit_size);
Expand Down Expand Up @@ -418,6 +400,7 @@ fn handle_foreign_call(
}
"avmOpcodeCalldataCopy" => handle_calldata_copy(avm_instrs, destinations, inputs),
"avmOpcodeReturn" => handle_return(avm_instrs, destinations, inputs),
"avmOpcodeRevert" => handle_revert(avm_instrs, destinations, inputs),
"avmOpcodeStorageRead" => handle_storage_read(avm_instrs, destinations, inputs),
"avmOpcodeStorageWrite" => handle_storage_write(avm_instrs, destinations, inputs),
"debugLog" => handle_debug_log(avm_instrs, destinations, inputs),
Expand Down Expand Up @@ -929,6 +912,35 @@ fn generate_cast_instruction(
}
}

/// Generates an AVM REVERT instruction.
fn generate_revert_instruction(
avm_instrs: &mut Vec<AvmInstruction>,
revert_data_pointer: &MemoryAddress,
revert_data_size_offset: &MemoryAddress,
) {
let bits_needed =
*[revert_data_pointer, revert_data_size_offset].map(bits_needed_for).iter().max().unwrap();
let avm_opcode = match bits_needed {
8 => AvmOpcode::REVERT_8,
16 => AvmOpcode::REVERT_16,
_ => panic!("REVERT only support 8 or 16 bit encodings, got: {}", bits_needed),
};
avm_instrs.push(AvmInstruction {
opcode: avm_opcode,
indirect: Some(
AddressingModeBuilder::default()
.indirect_operand(revert_data_pointer)
.direct_operand(revert_data_size_offset)
.build(),
),
operands: vec![
make_operand(bits_needed, &revert_data_pointer.to_usize()),
make_operand(bits_needed, &revert_data_size_offset.to_usize()),
],
..Default::default()
});
}

/// Generates an AVM MOV instruction.
fn generate_mov_instruction(
indirect: Option<AvmOperand>,
Expand Down Expand Up @@ -1214,7 +1226,6 @@ fn handle_return(
assert!(inputs.len() == 1);
assert!(destinations.len() == 0);

// First arg is the size, which is ignored because it's redundant.
let (return_data_offset, return_data_size) = match inputs[0] {
ValueOrArray::HeapArray(HeapArray { pointer, size }) => (pointer, size as u32),
_ => panic!("Return instruction's args input should be a HeapArray"),
Expand All @@ -1233,6 +1244,25 @@ fn handle_return(
});
}

// #[oracle(avmOpcodeRevert)]
// unconstrained fn revert_opcode(revertdata: [Field]) {}
fn handle_revert(
avm_instrs: &mut Vec<AvmInstruction>,
destinations: &Vec<ValueOrArray>,
inputs: &Vec<ValueOrArray>,
) {
assert!(inputs.len() == 2);
assert!(destinations.len() == 0);

// First arg is the size, which is ignored because it's redundant.
let (revert_data_offset, revert_data_size_offset) = match inputs[1] {
ValueOrArray::HeapVector(HeapVector { pointer, size }) => (pointer, size),
_ => panic!("Revert instruction's args input should be a HeapVector"),
};

generate_revert_instruction(avm_instrs, &revert_data_offset, &revert_data_size_offset);
}

/// Emit a storage write opcode
/// The current implementation writes an array of values into storage ( contiguous slots in memory )
fn handle_storage_write(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "barretenberg/common/throw_or_abort.hpp"
#include "bincode.hpp"
#include "serde.hpp"

Expand Down Expand Up @@ -712,7 +713,7 @@ struct BrilligOpcode {
};

struct Trap {
Program::HeapArray revert_data;
Program::HeapVector revert_data;

friend bool operator==(const Trap&, const Trap&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down
53 changes: 51 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2738,10 +2738,59 @@ std::vector<FF> AvmTraceBuilder::op_return(uint8_t indirect, uint32_t ret_offset
return returndata;
}

std::vector<FF> AvmTraceBuilder::op_revert(uint8_t indirect, uint32_t ret_offset, uint32_t ret_size)
std::vector<FF> AvmTraceBuilder::op_revert(uint8_t indirect, uint32_t ret_offset, uint32_t ret_size_offset)
{
// TODO: This opcode is still masquerading as RETURN.
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

// This boolean will not be a trivial constant once we re-enable constraining address resolution
bool tag_match = true;

auto [resolved_ret_offset, resolved_ret_size_offset] =
Addressing<2>::fromWire(indirect, call_ptr).resolve({ ret_offset, ret_size_offset }, mem_trace_builder);
const auto ret_size = static_cast<uint32_t>(unconstrained_read_from_memory(resolved_ret_size_offset));

gas_trace_builder.constrain_gas(clk, OpCode::RETURN, ret_size);

// TODO: fix and set sel_op_revert
return op_return(indirect, ret_offset, ret_size);
if (ret_size == 0) {
main_trace.push_back(Row{
.main_clk = clk,
.main_call_ptr = call_ptr,
.main_ib = ret_size,
.main_internal_return_ptr = FF(internal_return_ptr),
.main_pc = pc,
.main_sel_op_external_return = 1,
});

pc = UINT32_MAX; // This ensures that no subsequent opcode will be executed.
return {};
}

// The only memory operation performed from the main trace is a possible indirect load for resolving the
// direct destination offset stored in main_mem_addr_c.
// All the other memory operations are triggered by the slice gadget.
if (tag_match) {
returndata = mem_trace_builder.read_return_opcode(clk, call_ptr, resolved_ret_offset, ret_size);
slice_trace_builder.create_return_slice(returndata, clk, call_ptr, resolved_ret_offset, ret_size);
}

main_trace.push_back(Row{
.main_clk = clk,
.main_call_ptr = call_ptr,
.main_ib = ret_size,
.main_internal_return_ptr = FF(internal_return_ptr),
.main_mem_addr_c = resolved_ret_offset,
.main_pc = pc,
.main_r_in_tag = static_cast<uint32_t>(AvmMemoryTag::FF),
.main_sel_op_external_return = 1,
.main_sel_slice_gadget = static_cast<uint32_t>(tag_match),
.main_tag_err = static_cast<uint32_t>(!tag_match),
.main_w_in_tag = static_cast<uint32_t>(AvmMemoryTag::FF),
});

pc = UINT32_MAX; // This ensures that no subsequent opcode will be executed.
return returndata;
}

/**************************************************************************************************
Expand Down
2 changes: 1 addition & 1 deletion barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class AvmTraceBuilder {
uint32_t function_selector_offset);
std::vector<FF> op_return(uint8_t indirect, uint32_t ret_offset, uint32_t ret_size);
// REVERT Opcode (that just call return under the hood for now)
std::vector<FF> op_revert(uint8_t indirect, uint32_t ret_offset, uint32_t ret_size);
std::vector<FF> op_revert(uint8_t indirect, uint32_t ret_offset, uint32_t ret_size_offset);

// Gadgets
void op_poseidon2_permutation(uint8_t indirect, uint32_t input_offset, uint32_t output_offset);
Expand Down
15 changes: 15 additions & 0 deletions noir-projects/aztec-nr/aztec/src/context/public_context.nr
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ unconstrained fn avm_return<let N: u32>(returndata: [Field; N]) {
return_opcode(returndata)
}

// This opcode reverts using the exact data given. In general it should only be used
// to do rethrows, where the revert data is the same as the original revert data.
// For normal reverts, use Noir's `assert` which, on top of reverting, will also add
// an error selector to the revert data.
unconstrained fn avm_revert<let N: u32>(revertdata: [Field]) {
revert_opcode(revertdata)
}

unconstrained fn storage_read(storage_slot: Field) -> Field {
storage_read_opcode(storage_slot)
}
Expand Down Expand Up @@ -378,6 +386,13 @@ unconstrained fn calldata_copy_opcode<let N: u32>(cdoffset: u32, copy_size: u32)
#[oracle(avmOpcodeReturn)]
unconstrained fn return_opcode<let N: u32>(returndata: [Field; N]) {}

// This opcode reverts using the exact data given. In general it should only be used
// to do rethrows, where the revert data is the same as the original revert data.
// For normal reverts, use Noir's `assert` which, on top of reverting, will also add
// an error selector to the revert data.
#[oracle(avmOpcodeRevert)]
unconstrained fn revert_opcode(revertdata: [Field]) {}

#[oracle(avmOpcodeCall)]
unconstrained fn call_opcode<let RET_SIZE: u32>(
gas: [Field; 2], // gas allocation: [l2_gas, da_gas]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,12 @@ contract AvmTest {
[4, 5, 6] // Should not get here.
}

#[public]
fn revert_oracle() -> [Field; 3] {
dep::aztec::context::public_context::avm_revert([1, 2, 3]);
[4, 5, 6] // Should not get here.
}

/************************************************************************
* Hashing functions
************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ namespace Program {
};

struct Trap {
Program::HeapArray revert_data;
Program::HeapVector revert_data;

friend bool operator==(const Trap&, const Trap&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down
18 changes: 14 additions & 4 deletions noir/noir-repo/acvm-repo/acvm/tests/solver.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::{BTreeMap, HashSet};
use std::sync::Arc;

use acir::brillig::{BitSize, IntegerBitSize};
use acir::brillig::{BitSize, HeapVector, IntegerBitSize};
use acir::{
acir_field::GenericFieldElement,
brillig::{BinaryFieldOp, HeapArray, MemoryAddress, Opcode as BrilligOpcode, ValueOrArray},
brillig::{BinaryFieldOp, MemoryAddress, Opcode as BrilligOpcode, ValueOrArray},
circuit::{
brillig::{BrilligBytecode, BrilligFunctionId, BrilligInputs, BrilligOutputs},
opcodes::{BlackBoxFuncCall, BlockId, BlockType, FunctionInput, MemOp},
Expand Down Expand Up @@ -667,7 +667,12 @@ fn unsatisfied_opcode_resolved_brillig() {
let jmp_if_opcode =
BrilligOpcode::JumpIf { condition: MemoryAddress::direct(2), location: location_of_stop };

let trap_opcode = BrilligOpcode::Trap { revert_data: HeapArray::default() };
let trap_opcode = BrilligOpcode::Trap {
revert_data: HeapVector {
pointer: MemoryAddress::direct(0),
size: MemoryAddress::direct(3),
},
};
let stop_opcode = BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 0 };

let brillig_bytecode = BrilligBytecode {
Expand All @@ -682,6 +687,11 @@ fn unsatisfied_opcode_resolved_brillig() {
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
},
BrilligOpcode::Const {
destination: MemoryAddress::direct(3),
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
},
calldata_copy_opcode,
equal_opcode,
jmp_if_opcode,
Expand Down Expand Up @@ -739,7 +749,7 @@ fn unsatisfied_opcode_resolved_brillig() {
ACVMStatus::Failure(OpcodeResolutionError::BrilligFunctionFailed {
function_id: BrilligFunctionId(0),
payload: None,
call_stack: vec![OpcodeLocation::Brillig { acir_index: 0, brillig_index: 5 }]
call_stack: vec![OpcodeLocation::Brillig { acir_index: 0, brillig_index: 6 }]
}),
"The first opcode is not satisfiable, expected an error indicating this"
);
Expand Down
2 changes: 1 addition & 1 deletion noir/noir-repo/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ pub enum BrilligOpcode<F> {
BlackBox(BlackBoxOp),
/// Used to denote execution failure, returning data after the offset
Trap {
revert_data: HeapArray,
revert_data: HeapVector,
},
/// Stop execution, returning data after the offset
Stop {
Expand Down
23 changes: 18 additions & 5 deletions noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,11 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
self.increment_program_counter()
}
Opcode::Trap { revert_data } => {
if revert_data.size > 0 {
let revert_data_size = self.memory.read(revert_data.size).to_usize();
if revert_data_size > 0 {
self.trap(
self.memory.read_ref(revert_data.pointer).unwrap_direct(),
revert_data.size,
revert_data_size,
)
} else {
self.trap(0, 0)
Expand Down Expand Up @@ -904,8 +905,18 @@ mod tests {
size_address: MemoryAddress::direct(0),
offset_address: MemoryAddress::direct(1),
},
Opcode::Jump { location: 5 },
Opcode::Trap { revert_data: HeapArray::default() },
Opcode::Jump { location: 6 },
Opcode::Const {
destination: MemoryAddress::direct(0),
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
},
Opcode::Trap {
revert_data: HeapVector {
pointer: MemoryAddress::direct(0),
size: MemoryAddress::direct(0),
},
},
Opcode::BinaryFieldOp {
op: BinaryFieldOp::Equals,
lhs: MemoryAddress::direct(0),
Expand Down Expand Up @@ -933,6 +944,8 @@ mod tests {
assert_eq!(status, VMStatus::InProgress);
let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);
let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let output_cmp_value = vm.memory.read(MemoryAddress::direct(2));
assert_eq!(output_cmp_value.to_field(), false.into());
Expand All @@ -945,7 +958,7 @@ mod tests {
status,
VMStatus::Failure {
reason: FailureReason::Trap { revert_data_offset: 0, revert_data_size: 0 },
call_stack: vec![4]
call_stack: vec![5]
}
);

Expand Down
18 changes: 14 additions & 4 deletions noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ pub(crate) mod tests {
use std::vec;

use acvm::acir::brillig::{
BitSize, ForeignCallParam, ForeignCallResult, HeapArray, HeapVector, IntegerBitSize,
MemoryAddress, ValueOrArray,
BitSize, ForeignCallParam, ForeignCallResult, HeapVector, IntegerBitSize, MemoryAddress,
ValueOrArray,
};
use acvm::brillig_vm::brillig::HeapValueType;
use acvm::brillig_vm::{VMStatus, VM};
Expand Down Expand Up @@ -288,8 +288,18 @@ pub(crate) mod tests {
// We push a JumpIf and Trap opcode directly as the constrain instruction
// uses unresolved jumps which requires a block to be constructed in SSA and
// we don't need this for Brillig IR tests
context.push_opcode(BrilligOpcode::JumpIf { condition: r_equality, location: 8 });
context.push_opcode(BrilligOpcode::Trap { revert_data: HeapArray::default() });
context.push_opcode(BrilligOpcode::JumpIf { condition: r_equality, location: 9 });
context.push_opcode(BrilligOpcode::Const {
destination: MemoryAddress::direct(0),
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
});
context.push_opcode(BrilligOpcode::Trap {
revert_data: HeapVector {
pointer: MemoryAddress::direct(0),
size: MemoryAddress::direct(0),
},
});

context.stop_instruction();

Expand Down
Loading

0 comments on commit 1bbd724

Please sign in to comment.