diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 56cb76adbe4..808cf7533c9 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -48,6 +48,7 @@ pub(crate) fn optimize_into_acir( let ssa_gen_span_guard = ssa_gen_span.enter(); let ssa = SsaBuilder::new(program, print_ssa_passes, force_brillig_output)? .run_pass(Ssa::defunctionalize, "After Defunctionalization:") + .run_pass(Ssa::remove_paired_rc, "After Removing Paired rc_inc & rc_decs:") .run_pass(Ssa::inline_functions, "After Inlining:") // Run mem2reg with the CFG separated into blocks .run_pass(Ssa::mem2reg, "After Mem2Reg:") @@ -59,10 +60,7 @@ pub(crate) fn optimize_into_acir( // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores .run_pass(Ssa::mem2reg, "After Mem2Reg:") .run_pass(Ssa::fold_constants, "After Constant Folding:") - .run_pass( - Ssa::fold_constants_using_constraints, - "After Constant Folding With Constraint Info:", - ) + .run_pass(Ssa::fold_constants_using_constraints, "After Constraint Folding:") .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") .finish(); diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 2c39c83b342..aa5a7fedd92 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -195,12 +195,9 @@ impl FunctionBuilder { self.call_stack.clone() } - /// Insert a Load instruction at the end of the current block, loading from the given offset - /// of the given address which should point to a previous Allocate instruction. Note that - /// this is limited to loading a single value. Loading multiple values (such as a tuple) - /// will require multiple loads. - /// 'offset' is in units of FieldElements here. So loading the fourth FieldElement stored in - /// an array will have an offset of 3. + /// Insert a Load instruction at the end of the current block, loading from the given address + /// which should point to a previous Allocate instruction. Note that this is limited to loading + /// a single value. Loading multiple values (such as a tuple) will require multiple loads. /// Returns the element that was loaded. pub(crate) fn insert_load(&mut self, address: ValueId, type_to_load: Type) -> ValueId { self.insert_instruction(Instruction::Load { address }, Some(vec![type_to_load])).first() @@ -221,11 +218,9 @@ impl FunctionBuilder { operator: BinaryOp, rhs: ValueId, ) -> ValueId { - assert_eq!( - self.type_of_value(lhs), - self.type_of_value(rhs), - "ICE - Binary instruction operands must have the same type" - ); + let lhs_type = self.type_of_value(lhs); + let rhs_type = self.type_of_value(rhs); + assert_eq!(lhs_type, rhs_type, "ICE - Binary instruction operands must have the same type"); let instruction = Instruction::Binary(Binary { lhs, rhs, operator }); self.insert_instruction(instruction, None).first() } @@ -309,6 +304,18 @@ impl FunctionBuilder { self.insert_instruction(Instruction::ArraySet { array, index, value }, None).first() } + /// Insert an instruction to increment an array's reference count. This only has an effect + /// in unconstrained code where arrays are reference counted and copy on write. + pub(crate) fn insert_inc_rc(&mut self, value: ValueId) { + self.insert_instruction(Instruction::IncrementRc { value }, None); + } + + /// Insert an instruction to decrement an array's reference count. This only has an effect + /// in unconstrained code where arrays are reference counted and copy on write. + pub(crate) fn insert_dec_rc(&mut self, value: ValueId) { + self.insert_instruction(Instruction::DecrementRc { value }, None); + } + /// Terminates the current block with the given terminator instruction fn terminate_block_with(&mut self, terminator: TerminatorInstruction) { self.current_function.dfg.set_block_terminator(self.current_block, terminator); @@ -384,51 +391,65 @@ impl FunctionBuilder { /// within the given value. If the given value is not an array and does not contain /// any arrays, this does nothing. pub(crate) fn increment_array_reference_count(&mut self, value: ValueId) { - self.update_array_reference_count(value, true); + self.update_array_reference_count(value, true, None); } /// Insert instructions to decrement the reference count of any array(s) stored /// within the given value. If the given value is not an array and does not contain /// any arrays, this does nothing. pub(crate) fn decrement_array_reference_count(&mut self, value: ValueId) { - self.update_array_reference_count(value, false); + self.update_array_reference_count(value, false, None); } /// Increment or decrement the given value's reference count if it is an array. /// If it is not an array, this does nothing. Note that inc_rc and dec_rc instructions /// are ignored outside of unconstrained code. - pub(crate) fn update_array_reference_count(&mut self, value: ValueId, increment: bool) { + fn update_array_reference_count( + &mut self, + value: ValueId, + increment: bool, + load_address: Option, + ) { match self.type_of_value(value) { Type::Numeric(_) => (), Type::Function => (), Type::Reference(element) => { if element.contains_an_array() { - let value = self.insert_load(value, element.as_ref().clone()); - self.increment_array_reference_count(value); + let reference = value; + let value = self.insert_load(reference, element.as_ref().clone()); + self.update_array_reference_count(value, increment, Some(reference)); } } typ @ Type::Array(..) | typ @ Type::Slice(..) => { // If there are nested arrays or slices, we wait until ArrayGet // is issued to increment the count of that array. - let instruction = if increment { - Instruction::IncrementRc { value } - } else { - Instruction::DecrementRc { value } + let update_rc = |this: &mut Self, value| { + if increment { + this.insert_inc_rc(value); + } else { + this.insert_dec_rc(value); + } }; - self.insert_instruction(instruction, None); + + update_rc(self, value); + let dfg = &self.current_function.dfg; // This is a bit odd, but in brillig the inc_rc instruction operates on // a copy of the array's metadata, so we need to re-store a loaded array // even if there have been no other changes to it. - if let Value::Instruction { instruction, .. } = &self.current_function.dfg[value] { - let instruction = &self.current_function.dfg[*instruction]; + if let Some(address) = load_address { + // If we already have a load from the Type::Reference case, avoid inserting + // another load and rc update. + self.insert_store(address, value); + } else if let Value::Instruction { instruction, .. } = &dfg[value] { + let instruction = &dfg[*instruction]; if let Instruction::Load { address } = instruction { // We can't re-use `value` in case the original address was stored // to again in the meantime. So introduce another load. let address = *address; - let value = self.insert_load(address, typ); - self.insert_instruction(Instruction::IncrementRc { value }, None); - self.insert_store(address, value); + let new_load = self.insert_load(address, typ); + update_rc(self, new_load); + self.insert_store(address, new_load); } } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index a315695f7db..8f98b3fb17f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -12,6 +12,7 @@ mod die; pub(crate) mod flatten_cfg; mod inlining; mod mem2reg; +mod rc; mod remove_bit_shifts; mod simplify_cfg; mod unrolling; diff --git a/compiler/noirc_evaluator/src/ssa/opt/rc.rs b/compiler/noirc_evaluator/src/ssa/opt/rc.rs new file mode 100644 index 00000000000..4766bc3e8d2 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/rc.rs @@ -0,0 +1,327 @@ +use std::collections::{HashMap, HashSet}; + +use crate::ssa::{ + ir::{ + basic_block::BasicBlockId, + function::Function, + instruction::{Instruction, InstructionId, TerminatorInstruction}, + types::Type, + value::ValueId, + }, + ssa_gen::Ssa, +}; + +impl Ssa { + /// This pass removes `inc_rc` and `dec_rc` instructions + /// as long as there are no `array_set` instructions to an array + /// of the same type in between. + /// + /// Note that this pass is very conservative since the array_set + /// instruction does not need to be to the same array. This is because + /// the given array may alias another array (e.g. function parameters or + /// a `load`ed array from a reference). + #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn remove_paired_rc(mut self) -> Ssa { + for function in self.functions.values_mut() { + remove_paired_rc(function); + } + self + } +} + +#[derive(Default)] +struct Context { + // All inc_rc instructions encountered without a corresponding dec_rc. + // These are only searched for in the first block of a function. + // + // The type of the array being operated on is recorded. + // If an array_set to that array type is encountered, that is also recorded. + inc_rcs: HashMap>, +} + +struct IncRc { + id: InstructionId, + array: ValueId, + possibly_mutated: bool, +} + +/// This function is very simplistic for now. It takes advantage of the fact that dec_rc +/// instructions are currently issued only at the end of a function for parameters and will +/// only check the first and last block for inc & dec rc instructions to be removed. The rest +/// of the function is still checked for array_set instructions. +/// +/// This restriction lets this function largely ignore merging intermediate results from other +/// blocks and handling loops. +fn remove_paired_rc(function: &mut Function) { + // `dec_rc` is only issued for parameters currently so we can speed things + // up a bit by skipping any functions without them. + if !contains_array_parameter(function) { + return; + } + + let mut context = Context::default(); + + context.find_rcs_in_entry_block(function); + context.scan_for_array_sets(function); + let to_remove = context.find_rcs_to_remove(function); + remove_instructions(to_remove, function); +} + +fn contains_array_parameter(function: &mut Function) -> bool { + let mut parameters = function.parameters().iter(); + parameters.any(|parameter| function.dfg.type_of_value(*parameter).contains_an_array()) +} + +impl Context { + fn find_rcs_in_entry_block(&mut self, function: &Function) { + let entry = function.entry_block(); + + for instruction in function.dfg[entry].instructions() { + if let Instruction::IncrementRc { value } = &function.dfg[*instruction] { + let typ = function.dfg.type_of_value(*value); + + // We assume arrays aren't mutated until we find an array_set + let inc_rc = IncRc { id: *instruction, array: *value, possibly_mutated: false }; + self.inc_rcs.entry(typ).or_default().push(inc_rc); + } + } + } + + /// Find each array_set instruction in the function and mark any arrays used + /// by the inc_rc instructions as possibly mutated if they're the same type. + fn scan_for_array_sets(&mut self, function: &Function) { + for block in function.reachable_blocks() { + for instruction in function.dfg[block].instructions() { + if let Instruction::ArraySet { array, .. } = function.dfg[*instruction] { + let typ = function.dfg.type_of_value(array); + if let Some(inc_rcs) = self.inc_rcs.get_mut(&typ) { + for inc_rc in inc_rcs { + inc_rc.possibly_mutated = true; + } + } + } + } + } + } + + /// Find each dec_rc instruction and if the most recent inc_rc instruction for the same value + /// is not possibly mutated, then we can remove them both. Returns each such pair. + fn find_rcs_to_remove(&mut self, function: &Function) -> HashSet { + let last_block = Self::find_last_block(function); + let mut to_remove = HashSet::new(); + + for instruction in function.dfg[last_block].instructions() { + if let Instruction::DecrementRc { value } = &function.dfg[*instruction] { + if let Some(inc_rc) = self.pop_rc_for(*value, function) { + if !inc_rc.possibly_mutated { + to_remove.insert(inc_rc.id); + to_remove.insert(*instruction); + } + } + } + } + + to_remove + } + + /// Finds the block of the function with the Return instruction + fn find_last_block(function: &Function) -> BasicBlockId { + for block in function.reachable_blocks() { + if matches!( + function.dfg[block].terminator(), + Some(TerminatorInstruction::Return { .. }) + ) { + return block; + } + } + + unreachable!("SSA Function {} has no reachable return instruction!", function.id()) + } + + /// Finds and pops the IncRc for the given array value if possible. + fn pop_rc_for(&mut self, value: ValueId, function: &Function) -> Option { + let typ = function.dfg.type_of_value(value); + + let rcs = self.inc_rcs.get_mut(&typ)?; + let position = rcs.iter().position(|inc_rc| inc_rc.array == value)?; + + Some(rcs.remove(position)) + } +} + +fn remove_instructions(to_remove: HashSet, function: &mut Function) { + if !to_remove.is_empty() { + for block in function.reachable_blocks() { + function.dfg[block] + .instructions_mut() + .retain(|instruction| !to_remove.contains(instruction)); + } + } +} + +#[cfg(test)] +mod test { + use std::rc::Rc; + + use crate::ssa::{ + function_builder::FunctionBuilder, + ir::{ + basic_block::BasicBlockId, dfg::DataFlowGraph, function::RuntimeType, + instruction::Instruction, map::Id, types::Type, + }, + }; + + fn count_inc_rcs(block: BasicBlockId, dfg: &DataFlowGraph) -> usize { + dfg[block] + .instructions() + .iter() + .filter(|instruction_id| { + matches!(dfg[**instruction_id], Instruction::IncrementRc { .. }) + }) + .count() + } + + fn count_dec_rcs(block: BasicBlockId, dfg: &DataFlowGraph) -> usize { + dfg[block] + .instructions() + .iter() + .filter(|instruction_id| { + matches!(dfg[**instruction_id], Instruction::DecrementRc { .. }) + }) + .count() + } + + #[test] + fn single_block_fn_return_array() { + // This is the output for the program with a function: + // unconstrained fn foo(x: [Field; 2]) -> [[Field; 2]; 1] { + // [array] + // } + // + // fn foo { + // b0(v0: [Field; 2]): + // inc_rc v0 + // inc_rc v0 + // dec_rc v0 + // return [v0] + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("foo".into(), main_id, RuntimeType::Brillig); + + let inner_array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let v0 = builder.add_parameter(inner_array_type.clone()); + + builder.insert_inc_rc(v0); + builder.insert_inc_rc(v0); + builder.insert_dec_rc(v0); + + let outer_array_type = Type::Array(Rc::new(vec![inner_array_type]), 1); + let array = builder.array_constant(vec![v0].into(), outer_array_type); + builder.terminate_with_return(vec![array]); + + let ssa = builder.finish().remove_paired_rc(); + let main = ssa.main(); + let entry = main.entry_block(); + + assert_eq!(count_inc_rcs(entry, &main.dfg), 1); + assert_eq!(count_dec_rcs(entry, &main.dfg), 0); + } + + #[test] + fn single_block_mutation() { + // fn mutator(mut array: [Field; 2]) { + // array[0] = 5; + // } + // + // fn mutator { + // b0(v0: [Field; 2]): + // v1 = allocate + // store v0 at v1 + // inc_rc v0 + // v2 = load v1 + // v7 = array_set v2, index u64 0, value Field 5 + // store v7 at v1 + // dec_rc v0 + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("mutator".into(), main_id, RuntimeType::Acir); + + let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let v0 = builder.add_parameter(array_type.clone()); + + let v1 = builder.insert_allocate(array_type.clone()); + builder.insert_store(v1, v0); + builder.insert_inc_rc(v0); + let v2 = builder.insert_load(v1, array_type); + + let zero = builder.numeric_constant(0u128, Type::unsigned(64)); + let five = builder.field_constant(5u128); + let v7 = builder.insert_array_set(v2, zero, five); + + builder.insert_store(v1, v7); + builder.insert_dec_rc(v0); + builder.terminate_with_return(vec![]); + + let ssa = builder.finish().remove_paired_rc(); + let main = ssa.main(); + let entry = main.entry_block(); + + // No changes, the array is possibly mutated + assert_eq!(count_inc_rcs(entry, &main.dfg), 1); + assert_eq!(count_dec_rcs(entry, &main.dfg), 1); + } + + // Similar to single_block_mutation but for a function which + // uses a mutable reference parameter. + #[test] + fn single_block_mutation_through_reference() { + // fn mutator2(array: &mut [Field; 2]) { + // array[0] = 5; + // } + // + // fn mutator2 { + // b0(v0: &mut [Field; 2]): + // v1 = load v0 + // inc_rc v1 + // store v1 at v0 + // v2 = load v0 + // v7 = array_set v2, index u64 0, value Field 5 + // store v7 at v0 + // v8 = load v0 + // dec_rc v8 + // store v8 at v0 + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("mutator2".into(), main_id, RuntimeType::Acir); + + let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let reference_type = Type::Reference(Rc::new(array_type.clone())); + + let v0 = builder.add_parameter(reference_type); + + let v1 = builder.insert_load(v0, array_type.clone()); + builder.insert_inc_rc(v1); + builder.insert_store(v0, v1); + + let v2 = builder.insert_load(v1, array_type.clone()); + let zero = builder.numeric_constant(0u128, Type::unsigned(64)); + let five = builder.field_constant(5u128); + let v7 = builder.insert_array_set(v2, zero, five); + + builder.insert_store(v0, v7); + let v8 = builder.insert_load(v0, array_type); + builder.insert_dec_rc(v8); + builder.insert_store(v0, v8); + builder.terminate_with_return(vec![]); + + let ssa = builder.finish().remove_paired_rc(); + let main = ssa.main(); + let entry = main.entry_block(); + + // No changes, the array is possibly mutated + assert_eq!(count_inc_rcs(entry, &main.dfg), 1); + assert_eq!(count_dec_rcs(entry, &main.dfg), 1); + } +} diff --git a/compiler/noirc_frontend/src/monomorphization/debug.rs b/compiler/noirc_frontend/src/monomorphization/debug.rs index cf4e0ab792e..3a03177f8ec 100644 --- a/compiler/noirc_frontend/src/monomorphization/debug.rs +++ b/compiler/noirc_frontend/src/monomorphization/debug.rs @@ -195,8 +195,8 @@ fn element_type_at_index(ptype: &PrintableType, i: usize) -> &PrintableType { PrintableType::Tuple { types } => &types[i], PrintableType::Struct { name: _name, fields } => &fields[i].1, PrintableType::String { length: _length } => &PrintableType::UnsignedInteger { width: 8 }, - _ => { - panic!["expected type with sub-fields, found terminal type"] + other => { + panic!["expected type with sub-fields, found terminal type: {other:?}"] } } }