diff --git a/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs b/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs index 7b40c35960a..5d19f9629ba 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs @@ -48,25 +48,51 @@ impl RangeOptimizer { /// only store the fact that we have constrained it to /// be 16 bits. fn collect_ranges(circuit: &Circuit) -> BTreeMap { - let mut witness_to_bit_sizes = BTreeMap::new(); + let mut witness_to_bit_sizes: BTreeMap = BTreeMap::new(); for opcode in &circuit.opcodes { - // Extract the witness index and number of bits, - // if it is a range constraint - let (witness, num_bits) = match extract_range_opcode(opcode) { - Some(func_inputs) => func_inputs, - None => continue, + let Some((witness, num_bits)) = (match opcode { + Opcode::AssertZero(expr) => { + // If the opcode is constraining a witness to be equal to a value then it can be considered + // as a range opcode for the number of bits required to hold that value. + if expr.is_degree_one_univariate() { + let (k, witness) = expr.linear_combinations[0]; + let constant = expr.q_c; + let witness_value = -constant / k; + + if witness_value.is_zero() { + Some((witness, 0)) + } else { + // We subtract off 1 bit from the implied witness value to give the weakest range constraint + // which would be stricter than the constraint imposed by this opcode. + let implied_range_constraint_bits = witness_value.num_bits() - 1; + Some((witness, implied_range_constraint_bits)) + } + } else { + None + } + } + + + Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { + input: FunctionInput { witness, num_bits }, + }) => { + Some((*witness, *num_bits)) + } + + _ => None, + }) else { + continue; }; // Check if the witness has already been recorded and if the witness // size is more than the current one, we replace it - let should_replace = match witness_to_bit_sizes.get(&witness).copied() { - Some(old_range_bits) => old_range_bits > num_bits, - None => true, - }; - if should_replace { - witness_to_bit_sizes.insert(witness, num_bits); - } + witness_to_bit_sizes + .entry(witness) + .and_modify(|old_range_bits| { + *old_range_bits = std::cmp::min(*old_range_bits, num_bits); + }) + .or_insert(num_bits); } witness_to_bit_sizes } @@ -116,16 +142,10 @@ impl RangeOptimizer { /// Extract the range opcode from the `Opcode` enum /// Returns None, if `Opcode` is not the range opcode. fn extract_range_opcode(opcode: &Opcode) -> Option<(Witness, u32)> { - // Range constraints are blackbox function calls - // so we first extract the function call - let func_call = match opcode { - acir::circuit::Opcode::BlackBoxFuncCall(func_call) => func_call, - _ => return None, - }; - - // Skip if it is not a range constraint - match func_call { - BlackBoxFuncCall::RANGE { input } => Some((input.witness, input.num_bits)), + match opcode { + Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => { + Some((input.witness, input.num_bits)) + } _ => None, } } @@ -246,4 +266,17 @@ mod tests { let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions); assert_eq!(optimized_circuit.opcodes.len(), 5); } + + #[test] + fn constant_implied_ranges() { + // The optimizer should use knowledge about constant witness assignments to remove range opcodes. + let mut circuit = test_circuit(vec![(Witness(1), 16)]); + + circuit.opcodes.push(Opcode::AssertZero(Witness(1).into())); + let acir_opcode_positions = circuit.opcodes.iter().enumerate().map(|(i, _)| i).collect(); + let optimizer = RangeOptimizer::new(circuit); + let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions); + assert_eq!(optimized_circuit.opcodes.len(), 1); + assert_eq!(optimized_circuit.opcodes[0], Opcode::AssertZero(Witness(1).into())); + } }