Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ssa): Track all local allocations during flattening #6619

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
//! v11 = mul v4, Field 12
//! v12 = add v10, v11
//! store v12 at v5 (new store)
use fxhash::FxHashMap as HashMap;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};

use acvm::{acir::AcirField, acir::BlackBoxFunc, FieldElement};
use iter_extended::vecmap;
Expand Down Expand Up @@ -201,6 +201,15 @@ struct Context<'f> {
/// When processing a block, we pop this stack to get its arguments
/// and at the end we push the arguments for his successor
arguments_stack: Vec<Vec<ValueId>>,

/// Stores all allocations local to the current branch.
///
/// Since these branches are local to the current branch (i.e. only defined within one branch of
/// an if expression), they should not be merged with their previous value or stored value in
/// the other branch since there is no such value.
///
/// The `ValueId` here is that which is returned by the allocate instruction.
local_allocations: HashSet<ValueId>,
}

#[derive(Clone)]
Expand All @@ -211,6 +220,8 @@ struct ConditionalBranch {
old_condition: ValueId,
// The condition of the branch
condition: ValueId,
// The allocations accumulated when processing the branch
local_allocations: HashSet<ValueId>,
}

struct ConditionalContext {
Expand Down Expand Up @@ -243,6 +254,7 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap<Functio
slice_sizes: HashMap::default(),
condition_stack: Vec::new(),
arguments_stack: Vec::new(),
local_allocations: HashSet::default(),
};
context.flatten(no_predicates);
}
Expand Down Expand Up @@ -317,7 +329,6 @@ impl<'f> Context<'f> {
// If this is not a separate variable, clippy gets confused and says the to_vec is
// unnecessary, when removing it actually causes an aliasing/mutability error.
let instructions = self.inserter.function.dfg[block].instructions().to_vec();
let mut previous_allocate_result = None;

for instruction in instructions.iter() {
if self.is_no_predicate(no_predicates, instruction) {
Expand All @@ -332,10 +343,10 @@ impl<'f> Context<'f> {
None,
im::Vector::new(),
);
self.push_instruction(*instruction, &mut previous_allocate_result);
self.push_instruction(*instruction);
self.insert_current_side_effects_enabled();
} else {
self.push_instruction(*instruction, &mut previous_allocate_result);
self.push_instruction(*instruction);
}
}
}
Expand Down Expand Up @@ -405,10 +416,12 @@ impl<'f> Context<'f> {
let old_condition = *condition;
let then_condition = self.inserter.resolve(old_condition);

let old_allocations = std::mem::take(&mut self.local_allocations);
let branch = ConditionalBranch {
old_condition,
condition: self.link_condition(then_condition),
last_block: *then_destination,
local_allocations: old_allocations,
};
let cond_context = ConditionalContext {
condition: then_condition,
Expand All @@ -435,11 +448,14 @@ impl<'f> Context<'f> {
);
let else_condition = self.link_condition(else_condition);

let old_allocations = std::mem::take(&mut self.local_allocations);
let else_branch = ConditionalBranch {
old_condition: cond_context.then_branch.old_condition,
condition: else_condition,
last_block: *block,
local_allocations: old_allocations,
};
cond_context.then_branch.local_allocations.clear();
cond_context.else_branch = Some(else_branch);
self.condition_stack.push(cond_context);

Expand All @@ -461,6 +477,7 @@ impl<'f> Context<'f> {
}

let mut else_branch = cond_context.else_branch.unwrap();
self.local_allocations = std::mem::take(&mut else_branch.local_allocations);
else_branch.last_block = *block;
cond_context.else_branch = Some(else_branch);

Expand Down Expand Up @@ -593,22 +610,19 @@ impl<'f> Context<'f> {
/// `previous_allocate_result` should only be set to the result of an allocate instruction
/// if that instruction was the instruction immediately previous to this one - if there are
/// any instructions in between it should be None.
fn push_instruction(
&mut self,
id: InstructionId,
previous_allocate_result: &mut Option<ValueId>,
) {
fn push_instruction(&mut self, id: InstructionId) {
let (instruction, call_stack) = self.inserter.map_instruction(id);
let instruction = self.handle_instruction_side_effects(
instruction,
call_stack.clone(),
*previous_allocate_result,
);
let instruction = self.handle_instruction_side_effects(instruction, call_stack.clone());

let instruction_is_allocate = matches!(&instruction, Instruction::Allocate);
let entry = self.inserter.function.entry_block();
let results = self.inserter.push_instruction_value(instruction, id, entry, call_stack);
*previous_allocate_result = instruction_is_allocate.then(|| results.first());

// Remember an allocate was created local to this branch so that we do not try to merge store
// values across branches for it later.
if instruction_is_allocate {
self.local_allocations.insert(results.first());
}
}

/// If we are currently in a branch, we need to modify constrain instructions
Expand All @@ -621,7 +635,6 @@ impl<'f> Context<'f> {
&mut self,
instruction: Instruction,
call_stack: CallStack,
previous_allocate_result: Option<ValueId>,
) -> Instruction {
if let Some(condition) = self.get_last_condition() {
match instruction {
Expand Down Expand Up @@ -652,7 +665,7 @@ impl<'f> Context<'f> {
Instruction::Store { address, value } => {
// If this instruction immediately follows an allocate, and stores to that
// address there is no previous value to load and we don't need a merge anyway.
if Some(address) == previous_allocate_result {
if self.local_allocations.contains(&address) {
Instruction::Store { address, value }
} else {
// Instead of storing `value`, store `if condition { value } else { previous_value }`
Expand Down
120 changes: 41 additions & 79 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
//!
//! Repeating this algorithm for each block in the function in program order should result in
//! optimizing out most known loads. However, identifying all aliases correctly has been proven
//! undecidable in general (Landi, 1992). So this pass will not always optimize out all loads

Check warning on line 59 in compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Landi)
//! that could theoretically be optimized out. This pass can be performed at any time in the
//! SSA optimization pipeline, although it will be more successful the simpler the program's CFG is.
//! This pass is currently performed several times to enable other passes - most notably being
Expand Down Expand Up @@ -117,7 +117,7 @@
/// Load and Store instructions that should be removed at the end of the pass.
///
/// We avoid removing individual instructions as we go since removing elements
/// from the middle of Vecs many times will be slower than a single call to `retain`.

Check warning on line 120 in compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Vecs)
instructions_to_remove: HashSet<InstructionId>,

/// Track a value's last load across all blocks.
Expand Down Expand Up @@ -415,13 +415,11 @@
let address = self.inserter.function.dfg.resolve(*address);
let value = self.inserter.function.dfg.resolve(*value);

// FIXME: This causes errors in the sha256 tests
//
// If there was another store to this instruction without any (unremoved) loads or
// function calls in-between, we can remove the previous store.
// if let Some(last_store) = references.last_stores.get(&address) {
// self.instructions_to_remove.insert(*last_store);
// }
if let Some(last_store) = references.last_stores.get(&address) {
self.instructions_to_remove.insert(*last_store);
}

if self.inserter.function.dfg.value_is_reference(value) {
if let Some(expression) = references.expressions.get(&value) {
Expand Down Expand Up @@ -614,6 +612,8 @@
map::Id,
types::Type,
},
opt::assert_normalized_ssa_equals,
Ssa,
};

#[test]
Expand Down Expand Up @@ -824,91 +824,53 @@
// is later stored in a successor block
#[test]
fn load_aliases_in_predecessor_block() {
// fn main {
// b0():
// v0 = allocate
// store Field 0 at v0
// v2 = allocate
// store v0 at v2
// v3 = load v2
// v4 = load v2
// jmp b1()
// b1():
// store Field 1 at v3
// store Field 2 at v4
// v7 = load v3
// v8 = eq v7, Field 2
// return
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id);

let v0 = builder.insert_allocate(Type::field());

let zero = builder.field_constant(0u128);
builder.insert_store(v0, zero);

let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field())));
builder.insert_store(v2, v0);

let v3 = builder.insert_load(v2, Type::field());
let v4 = builder.insert_load(v2, Type::field());
let b1 = builder.insert_block();
builder.terminate_with_jmp(b1, vec![]);

builder.switch_to_block(b1);

let one = builder.field_constant(1u128);
builder.insert_store(v3, one);

let two = builder.field_constant(2u128);
builder.insert_store(v4, two);

let v8 = builder.insert_load(v3, Type::field());
let _ = builder.insert_binary(v8, BinaryOp::Eq, two);

builder.terminate_with_return(vec![]);

let ssa = builder.finish();
assert_eq!(ssa.main().reachable_blocks().len(), 2);
let src = "
acir(inline) fn main f0 {
b0():
v0 = allocate -> &mut Field
store Field 0 at v0
v2 = allocate -> &mut &mut Field
store v0 at v2
v3 = load v2 -> &mut Field
v4 = load v2 -> &mut Field
jmp b1()
b1():
store Field 1 at v3
store Field 2 at v4
v7 = load v3 -> Field
v8 = eq v7, Field 2
return
}
";

// Expected result:
// acir fn main f0 {
// b0():
// v9 = allocate
// store Field 0 at v9
// v10 = allocate
// jmp b1()
// b1():
// return
// }
let ssa = ssa.mem2reg();
println!("{}", ssa);
let mut ssa = Ssa::from_str(src).unwrap();
let main = ssa.main_mut();

let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 2);
let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 6); // The final return is not counted

// All loads should be removed
assert_eq!(count_loads(main.entry_block(), &main.dfg), 0);
assert_eq!(count_loads(b1, &main.dfg), 0);

// The first store is not removed as it is used as a nested reference in another store.
// We would need to track whether the store where `v9` is the store value gets removed to know whether
// We would need to track whether the store where `v0` is the store value gets removed to know whether
// to remove it.
assert_eq!(count_stores(main.entry_block(), &main.dfg), 1);

// The first store in b1 is removed since there is another store to the same reference
// in the same block, and the store is not needed before the later store.
// The rest of the stores are also removed as no loads are done within any blocks
// to the stored values.
//
// NOTE: This store is not removed due to the FIXME when handling Instruction::Store.
assert_eq!(count_stores(b1, &main.dfg), 1);

let b1_instructions = main.dfg[b1].instructions();
let expected = "
acir(inline) fn main f0 {
b0():
v0 = allocate -> &mut Field
store Field 0 at v0
v2 = allocate -> &mut &mut Field
jmp b1()
b1():
return
}
";

// We expect the last eq to be optimized out, only the store from above remains
assert_eq!(b1_instructions.len(), 1);
let ssa = ssa.mem2reg();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
Expand Down