From ab7b1c6eeb30b6353a066e134425cce706b35275 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Fri, 1 Dec 2023 15:15:45 -0600 Subject: [PATCH 01/11] Fix try_unify --- compiler/noirc_evaluator/src/ssa.rs | 2 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 99 +++-- .../src/hir/def_collector/dc_crate.rs | 5 +- .../src/hir/resolution/functions.rs | 4 +- .../src/hir/resolution/resolver.rs | 6 +- .../noirc_frontend/src/hir/type_check/expr.rs | 16 +- .../noirc_frontend/src/hir/type_check/stmt.rs | 5 +- compiler/noirc_frontend/src/hir_def/types.rs | 356 +++++++++++------- .../src/monomorphization/mod.rs | 26 +- compiler/noirc_frontend/src/node_interner.rs | 59 ++- 10 files changed, 375 insertions(+), 203 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 8e1c62edc69..a6a6a48ea70 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -44,7 +44,7 @@ pub(crate) fn optimize_into_acir( let ssa_builder = SsaBuilder::new(program, print_ssa_passes)? .run_pass(Ssa::defunctionalize, "After Defunctionalization:") - .run_pass(Ssa::inline_functions, "After Inlining:") + .try_run_pass(Ssa::inline_functions, "After Inlining:")? // Run mem2reg with the CFG separated into blocks .run_pass(Ssa::mem2reg, "After Mem2Reg:") .try_run_pass(Ssa::evaluate_assert_constant, "After Assert Constant:")? diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index ed2484febac..e8a7d1d357c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -4,18 +4,21 @@ //! be a single function remaining when the pass finishes. use std::collections::{BTreeSet, HashSet}; -use iter_extended::{btree_map, vecmap}; - -use crate::ssa::{ - function_builder::FunctionBuilder, - ir::{ - basic_block::BasicBlockId, - dfg::{CallStack, InsertInstructionResult}, - function::{Function, FunctionId, RuntimeType}, - instruction::{Instruction, InstructionId, TerminatorInstruction}, - value::{Value, ValueId}, +use iter_extended::{try_btree_map, vecmap}; + +use crate::{ + errors::RuntimeError, + ssa::{ + function_builder::FunctionBuilder, + ir::{ + basic_block::BasicBlockId, + dfg::{CallStack, InsertInstructionResult}, + function::{Function, FunctionId, RuntimeType}, + instruction::{Instruction, InstructionId, TerminatorInstruction}, + value::{Value, ValueId}, + }, + ssa_gen::Ssa, }, - ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; @@ -36,13 +39,13 @@ impl Ssa { /// changes. This is because if the function's id later becomes known by a later /// pass, we would need to re-run all of inlining anyway to inline it, so we might /// as well save the work for later instead of performing it twice. - pub(crate) fn inline_functions(mut self) -> Ssa { - self.functions = btree_map(get_entry_point_functions(&self), |entry_point| { - let new_function = InlineContext::new(&self, entry_point).inline_all(&self); - (entry_point, new_function) - }); + pub(crate) fn inline_functions(mut self) -> Result { + self.functions = try_btree_map(get_entry_point_functions(&self), |entry_point| { + let new_function = InlineContext::new(&self, entry_point).inline_all(&self)?; + Ok::<(FunctionId, Function), RuntimeError>((entry_point, new_function)) + })?; - self + Ok(self) } } @@ -119,7 +122,7 @@ impl InlineContext { } /// Start inlining the entry point function and all functions reachable from it. - fn inline_all(mut self, ssa: &Ssa) -> Function { + fn inline_all(mut self, ssa: &Ssa) -> Result { let entry_point = &ssa.functions[&self.entry_point]; let mut context = PerFunctionContext::new(&mut self, entry_point); @@ -138,12 +141,12 @@ impl InlineContext { } context.blocks.insert(context.source_function.entry_block(), entry_block); - context.inline_blocks(ssa); + context.inline_blocks(ssa)?; // Finally, we should have 1 function left representing the inlined version of the target function. let mut new_ssa = self.builder.finish(); assert_eq!(new_ssa.functions.len(), 1); - new_ssa.functions.pop_first().unwrap().1 + Ok(new_ssa.functions.pop_first().unwrap().1) } /// Inlines a function into the current function and returns the translated return values @@ -153,7 +156,7 @@ impl InlineContext { ssa: &Ssa, id: FunctionId, arguments: &[ValueId], - ) -> Vec { + ) -> Result, RuntimeError> { self.recursion_level += 1; if self.recursion_level > RECURSION_LIMIT { @@ -172,9 +175,9 @@ impl InlineContext { let current_block = context.context.builder.current_block(); context.blocks.insert(source_function.entry_block(), current_block); - let return_values = context.inline_blocks(ssa); + let return_values = context.inline_blocks(ssa)?; self.recursion_level -= 1; - return_values + Ok(return_values) } } @@ -278,7 +281,7 @@ impl<'function> PerFunctionContext<'function> { } /// Inline all reachable blocks within the source_function into the destination function. - fn inline_blocks(&mut self, ssa: &Ssa) -> Vec { + fn inline_blocks(&mut self, ssa: &Ssa) -> Result, RuntimeError> { let mut seen_blocks = HashSet::new(); let mut block_queue = vec![self.source_function.entry_block()]; @@ -294,7 +297,7 @@ impl<'function> PerFunctionContext<'function> { self.context.builder.switch_to_block(translated_block_id); seen_blocks.insert(source_block_id); - self.inline_block_instructions(ssa, source_block_id); + self.inline_block_instructions(ssa, source_block_id)?; if let Some((block, values)) = self.handle_terminator_instruction(source_block_id, &mut block_queue) @@ -303,7 +306,7 @@ impl<'function> PerFunctionContext<'function> { } } - self.handle_function_returns(function_returns) + Ok(self.handle_function_returns(function_returns)) } /// Handle inlining a function's possibly multiple return instructions. @@ -339,13 +342,17 @@ impl<'function> PerFunctionContext<'function> { /// Inline each instruction in the given block into the function being inlined into. /// This may recurse if it finds another function to inline if a call instruction is within this block. - fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) { + fn inline_block_instructions( + &mut self, + ssa: &Ssa, + block_id: BasicBlockId, + ) -> Result<(), RuntimeError> { let block = &self.source_function.dfg[block_id]; for id in block.instructions() { match &self.source_function.dfg[*id] { Instruction::Call { func, arguments } => match self.get_function(*func) { Some(function) => match ssa.functions[&function].runtime() { - RuntimeType::Acir => self.inline_function(ssa, *id, function, arguments), + RuntimeType::Acir => self.inline_function(ssa, *id, function, arguments)?, RuntimeType::Brillig => self.push_instruction(*id), }, None => self.push_instruction(*id), @@ -353,6 +360,7 @@ impl<'function> PerFunctionContext<'function> { _ => self.push_instruction(*id), } } + Ok(()) } /// Inline a function call and remember the inlined return values in the values map @@ -362,7 +370,7 @@ impl<'function> PerFunctionContext<'function> { call_id: InstructionId, function: FunctionId, arguments: &[ValueId], - ) { + ) -> Result<(), RuntimeError> { let old_results = self.source_function.dfg.instruction_results(call_id); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); @@ -374,14 +382,32 @@ impl<'function> PerFunctionContext<'function> { self.context.call_stack.push_back(location); } - let new_results = self.context.inline_function(ssa, function, &arguments); + let c = self.context.call_stack.clone(); + + let new_results = self.context.inline_function(ssa, function, &arguments)?; if has_location { self.context.call_stack.pop_back(); } let new_results = InsertInstructionResult::Results(call_id, &new_results); + + if old_results.len() != new_results.len() { + println!( + "Function {} has {} results, but call site has {}", + self.source_function.name(), + new_results.len(), + old_results.len() + ); + + return Err(RuntimeError::UnInitialized { + name: "bad inlined function".into(), + call_stack: c, + }); + } + Self::insert_new_instruction_results(&mut self.values, old_results, new_results); + Ok(()) } /// Push the given instruction from the source_function into the current block of the @@ -402,6 +428,11 @@ impl<'function> PerFunctionContext<'function> { self.context.builder.set_call_stack(call_stack); let new_results = self.context.builder.insert_instruction(instruction, ctrl_typevars); + + if results.len() != new_results.len() { + println!("In function {}", self.source_function.name()); + } + Self::insert_new_instruction_results(&mut self.values, &results, new_results); } @@ -543,7 +574,7 @@ mod test { let ssa = builder.finish(); assert_eq!(ssa.functions.len(), 2); - let inlined = ssa.inline_functions(); + let inlined = ssa.inline_functions().unwrap(); assert_eq!(inlined.functions.len(), 1); } @@ -609,7 +640,7 @@ mod test { let ssa = builder.finish(); assert_eq!(ssa.functions.len(), 4); - let inlined = ssa.inline_functions(); + let inlined = ssa.inline_functions().unwrap(); assert_eq!(inlined.functions.len(), 1); } @@ -683,7 +714,7 @@ mod test { // b6(): // return Field 120 // } - let inlined = ssa.inline_functions(); + let inlined = ssa.inline_functions().unwrap(); assert_eq!(inlined.functions.len(), 1); let main = inlined.main(); @@ -766,7 +797,7 @@ mod test { builder.switch_to_block(join_block); builder.terminate_with_return(vec![join_param]); - let ssa = builder.finish().inline_functions(); + let ssa = builder.finish().inline_functions().unwrap(); // Expected result: // fn main f3 { // b0(v0: u1): diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 86122530cde..6a56ebc5c06 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -415,7 +415,8 @@ pub(crate) fn check_methods_signatures( let self_type = resolver.get_self_type().expect("trait impl must have a Self type"); // Temporarily bind the trait's Self type to self_type so we can type check - let _ = the_trait.self_type_typevar.borrow_mut().bind_to(self_type.clone(), the_trait.span); + the_trait.self_type_typevar.try_bind(self_type.clone(), the_trait.span) + .expect("Failed binding Self type of trait"); for (file_id, func_id) in impl_methods { let impl_method = resolver.interner.function_meta(func_id); @@ -494,5 +495,5 @@ pub(crate) fn check_methods_signatures( } } - the_trait.self_type_typevar.borrow_mut().unbind(the_trait.self_type_typevar_id); + the_trait.self_type_typevar.unbind(the_trait.self_type_typevar_id); } diff --git a/compiler/noirc_frontend/src/hir/resolution/functions.rs b/compiler/noirc_frontend/src/hir/resolution/functions.rs index 387f94e129c..e63de9b9173 100644 --- a/compiler/noirc_frontend/src/hir/resolution/functions.rs +++ b/compiler/noirc_frontend/src/hir/resolution/functions.rs @@ -11,7 +11,7 @@ use crate::{ def_map::{CrateDefMap, ModuleId}, }, node_interner::{FuncId, NodeInterner, TraitImplId}, - Shared, Type, TypeBinding, + Type, TypeVariable, }; use super::{path_resolver::StandardPathResolver, resolver::Resolver}; @@ -24,7 +24,7 @@ pub(crate) fn resolve_function_set( mut unresolved_functions: UnresolvedFunctions, self_type: Option, trait_impl_id: Option, - impl_generics: Vec<(Rc, Shared, Span)>, + impl_generics: Vec<(Rc, TypeVariable, Span)>, errors: &mut Vec<(CompilationError, FileId)>, ) -> Vec<(FileId, FuncId)> { let file_id = unresolved_functions.file_id; diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 52d592404c8..2ac6f817720 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -571,7 +571,7 @@ impl<'a> Resolver<'a> { match length { None => { let id = self.interner.next_type_variable_id(); - let typevar = Shared::new(TypeBinding::Unbound(id)); + let typevar = TypeVariable::unbound(id); new_variables.push((id, typevar.clone())); // 'Named'Generic is a bit of a misnomer here, we want a type variable that @@ -681,7 +681,7 @@ impl<'a> Resolver<'a> { vecmap(generics, |generic| { // Map the generic to a fresh type variable let id = self.interner.next_type_variable_id(); - let typevar = Shared::new(TypeBinding::Unbound(id)); + let typevar = TypeVariable::unbound(id); let span = generic.0.span(); // Check for name collisions of this generic @@ -929,7 +929,7 @@ impl<'a> Resolver<'a> { fn find_numeric_generics_in_type( typ: &Type, - found: &mut BTreeMap>, + found: &mut BTreeMap, ) { match typ { Type::FieldElement diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 74f076212fa..386bef50060 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -11,7 +11,7 @@ use crate::{ types::Type, }, node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitMethodId}, - BinaryOpKind, Signedness, TypeBinding, TypeVariableKind, UnaryOp, + BinaryOpKind, Signedness, TypeBinding, TypeBindings, TypeVariableKind, UnaryOp, }; use super::{errors::TypeCheckError, TypeChecker}; @@ -778,7 +778,11 @@ impl<'interner> TypeChecker<'interner> { })); } - if other.try_bind_to_polymorphic_int(int).is_ok() || other == &Type::Error { + let mut bindings = TypeBindings::new(); + if other.try_bind_to_polymorphic_int(int, &mut bindings).is_ok() + || other == &Type::Error + { + Type::apply_type_bindings(bindings); Ok(Bool) } else { Err(TypeCheckError::TypeMismatchWithSource { @@ -1009,7 +1013,7 @@ impl<'interner> TypeChecker<'interner> { let env_type = self.interner.next_type_variable(); let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); - if let Err(error) = binding.borrow_mut().bind_to(expected, span) { + if let Err(error) = binding.try_bind(expected, span) { self.errors.push(error); } ret @@ -1077,7 +1081,11 @@ impl<'interner> TypeChecker<'interner> { })); } - if other.try_bind_to_polymorphic_int(int).is_ok() || other == &Type::Error { + let mut bindings = TypeBindings::new(); + if other.try_bind_to_polymorphic_int(int, &mut bindings).is_ok() + || other == &Type::Error + { + Type::apply_type_bindings(bindings); Ok(other.clone()) } else { Err(TypeCheckError::TypeMismatchWithSource { diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index e289ae0fc9d..78c287df926 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -8,7 +8,6 @@ use crate::hir_def::stmt::{ }; use crate::hir_def::types::Type; use crate::node_interner::{DefinitionId, ExprId, StmtId}; -use crate::{Shared, TypeBinding, TypeVariableKind}; use super::errors::{Source, TypeCheckError}; use super::TypeChecker; @@ -71,9 +70,7 @@ impl<'interner> TypeChecker<'interner> { expr_span: range_span, }); - let fresh_id = self.interner.next_type_variable_id(); - let type_variable = Shared::new(TypeBinding::Unbound(fresh_id)); - let expected_type = Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField); + let expected_type = Type::polymorphic_integer(self.interner); self.unify(&start_range_type, &expected_type, || { TypeCheckError::TypeCannotBeUsed { diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 46818626a16..8c36132b584 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -298,7 +298,7 @@ impl std::fmt::Display for TypeAliasType { write!(f, "{}", self.name)?; if !self.generics.is_empty() { - let generics = vecmap(&self.generics, |(_, binding)| binding.borrow().to_string()); + let generics = vecmap(&self.generics, |(_, binding)| binding.0.borrow().to_string()); write!(f, "{}", generics.join(", "))?; } @@ -413,7 +413,46 @@ pub enum TypeVariableKind { /// A TypeVariable is a mutable reference that is either /// bound to some type, or unbound with a given TypeVariableId. -pub type TypeVariable = Shared; +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct TypeVariable(Shared); + +impl TypeVariable { + pub fn unbound(id: TypeVariableId) -> Self { + TypeVariable(Shared::new(TypeBinding::Unbound(id))) + } + + pub fn bind(&self, typ: Type) { + let id = match &*self.0.borrow() { + TypeBinding::Bound(binding) => unreachable!("Expected unbound, found bound to {binding}"), + TypeBinding::Unbound(id) => *id, + }; + + assert!(!typ.occurs(id)); + *self.0.borrow_mut() = TypeBinding::Bound(typ); + } + + pub fn try_bind(&self, binding: Type, span: Span) -> Result<(), TypeCheckError> { + let id = match &*self.0.borrow() { + TypeBinding::Bound(binding) => unreachable!("Expected unbound, found bound to {binding}"), + TypeBinding::Unbound(id) => *id, + }; + + if binding.occurs(id) { + Err(TypeCheckError::TypeAnnotationsNeeded { span }) + } else { + *self.0.borrow_mut() = TypeBinding::Bound(binding); + Ok(()) + } + } + + pub fn unbind(&self, id: TypeVariableId) { + *self.0.borrow_mut() = TypeBinding::Unbound(id); + } + + pub fn borrow(&self) -> std::cell::Ref { + self.0.borrow() + } +} /// TypeBindings are the mutable insides of a TypeVariable. /// They are either bound to some type, or are unbound. @@ -427,24 +466,6 @@ impl TypeBinding { pub fn is_unbound(&self) -> bool { matches!(self, TypeBinding::Unbound(_)) } - - pub fn bind_to(&mut self, binding: Type, span: Span) -> Result<(), TypeCheckError> { - match self { - TypeBinding::Bound(_) => panic!("Tried to bind an already bound type variable!"), - TypeBinding::Unbound(id) => { - if binding.occurs(*id) { - Err(TypeCheckError::TypeAnnotationsNeeded { span }) - } else { - *self = TypeBinding::Bound(binding); - Ok(()) - } - } - } - } - - pub fn unbind(&mut self, id: TypeVariableId) { - *self = TypeBinding::Unbound(id); - } } /// A unique ID used to differentiate different type variables @@ -461,7 +482,8 @@ impl Type { } pub fn type_variable(id: TypeVariableId) -> Type { - Type::TypeVariable(Shared::new(TypeBinding::Unbound(id)), TypeVariableKind::Normal) + let var = TypeVariable(Shared::new(TypeBinding::Unbound(id))); + Type::TypeVariable(var, TypeVariableKind::Normal) } /// Returns a TypeVariable(_, TypeVariableKind::Constant(length)) to bind to @@ -469,13 +491,15 @@ impl Type { pub fn constant_variable(length: u64, interner: &mut NodeInterner) -> Type { let id = interner.next_type_variable_id(); let kind = TypeVariableKind::Constant(length); - Type::TypeVariable(Shared::new(TypeBinding::Unbound(id)), kind) + let var = TypeVariable(Shared::new(TypeBinding::Unbound(id))); + Type::TypeVariable(var, kind) } pub fn polymorphic_integer(interner: &mut NodeInterner) -> Type { let id = interner.next_type_variable_id(); let kind = TypeVariableKind::IntegerOrField; - Type::TypeVariable(Shared::new(TypeBinding::Unbound(id)), kind) + let var = TypeVariable(Shared::new(TypeBinding::Unbound(id))); + Type::TypeVariable(var, kind) } /// A bit of an awkward name for this function - this function returns @@ -484,7 +508,7 @@ impl Type { /// they shouldn't be bound over until monomorphization. pub fn is_bindable(&self) -> bool { match self { - Type::TypeVariable(binding, _) => match &*binding.borrow() { + Type::TypeVariable(binding, _) => match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.is_bindable(), TypeBinding::Unbound(_) => true, }, @@ -508,7 +532,7 @@ impl Type { // True if the given type is a NamedGeneric with the target_id let named_generic_id_matches_target = |typ: &Type| { if let Type::NamedGeneric(type_variable, _) = typ { - match &*type_variable.borrow() { + match &*type_variable.0.borrow() { TypeBinding::Bound(_) => { unreachable!("Named generics should not be bound until monomorphization") } @@ -608,7 +632,7 @@ impl Type { match self { Type::Forall(generics, _) => generics.len(), Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => { - match &*type_variable.borrow() { + match &*type_variable.0.borrow() { TypeBinding::Bound(binding) => binding.generic_count(), TypeBinding::Unbound(_) => 0, } @@ -661,23 +685,23 @@ impl std::fmt::Display for Type { Signedness::Signed => write!(f, "i{num_bits}"), Signedness::Unsigned => write!(f, "u{num_bits}"), }, - Type::TypeVariable(id, TypeVariableKind::Normal) => write!(f, "{}", id.borrow()), + Type::TypeVariable(var, TypeVariableKind::Normal) => write!(f, "{}", var.0.borrow()), Type::TypeVariable(binding, TypeVariableKind::IntegerOrField) => { - if let TypeBinding::Unbound(_) = &*binding.borrow() { + if let TypeBinding::Unbound(_) = &*binding.0.borrow() { // Show a Field by default if this TypeVariableKind::IntegerOrField is unbound, since that is // what they bind to by default anyway. It is less confusing than displaying it // as a generic. write!(f, "Field") } else { - write!(f, "{}", binding.borrow()) + write!(f, "{}", binding.0.borrow()) } } Type::TypeVariable(binding, TypeVariableKind::Constant(n)) => { - if let TypeBinding::Unbound(_) = &*binding.borrow() { + if let TypeBinding::Unbound(_) = &*binding.0.borrow() { // TypeVariableKind::Constant(n) binds to Type::Constant(n) by default, so just show that. write!(f, "{n}") } else { - write!(f, "{}", binding.borrow()) + write!(f, "{}", binding.0.borrow()) } } Type::Struct(s, args) => { @@ -702,7 +726,7 @@ impl std::fmt::Display for Type { } Type::Unit => write!(f, "()"), Type::Error => write!(f, "error"), - Type::NamedGeneric(binding, name) => match &*binding.borrow() { + Type::NamedGeneric(binding, name) => match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.fmt(f), TypeBinding::Unbound(_) if name.is_empty() => write!(f, "_"), TypeBinding::Unbound(_) => write!(f, "{name}"), @@ -761,58 +785,65 @@ pub struct UnificationError; impl Type { /// Try to bind a MaybeConstant variable to self, succeeding if self is a Constant, - /// MaybeConstant, or type variable. - pub fn try_bind_to_maybe_constant( + /// MaybeConstant, or type variable. If successful, the binding is placed in the + /// given TypeBindings map rather than linked immediately. + fn try_bind_to_maybe_constant( &self, var: &TypeVariable, target_length: u64, + bindings: &mut TypeBindings, ) -> Result<(), UnificationError> { - let target_id = match &*var.borrow() { + let target_id = match &*var.0.borrow() { TypeBinding::Bound(_) => unreachable!(), TypeBinding::Unbound(id) => *id, }; - match self { - Type::Constant(length) if *length == target_length => { - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + match self.substitute(bindings) { + Type::Constant(length) if length == target_length => { + assert!(!self.occurs(target_id)); + bindings.insert(target_id, (var.clone(), self.clone())); Ok(()) } Type::NotConstant => { - *var.borrow_mut() = TypeBinding::Bound(Type::NotConstant); + assert!(!self.occurs(target_id)); + bindings.insert(target_id, (var.clone(), Type::NotConstant)); Ok(()) } - Type::TypeVariable(binding, kind) => { - let borrow = binding.borrow(); + // A TypeVariable is less specific than a MaybeConstant, so we bind + // to the other type variable instead. + Type::TypeVariable(new_var, kind) => { + let borrow = new_var.0.borrow(); match &*borrow { - TypeBinding::Bound(typ) => typ.try_bind_to_maybe_constant(var, target_length), + TypeBinding::Bound(typ) => { + typ.try_bind_to_maybe_constant(var, target_length, bindings) + } // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), - TypeBinding::Unbound(_) => match kind { + TypeBinding::Unbound(new_target_id) => match kind { TypeVariableKind::Normal => { - drop(borrow); let clone = Type::TypeVariable( var.clone(), TypeVariableKind::Constant(target_length), ); - *binding.borrow_mut() = TypeBinding::Bound(clone); + assert!(!clone.occurs(*new_target_id)); + bindings.insert(*new_target_id, (new_var.clone(), clone)); Ok(()) } - TypeVariableKind::Constant(length) if *length == target_length => { - drop(borrow); + TypeVariableKind::Constant(length) if length == target_length => { let clone = Type::TypeVariable( var.clone(), TypeVariableKind::Constant(target_length), ); - *binding.borrow_mut() = TypeBinding::Bound(clone); + assert!(!clone.occurs(*new_target_id)); + bindings.insert(*new_target_id, (new_var.clone(), clone)); Ok(()) } // The lengths don't match, but neither are set in stone so we can // just set them both to NotConstant. See issue 2370 TypeVariableKind::Constant(_) => { // *length != target_length - drop(borrow); - *var.borrow_mut() = TypeBinding::Bound(Type::NotConstant); - *binding.borrow_mut() = TypeBinding::Bound(Type::NotConstant); + bindings.insert(target_id, (var.clone(), Type::NotConstant)); + bindings.insert(*new_target_id, (new_var.clone(), Type::NotConstant)); Ok(()) } TypeVariableKind::IntegerOrField => Err(UnificationError), @@ -824,44 +855,50 @@ impl Type { } /// Try to bind a PolymorphicInt variable to self, succeeding if self is an integer, field, - /// other PolymorphicInt type, or type variable. - pub fn try_bind_to_polymorphic_int(&self, var: &TypeVariable) -> Result<(), UnificationError> { - let target_id = match &*var.borrow() { + /// other PolymorphicInt type, or type variable. If successful, the binding is placed in the + /// given TypeBindings map rather than linked immediately. + pub fn try_bind_to_polymorphic_int( + &self, + var: &TypeVariable, + bindings: &mut TypeBindings, + ) -> Result<(), UnificationError> { + let target_id = match &*var.0.borrow() { TypeBinding::Bound(_) => unreachable!(), TypeBinding::Unbound(id) => *id, }; - match self { + match self.substitute(bindings) { Type::FieldElement | Type::Integer(..) => { - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + assert!(!self.occurs(target_id)); + bindings.insert(target_id, (var.clone(), self.clone())); Ok(()) } Type::TypeVariable(self_var, TypeVariableKind::IntegerOrField) => { - let borrow = self_var.borrow(); + let borrow = self_var.0.borrow(); match &*borrow { - TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var), + TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var, bindings), // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), TypeBinding::Unbound(_) => { - drop(borrow); - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + assert!(!self.occurs(target_id)); + bindings.insert(target_id, (var.clone(), self.clone())); Ok(()) } } } Type::TypeVariable(binding, TypeVariableKind::Normal) => { - let borrow = binding.borrow(); + let borrow = binding.0.borrow(); match &*borrow { - TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var), + TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var, bindings), // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), - TypeBinding::Unbound(_) => { - drop(borrow); - // PolymorphicInt is more specific than TypeVariable so we bind the type - // variable to PolymorphicInt instead. + TypeBinding::Unbound(new_target_id) => { + // IntegerOrField is more specific than TypeVariable so we bind the type + // variable to IntegerOrField instead. let clone = Type::TypeVariable(var.clone(), TypeVariableKind::IntegerOrField); - *binding.borrow_mut() = TypeBinding::Bound(clone); + assert!(!clone.occurs(*new_target_id)); + bindings.insert(*new_target_id, (binding.clone(), clone)); Ok(()) } } @@ -870,102 +907,114 @@ impl Type { } } - pub fn try_bind_to(&self, var: &TypeVariable) -> Result<(), UnificationError> { - let target_id = match &*var.borrow() { + /// Try to bind the given type variable to self. Although the given type variable + /// is expected to be of TypeVariableKind::Normal, this binding can still fail + /// if the given type variable occurs within `self` as that would create a recursive type. + /// + /// If successful, the binding is placed in the + /// given TypeBindings map rather than linked immediately. + fn try_bind_to( + &self, + var: &TypeVariable, + bindings: &mut TypeBindings, + ) -> Result<(), UnificationError> { + let target_id = match &*var.0.borrow() { TypeBinding::Bound(_) => unreachable!(), TypeBinding::Unbound(id) => *id, }; - if let Some(binding) = self.get_inner_type_variable() { + let this = self.substitute(bindings); + + if let Some(binding) = this.get_inner_type_variable() { match &*binding.borrow() { - TypeBinding::Bound(typ) => return typ.try_bind_to(var), + TypeBinding::Bound(typ) => return typ.try_bind_to(var, bindings), // Don't recursively bind the same id to itself TypeBinding::Unbound(id) if *id == target_id => return Ok(()), _ => (), } } - // Check if the target id occurs within self before binding. Otherwise this could + // Check if the target id occurs within `this` before binding. Otherwise this could // cause infinitely recursive types - if self.occurs(target_id) { + if this.occurs(target_id) { Err(UnificationError) } else { - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + assert!(!this.occurs(target_id)); + bindings.insert(target_id, (var.clone(), this.clone())); Ok(()) } } fn get_inner_type_variable(&self) -> Option> { match self { - Type::TypeVariable(var, _) | Type::NamedGeneric(var, _) => Some(var.clone()), + Type::TypeVariable(var, _) | Type::NamedGeneric(var, _) => Some(var.0.clone()), _ => None, } } /// Try to unify this type with another, setting any type variables found - /// equal to the other type in the process. Unification is more strict - /// than sub-typing but less strict than Eq. Returns true if the unification - /// succeeded. Note that any bindings performed in a failed unification are - /// not undone. This may cause further type errors later on. + /// equal to the other type in the process. When comparing types, unification + /// (including try_unify) are almost always preferred over Type::eq as unification + /// will correctly handle generic types. pub fn unify( &self, expected: &Type, errors: &mut Vec, make_error: impl FnOnce() -> TypeCheckError, ) { - if let Err(UnificationError) = self.try_unify(expected) { - errors.push(make_error()); + let mut bindings = TypeBindings::new(); + + match self.try_unify(expected, &mut bindings) { + Ok(()) => { + // Commit any type bindings on success + Self::apply_type_bindings(bindings); + } + Err(UnificationError) => errors.push(make_error()), } } /// `try_unify` is a bit of a misnomer since although errors are not committed, /// any unified bindings are on success. - pub fn try_unify(&self, other: &Type) -> Result<(), UnificationError> { + pub fn try_unify( + &self, + other: &Type, + bindings: &mut TypeBindings, + ) -> Result<(), UnificationError> { use Type::*; use TypeVariableKind as Kind; match (self, other) { (Error, _) | (_, Error) => Ok(()), - (TypeVariable(binding, Kind::IntegerOrField), other) - | (other, TypeVariable(binding, Kind::IntegerOrField)) => { - // If it is already bound, unify against what it is bound to - if let TypeBinding::Bound(link) = &*binding.borrow() { - return link.try_unify(other); - } - - // Otherwise, check it is unified against an integer and bind it - other.try_bind_to_polymorphic_int(binding) + (TypeVariable(var, Kind::IntegerOrField), other) + | (other, TypeVariable(var, Kind::IntegerOrField)) => { + other.try_unify_to_type_variable(var, bindings, |bindings| { + other.try_bind_to_polymorphic_int(var, bindings) + }) } - (TypeVariable(binding, Kind::Normal), other) - | (other, TypeVariable(binding, Kind::Normal)) => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - return link.try_unify(other); - } - - other.try_bind_to(binding) + (TypeVariable(var, Kind::Normal), other) | (other, TypeVariable(var, Kind::Normal)) => { + other.try_unify_to_type_variable(var, bindings, |bindings| { + other.try_bind_to(var, bindings) + }) } - (TypeVariable(binding, Kind::Constant(length)), other) - | (other, TypeVariable(binding, Kind::Constant(length))) => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - return link.try_unify(other); - } - - other.try_bind_to_maybe_constant(binding, *length) - } + (TypeVariable(var, Kind::Constant(length)), other) + | (other, TypeVariable(var, Kind::Constant(length))) => other + .try_unify_to_type_variable(var, bindings, |bindings| { + other.try_bind_to_maybe_constant(var, *length, bindings) + }), (Array(len_a, elem_a), Array(len_b, elem_b)) => { - len_a.try_unify(len_b)?; - elem_a.try_unify(elem_b) + len_a.try_unify(len_b, bindings)?; + elem_a.try_unify(elem_b, bindings) } - (String(len_a), String(len_b)) => len_a.try_unify(len_b), + (String(len_a), String(len_b)) => len_a.try_unify(len_b, bindings), (FmtString(len_a, elements_a), FmtString(len_b, elements_b)) => { - len_a.try_unify(len_b)?; - elements_a.try_unify(elements_b) + len_a.try_unify(len_b, bindings)?; + elements_a.try_unify(elements_b, bindings) } (Tuple(elements_a), Tuple(elements_b)) => { @@ -973,7 +1022,7 @@ impl Type { Err(UnificationError) } else { for (a, b) in elements_a.iter().zip(elements_b) { - a.try_unify(b)?; + a.try_unify(b, bindings)?; } Ok(()) } @@ -985,7 +1034,7 @@ impl Type { (Struct(id_a, args_a), Struct(id_b, args_b)) => { if id_a == id_b && args_a.len() == args_b.len() { for (a, b) in args_a.iter().zip(args_b) { - a.try_unify(b)?; + a.try_unify(b, bindings)?; } Ok(()) } else { @@ -993,17 +1042,17 @@ impl Type { } } - (NamedGeneric(binding, _), other) if !binding.borrow().is_unbound() => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - link.try_unify(other) + (NamedGeneric(binding, _), other) if !binding.0.borrow().is_unbound() => { + if let TypeBinding::Bound(link) = &*binding.0.borrow() { + link.try_unify(other, bindings) } else { unreachable!("If guard ensures binding is bound") } } - (other, NamedGeneric(binding, _)) if !binding.borrow().is_unbound() => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - other.try_unify(link) + (other, NamedGeneric(binding, _)) if !binding.0.borrow().is_unbound() => { + if let TypeBinding::Bound(link) = &*binding.0.borrow() { + other.try_unify(link, bindings) } else { unreachable!("If guard ensures binding is bound") } @@ -1011,8 +1060,8 @@ impl Type { (NamedGeneric(binding_a, name_a), NamedGeneric(binding_b, name_b)) => { // Unbound NamedGenerics are caught by the checks above - assert!(binding_a.borrow().is_unbound()); - assert!(binding_b.borrow().is_unbound()); + assert!(binding_a.0.borrow().is_unbound()); + assert!(binding_b.0.borrow().is_unbound()); if name_a == name_b { Ok(()) @@ -1024,17 +1073,19 @@ impl Type { (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { if params_a.len() == params_b.len() { for (a, b) in params_a.iter().zip(params_b.iter()) { - a.try_unify(b)?; + a.try_unify(b, bindings)?; } - env_a.try_unify(env_b)?; - ret_b.try_unify(ret_a) + env_a.try_unify(env_b, bindings)?; + ret_b.try_unify(ret_a, bindings) } else { Err(UnificationError) } } - (MutableReference(elem_a), MutableReference(elem_b)) => elem_a.try_unify(elem_b), + (MutableReference(elem_a), MutableReference(elem_b)) => { + elem_a.try_unify(elem_b, bindings) + } (other_a, other_b) => { if other_a == other_b { @@ -1046,6 +1097,34 @@ impl Type { } } + /// Try to unify a type variable to `self`. + /// This is a helper function factored out from try_unify. + fn try_unify_to_type_variable( + &self, + type_variable: &TypeVariable, + bindings: &mut TypeBindings, + + // Bind the type variable to a type. This is factored out since depending on the + // TypeVariableKind, there are different methods to check whether the variable can + // bind to the given type or not. + bind_variable: impl FnOnce(&mut TypeBindings) -> Result<(), UnificationError>, + ) -> Result<(), UnificationError> { + match &*type_variable.0.borrow() { + // If it is already bound, unify against what it is bound to + TypeBinding::Bound(link) => link.try_unify(self, bindings), + TypeBinding::Unbound(id) => { + // We may have already "bound" this type variable in this call to + // try_unify, so check those bindings as well. + match bindings.get(id) { + Some((_, binding)) => binding.clone().try_unify(self, bindings), + + // Otherwise, bind it + None => bind_variable(bindings), + } + } + } + } + /// Similar to `unify` but if the check fails this will attempt to coerce the /// argument to the target type. When this happens, the given expression is wrapped in /// a new expression to convert its type. E.g. `array` -> `array.as_slice()` @@ -1059,7 +1138,9 @@ impl Type { errors: &mut Vec, make_error: impl FnOnce() -> TypeCheckError, ) { - if let Err(UnificationError) = self.try_unify(expected) { + let mut bindings = TypeBindings::new(); + + if let Err(UnificationError) = self.try_unify(expected, &mut bindings) { if !self.try_array_to_slice_coercion(expected, expression, interner) { errors.push(make_error()); } @@ -1085,8 +1166,10 @@ impl Type { if matches!(size1, Type::Constant(_)) && matches!(size2, Type::NotConstant) { // Still have to ensure the element types match. // Don't need to issue an error here if not, it will be done in unify_with_coercions - if element1.try_unify(element2).is_ok() { + let mut bindings = TypeBindings::new(); + if element1.try_unify(element2, &mut bindings).is_ok() { convert_array_expression_to_slice(expression, this, target, interner); + Self::apply_type_bindings(bindings); return true; } } @@ -1094,6 +1177,15 @@ impl Type { false } + /// Apply the given type bindings, making them permanently visible for each + /// clone of each type variable bound. + pub fn apply_type_bindings(bindings: TypeBindings) { + for (id, (type_variable, binding)) in &bindings { + assert!(!binding.occurs(*id)); + type_variable.bind(binding.clone()); + } + } + /// If this type is a Type::Constant (used in array lengths), or is bound /// to a Type::Constant, return the constant as a u64. pub fn evaluate_to_u64(&self) -> Option { @@ -1229,7 +1321,7 @@ impl Type { } Type::Forall(_, typ) => typ.find_all_unbound_type_variables(type_variables), Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => { - match &*type_variable.borrow() { + match &*type_variable.0.borrow() { TypeBinding::Bound(binding) => { binding.find_all_unbound_type_variables(type_variables); } @@ -1251,7 +1343,7 @@ impl Type { return self.clone(); } - let substitute_binding = |binding: &TypeVariable| match &*binding.borrow() { + let substitute_binding = |binding: &TypeVariable| match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.substitute(type_bindings), TypeBinding::Unbound(id) => match type_bindings.get(id) { Some((_, binding)) => binding.clone(), @@ -1331,7 +1423,7 @@ impl Type { Type::Struct(_, generic_args) => generic_args.iter().any(|arg| arg.occurs(target_id)), Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), Type::NamedGeneric(binding, _) | Type::TypeVariable(binding, _) => { - match &*binding.borrow() { + match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.occurs(target_id), TypeBinding::Unbound(id) => *id == target_id, } @@ -1380,7 +1472,7 @@ impl Type { } Tuple(args) => Tuple(vecmap(args, |arg| arg.follow_bindings())), TypeVariable(var, _) | NamedGeneric(var, _) => { - if let TypeBinding::Bound(typ) = &*var.borrow() { + if let TypeBinding::Bound(typ) = &*var.0.borrow() { return typ.follow_bindings(); } self.clone() @@ -1485,7 +1577,7 @@ impl From<&Type> for PrintableType { Signedness::Signed => PrintableType::SignedInteger { width: *bit_width }, }, Type::TypeVariable(binding, TypeVariableKind::IntegerOrField) => { - match &*binding.borrow() { + match &*binding.0.borrow() { TypeBinding::Bound(typ) => typ.into(), TypeBinding::Unbound(_) => Type::default_int_type().into(), } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 57e4e6cdeb0..248dd533ec3 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -215,7 +215,7 @@ impl<'interner> Monomorphizer<'interner> { fn function(&mut self, f: node_interner::FuncId, id: FuncId) { if let Some((self_type, trait_id)) = self.interner.get_function_trait(&f) { let the_trait = self.interner.get_trait(trait_id); - *the_trait.self_type_typevar.borrow_mut() = TypeBinding::Bound(self_type); + the_trait.self_type_typevar.bind(self_type); } let meta = self.interner.function_meta(&f); @@ -716,7 +716,7 @@ impl<'interner> Monomorphizer<'interner> { // like automatic solving of traits. It should be fine since it is strictly // after type checking, but care should be taken that it doesn't change which // impls are chosen. - *binding.borrow_mut() = TypeBinding::Bound(HirType::default_int_type()); + binding.bind(HirType::default_int_type()); ast::Type::Field } @@ -740,7 +740,7 @@ impl<'interner> Monomorphizer<'interner> { }; let monomorphized_default = self.convert_type(&default); - *binding.borrow_mut() = TypeBinding::Bound(default); + binding.bind(default); monomorphized_default } @@ -830,6 +830,13 @@ impl<'interner> Monomorphizer<'interner> { node_interner::TraitImplKind::Assumed { object_type } => { match self.interner.lookup_trait_implementation(&object_type, method.trait_id) { Ok(TraitImplKind::Normal(impl_id)) => { + let id = self.interner.get_trait_implementation(impl_id).borrow().methods + [method.method_index]; + + let name = self.interner.function_name(&id); + + println!("Looked up assumed trait impl for {}::{name}", object_type); + self.interner.get_trait_implementation(impl_id).borrow().methods [method.method_index] } @@ -858,6 +865,8 @@ impl<'interner> Monomorphizer<'interner> { let the_trait = self.interner.get_trait(method.trait_id); + println!("Function type = {:?} => {}", function_type, self.convert_type(&function_type)); + ast::Expression::Ident(ast::Ident { definition: Definition::Function(func_id), mutable: false, @@ -1430,12 +1439,19 @@ fn unwrap_struct_type(typ: &HirType) -> Vec<(String, HirType)> { fn perform_instantiation_bindings(bindings: &TypeBindings) { for (var, binding) in bindings.values() { - *var.borrow_mut() = TypeBinding::Bound(binding.clone()); + let id = match &*var.borrow() { + TypeBinding::Bound(original) => { + panic!("Binding over already bound type! {} <- {}", original, binding) + } + TypeBinding::Unbound(id) => *id, + }; + println!(" Applying {:?} <- {}", id, binding); + var.bind(binding.clone()); } } fn undo_instantiation_bindings(bindings: TypeBindings) { for (id, (var, _)) in bindings { - *var.borrow_mut() = TypeBinding::Unbound(id); + var.unbind(id); } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index e66a6d57605..3871e692fd7 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -21,7 +21,7 @@ use crate::hir_def::{ use crate::token::{Attributes, SecondaryAttribute}; use crate::{ ContractFunctionType, FunctionDefinition, FunctionVisibility, Generics, Shared, TypeAliasType, - TypeBinding, TypeBindings, TypeVariable, TypeVariableId, TypeVariableKind, + TypeBindings, TypeVariable, TypeVariableId, TypeVariableKind, }; /// An arbitrary number to limit the recursion depth when searching for trait impls. @@ -463,7 +463,6 @@ impl NodeInterner { pub fn push_empty_trait(&mut self, type_id: TraitId, typ: &UnresolvedTrait) { let self_type_typevar_id = self.next_type_variable_id(); - let self_type_typevar = Shared::new(TypeBinding::Unbound(self_type_typevar_id)); self.traits.insert( type_id, @@ -478,10 +477,10 @@ impl NodeInterner { // can refer to it with generic arguments before the generic parameters themselves // are resolved. let id = TypeVariableId(0); - (id, Shared::new(TypeBinding::Unbound(id))) + (id, TypeVariable::unbound(id)) }), self_type_typevar_id, - self_type_typevar, + TypeVariable::unbound(self_type_typevar_id), ), ); } @@ -503,7 +502,7 @@ impl NodeInterner { // can refer to it with generic arguments before the generic parameters themselves // are resolved. let id = TypeVariableId(0); - (id, Shared::new(TypeBinding::Unbound(id))) + (id, TypeVariable::unbound(id)) }); let new_struct = StructType::new(struct_id, name, typ.struct_def.span, no_fields, generics); @@ -522,7 +521,7 @@ impl NodeInterner { Type::Error, vecmap(&typ.type_alias_def.generics, |_| { let id = TypeVariableId(0); - (id, Shared::new(TypeBinding::Unbound(id))) + (id, TypeVariable::unbound(id)) }), )); @@ -973,13 +972,32 @@ impl NodeInterner { object_type: &Type, trait_id: TraitId, ) -> Result> { - self.lookup_trait_implementation_helper(object_type, trait_id, IMPL_SEARCH_RECURSION_LIMIT) + let (impl_kind, bindings) = self.try_lookup_trait_implementation(object_type, trait_id)?; + Type::apply_type_bindings(bindings); + Ok(impl_kind) + } + + /// Similar to `lookup_trait_implementation` but does not apply any type bindings on success. + pub fn try_lookup_trait_implementation( + &self, + object_type: &Type, + trait_id: TraitId, + ) -> Result<(TraitImplKind, TypeBindings), Vec> { + let mut bindings = TypeBindings::new(); + let impl_kind = self.lookup_trait_implementation_helper( + object_type, + trait_id, + &mut bindings, + IMPL_SEARCH_RECURSION_LIMIT, + )?; + Ok((impl_kind, bindings)) } fn lookup_trait_implementation_helper( &self, object_type: &Type, trait_id: TraitId, + type_bindings: &mut TypeBindings, recursion_limit: u32, ) -> Result> { let make_constraint = || TraitConstraint::new(object_type.clone(), trait_id); @@ -993,16 +1011,22 @@ impl NodeInterner { self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; for (existing_object_type, impl_kind) in impls { - let (existing_object_type, type_bindings) = existing_object_type.instantiate(self); + let (existing_object_type, instantiation_bindings) = + existing_object_type.instantiate(self); + let mut fresh_bindings = TypeBindings::new(); + + if object_type.try_unify(&existing_object_type, &mut fresh_bindings).is_ok() { + // The unification was successful so we can append fresh_bindings to our bindings list + type_bindings.extend(instantiation_bindings); + type_bindings.extend(fresh_bindings); - if object_type.try_unify(&existing_object_type).is_ok() { if let TraitImplKind::Normal(impl_id) = impl_kind { let trait_impl = self.get_trait_implementation(*impl_id); let trait_impl = trait_impl.borrow(); if let Err(mut errors) = self.validate_where_clause( &trait_impl.where_clause, - &type_bindings, + type_bindings, recursion_limit, ) { errors.push(make_constraint()); @@ -1022,7 +1046,7 @@ impl NodeInterner { fn validate_where_clause( &self, where_clause: &[TraitConstraint], - type_bindings: &TypeBindings, + type_bindings: &mut TypeBindings, recursion_limit: u32, ) -> Result<(), Vec> { for constraint in where_clause { @@ -1030,6 +1054,7 @@ impl NodeInterner { self.lookup_trait_implementation_helper( &constraint_type, constraint.trait_id, + type_bindings, recursion_limit - 1, )?; } @@ -1050,7 +1075,7 @@ impl NodeInterner { trait_id: TraitId, ) -> bool { // Make sure there are no overlapping impls - if self.lookup_trait_implementation(&object_type, trait_id).is_ok() { + if self.try_lookup_trait_implementation(&object_type, trait_id).is_ok() { return false; } @@ -1078,8 +1103,8 @@ impl NodeInterner { let (instantiated_object_type, substitutions) = object_type.instantiate_type_variables(self); - if let Ok(TraitImplKind::Normal(existing)) = - self.lookup_trait_implementation(&instantiated_object_type, trait_id) + if let Ok((TraitImplKind::Normal(existing), _)) = + self.try_lookup_trait_implementation(&instantiated_object_type, trait_id) { let existing_impl = self.get_trait_implementation(existing); let existing_impl = existing_impl.borrow(); @@ -1227,8 +1252,10 @@ impl Methods { match interner.function_meta(&method).typ.instantiate(interner).0 { Type::Function(args, _, _) => { if let Some(object) = args.get(0) { - // TODO #3089: This is dangerous! try_unify may commit type bindings even on failure - if object.try_unify(typ).is_ok() { + let mut bindings = TypeBindings::new(); + + if object.try_unify(typ, &mut bindings).is_ok() { + Type::apply_type_bindings(bindings); return Some(method); } } From 5608c64e827ef252d1b426eb51d6f465b4e0b25a Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Fri, 1 Dec 2023 15:57:23 -0600 Subject: [PATCH 02/11] Now there is a different error in protocol-circuits --- compiler/noirc_frontend/src/hir_def/types.rs | 34 +++++++++++++------- compiler/noirc_frontend/src/node_interner.rs | 3 ++ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 8c36132b584..f3371488a51 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -423,7 +423,13 @@ impl TypeVariable { pub fn bind(&self, typ: Type) { let id = match &*self.0.borrow() { - TypeBinding::Bound(binding) => unreachable!("Expected unbound, found bound to {binding}"), + TypeBinding::Bound(binding) => { + if *binding == typ { + return; + } else { + unreachable!("TypeVariable::bind, cannot bind bound var {} to {}", binding, typ); + } + } TypeBinding::Unbound(id) => *id, }; @@ -798,14 +804,16 @@ impl Type { TypeBinding::Unbound(id) => *id, }; - match self.substitute(bindings) { - Type::Constant(length) if length == target_length => { - assert!(!self.occurs(target_id)); - bindings.insert(target_id, (var.clone(), self.clone())); + let this = self.substitute(bindings); + + match &this { + Type::Constant(length) if *length == target_length => { + assert!(!this.occurs(target_id)); + bindings.insert(target_id, (var.clone(), this)); Ok(()) } Type::NotConstant => { - assert!(!self.occurs(target_id)); + assert!(!this.occurs(target_id)); bindings.insert(target_id, (var.clone(), Type::NotConstant)); Ok(()) } @@ -829,7 +837,7 @@ impl Type { bindings.insert(*new_target_id, (new_var.clone(), clone)); Ok(()) } - TypeVariableKind::Constant(length) if length == target_length => { + TypeVariableKind::Constant(length) if *length == target_length => { let clone = Type::TypeVariable( var.clone(), TypeVariableKind::Constant(target_length), @@ -867,10 +875,12 @@ impl Type { TypeBinding::Unbound(id) => *id, }; - match self.substitute(bindings) { + let this = self.substitute(bindings); + + match &this { Type::FieldElement | Type::Integer(..) => { - assert!(!self.occurs(target_id)); - bindings.insert(target_id, (var.clone(), self.clone())); + assert!(!this.occurs(target_id)); + bindings.insert(target_id, (var.clone(), this)); Ok(()) } Type::TypeVariable(self_var, TypeVariableKind::IntegerOrField) => { @@ -880,8 +890,8 @@ impl Type { // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), TypeBinding::Unbound(_) => { - assert!(!self.occurs(target_id)); - bindings.insert(target_id, (var.clone(), self.clone())); + assert!(!this.occurs(target_id)); + bindings.insert(target_id, (var.clone(), this.clone())); Ok(()) } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 3871e692fd7..ac279b392aa 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1007,6 +1007,8 @@ impl NodeInterner { return Err(vec![make_constraint()]); } + let object_type = object_type.substitute(type_bindings); + let impls = self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; @@ -1051,6 +1053,7 @@ impl NodeInterner { ) -> Result<(), Vec> { for constraint in where_clause { let constraint_type = constraint.typ.substitute(type_bindings); + self.lookup_trait_implementation_helper( &constraint_type, constraint.trait_id, From 46c707767159cf2d354138212bbddc375c6e97ea Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Mon, 4 Dec 2023 13:28:13 -0600 Subject: [PATCH 03/11] Fix missed apply_type_bindings --- compiler/noirc_frontend/src/hir_def/types.rs | 2 ++ .../noirc_frontend/src/monomorphization/mod.rs | 16 ---------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index f3371488a51..554ad55bb0a 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1154,6 +1154,8 @@ impl Type { if !self.try_array_to_slice_coercion(expected, expression, interner) { errors.push(make_error()); } + } else { + Type::apply_type_bindings(bindings); } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 248dd533ec3..dc5c1bcb314 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -830,13 +830,6 @@ impl<'interner> Monomorphizer<'interner> { node_interner::TraitImplKind::Assumed { object_type } => { match self.interner.lookup_trait_implementation(&object_type, method.trait_id) { Ok(TraitImplKind::Normal(impl_id)) => { - let id = self.interner.get_trait_implementation(impl_id).borrow().methods - [method.method_index]; - - let name = self.interner.function_name(&id); - - println!("Looked up assumed trait impl for {}::{name}", object_type); - self.interner.get_trait_implementation(impl_id).borrow().methods [method.method_index] } @@ -865,8 +858,6 @@ impl<'interner> Monomorphizer<'interner> { let the_trait = self.interner.get_trait(method.trait_id); - println!("Function type = {:?} => {}", function_type, self.convert_type(&function_type)); - ast::Expression::Ident(ast::Ident { definition: Definition::Function(func_id), mutable: false, @@ -1439,13 +1430,6 @@ fn unwrap_struct_type(typ: &HirType) -> Vec<(String, HirType)> { fn perform_instantiation_bindings(bindings: &TypeBindings) { for (var, binding) in bindings.values() { - let id = match &*var.borrow() { - TypeBinding::Bound(original) => { - panic!("Binding over already bound type! {} <- {}", original, binding) - } - TypeBinding::Unbound(id) => *id, - }; - println!(" Applying {:?} <- {}", id, binding); var.bind(binding.clone()); } } From 8e7f90fa5c345fe4779767d68e973947c8bd7b92 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Mon, 4 Dec 2023 15:40:28 -0600 Subject: [PATCH 04/11] Trying to debug lookup_trait_implementation bindings --- compiler/noirc_frontend/src/hir_def/types.rs | 3 +++ compiler/noirc_frontend/src/node_interner.rs | 18 +++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 554ad55bb0a..d2acd787bb2 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -950,6 +950,7 @@ impl Type { Err(UnificationError) } else { assert!(!this.occurs(target_id)); + println!("Binding {:?} {:?} <- {}", target_id, var, this); bindings.insert(target_id, (var.clone(), this.clone())); Ok(()) } @@ -1192,7 +1193,9 @@ impl Type { /// Apply the given type bindings, making them permanently visible for each /// clone of each type variable bound. pub fn apply_type_bindings(bindings: TypeBindings) { + // println!("apply_type_bindings {} bindings", bindings.len()); for (id, (type_variable, binding)) in &bindings { + // println!(" {:?} {:?} <- {:?}", id, type_variable, binding); assert!(!binding.occurs(*id)); type_variable.bind(binding.clone()); } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index ac279b392aa..077855f40ec 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1013,14 +1013,22 @@ impl NodeInterner { self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; for (existing_object_type, impl_kind) in impls { - let (existing_object_type, instantiation_bindings) = + let (existing_object_type, mut instantiation_bindings) = existing_object_type.instantiate(self); + let mut fresh_bindings = TypeBindings::new(); + println!("try_unify {:?}", object_type); + println!(" == {:?}", existing_object_type); + if object_type.try_unify(&existing_object_type, &mut fresh_bindings).is_ok() { + println!("fresh_bindings len = {}", fresh_bindings.len()); + println!("insta_bindings len = {}", instantiation_bindings.len()); + println!(""); + // The unification was successful so we can append fresh_bindings to our bindings list - type_bindings.extend(instantiation_bindings); type_bindings.extend(fresh_bindings); + type_bindings.extend(instantiation_bindings.clone()); if let TraitImplKind::Normal(impl_id) = impl_kind { let trait_impl = self.get_trait_implementation(*impl_id); @@ -1029,6 +1037,7 @@ impl NodeInterner { if let Err(mut errors) = self.validate_where_clause( &trait_impl.where_clause, type_bindings, + &mut instantiation_bindings, recursion_limit, ) { errors.push(make_constraint()); @@ -1037,6 +1046,8 @@ impl NodeInterner { } return Ok(impl_kind.clone()); + } else { + println!("try_unify failed"); } } @@ -1049,10 +1060,11 @@ impl NodeInterner { &self, where_clause: &[TraitConstraint], type_bindings: &mut TypeBindings, + type_bindings2: &mut TypeBindings, recursion_limit: u32, ) -> Result<(), Vec> { for constraint in where_clause { - let constraint_type = constraint.typ.substitute(type_bindings); + let constraint_type = constraint.typ.substitute(type_bindings2); self.lookup_trait_implementation_helper( &constraint_type, From b12c7c4911bc8e488d33dd1cc4abd7febadc44ec Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 5 Dec 2023 13:27:27 -0600 Subject: [PATCH 05/11] Fix remainder of bugs --- .../src/hir/def_collector/dc_crate.rs | 4 +- .../src/hir/resolution/resolver.rs | 5 +- compiler/noirc_frontend/src/hir_def/types.rs | 57 ++++++++++--------- .../src/monomorphization/mod.rs | 12 +--- compiler/noirc_frontend/src/node_interner.rs | 21 ++----- 5 files changed, 42 insertions(+), 57 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 6a56ebc5c06..01ed0f72a73 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -415,7 +415,9 @@ pub(crate) fn check_methods_signatures( let self_type = resolver.get_self_type().expect("trait impl must have a Self type"); // Temporarily bind the trait's Self type to self_type so we can type check - the_trait.self_type_typevar.try_bind(self_type.clone(), the_trait.span) + the_trait + .self_type_typevar + .try_bind(self_type.clone(), the_trait.span) .expect("Failed binding Self type of trait"); for (file_id, func_id) in impl_methods { diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 2ac6f817720..7a776ca411d 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -927,10 +927,7 @@ impl<'a> Resolver<'a> { found.into_iter().collect() } - fn find_numeric_generics_in_type( - typ: &Type, - found: &mut BTreeMap, - ) { + fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap) { match typ { Type::FieldElement | Type::Integer(_, _) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index d2acd787bb2..c7b3fcc499c 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -421,13 +421,21 @@ impl TypeVariable { TypeVariable(Shared::new(TypeBinding::Unbound(id))) } + /// Bind this type variable to a value. + /// + /// Panics if this TypeVariable is already Bound. + /// Also Panics if the ID of this TypeVariable occurs within the given + /// binding, as that would cause an infinitely recursive type. pub fn bind(&self, typ: Type) { let id = match &*self.0.borrow() { TypeBinding::Bound(binding) => { if *binding == typ { return; } else { - unreachable!("TypeVariable::bind, cannot bind bound var {} to {}", binding, typ); + unreachable!( + "TypeVariable::bind, cannot bind bound var {} to {}", + binding, typ + ); } } TypeBinding::Unbound(id) => *id, @@ -439,7 +447,9 @@ impl TypeVariable { pub fn try_bind(&self, binding: Type, span: Span) -> Result<(), TypeCheckError> { let id = match &*self.0.borrow() { - TypeBinding::Bound(binding) => unreachable!("Expected unbound, found bound to {binding}"), + TypeBinding::Bound(binding) => { + unreachable!("Expected unbound, found bound to {binding}") + } TypeBinding::Unbound(id) => *id, }; @@ -451,12 +461,23 @@ impl TypeVariable { } } + /// Borrows this TypeVariable to (e.g.) manually match on the inner TypeBinding. + pub fn borrow(&self) -> std::cell::Ref { + self.0.borrow() + } + + /// Unbind this type variable, setting it to Unbound(id). + /// + /// This is generally a logic error to use outside of monomorphization. pub fn unbind(&self, id: TypeVariableId) { *self.0.borrow_mut() = TypeBinding::Unbound(id); } - pub fn borrow(&self) -> std::cell::Ref { - self.0.borrow() + /// Forcibly bind a type variable to a new type - even if the type + /// variable is already bound to a different type. This generally + /// a logic error to use outside of monomorphization. + pub fn force_bind(&self, typ: Type) { + *self.0.borrow_mut() = TypeBinding::Bound(typ); } } @@ -808,12 +829,10 @@ impl Type { match &this { Type::Constant(length) if *length == target_length => { - assert!(!this.occurs(target_id)); bindings.insert(target_id, (var.clone(), this)); Ok(()) } Type::NotConstant => { - assert!(!this.occurs(target_id)); bindings.insert(target_id, (var.clone(), Type::NotConstant)); Ok(()) } @@ -833,7 +852,6 @@ impl Type { var.clone(), TypeVariableKind::Constant(target_length), ); - assert!(!clone.occurs(*new_target_id)); bindings.insert(*new_target_id, (new_var.clone(), clone)); Ok(()) } @@ -842,7 +860,6 @@ impl Type { var.clone(), TypeVariableKind::Constant(target_length), ); - assert!(!clone.occurs(*new_target_id)); bindings.insert(*new_target_id, (new_var.clone(), clone)); Ok(()) } @@ -879,7 +896,6 @@ impl Type { match &this { Type::FieldElement | Type::Integer(..) => { - assert!(!this.occurs(target_id)); bindings.insert(target_id, (var.clone(), this)); Ok(()) } @@ -890,7 +906,6 @@ impl Type { // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), TypeBinding::Unbound(_) => { - assert!(!this.occurs(target_id)); bindings.insert(target_id, (var.clone(), this.clone())); Ok(()) } @@ -907,7 +922,6 @@ impl Type { // variable to IntegerOrField instead. let clone = Type::TypeVariable(var.clone(), TypeVariableKind::IntegerOrField); - assert!(!clone.occurs(*new_target_id)); bindings.insert(*new_target_id, (binding.clone(), clone)); Ok(()) } @@ -949,8 +963,6 @@ impl Type { if this.occurs(target_id) { Err(UnificationError) } else { - assert!(!this.occurs(target_id)); - println!("Binding {:?} {:?} <- {}", target_id, var, this); bindings.insert(target_id, (var.clone(), this.clone())); Ok(()) } @@ -1053,7 +1065,9 @@ impl Type { } } - (NamedGeneric(binding, _), other) if !binding.0.borrow().is_unbound() => { + (NamedGeneric(binding, _), other) | (other, NamedGeneric(binding, _)) + if !binding.0.borrow().is_unbound() => + { if let TypeBinding::Bound(link) = &*binding.0.borrow() { link.try_unify(other, bindings) } else { @@ -1061,16 +1075,8 @@ impl Type { } } - (other, NamedGeneric(binding, _)) if !binding.0.borrow().is_unbound() => { - if let TypeBinding::Bound(link) = &*binding.0.borrow() { - other.try_unify(link, bindings) - } else { - unreachable!("If guard ensures binding is bound") - } - } - (NamedGeneric(binding_a, name_a), NamedGeneric(binding_b, name_b)) => { - // Unbound NamedGenerics are caught by the checks above + // Bound NamedGenerics are caught by the check above assert!(binding_a.0.borrow().is_unbound()); assert!(binding_b.0.borrow().is_unbound()); @@ -1193,10 +1199,7 @@ impl Type { /// Apply the given type bindings, making them permanently visible for each /// clone of each type variable bound. pub fn apply_type_bindings(bindings: TypeBindings) { - // println!("apply_type_bindings {} bindings", bindings.len()); - for (id, (type_variable, binding)) in &bindings { - // println!(" {:?} {:?} <- {:?}", id, type_variable, binding); - assert!(!binding.occurs(*id)); + for (type_variable, binding) in bindings.values() { type_variable.bind(binding.clone()); } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index dc5c1bcb314..a566a43ab49 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -215,7 +215,7 @@ impl<'interner> Monomorphizer<'interner> { fn function(&mut self, f: node_interner::FuncId, id: FuncId) { if let Some((self_type, trait_id)) = self.interner.get_function_trait(&f) { let the_trait = self.interner.get_trait(trait_id); - the_trait.self_type_typevar.bind(self_type); + the_trait.self_type_typevar.force_bind(self_type); } let meta = self.interner.function_meta(&f); @@ -712,10 +712,6 @@ impl<'interner> Monomorphizer<'interner> { // Default any remaining unbound type variables. // This should only happen if the variable in question is unused // and within a larger generic type. - // NOTE: Make sure to review this if there is ever type-directed dispatch, - // like automatic solving of traits. It should be fine since it is strictly - // after type checking, but care should be taken that it doesn't change which - // impls are chosen. binding.bind(HirType::default_int_type()); ast::Type::Field } @@ -728,10 +724,6 @@ impl<'interner> Monomorphizer<'interner> { // Default any remaining unbound type variables. // This should only happen if the variable in question is unused // and within a larger generic type. - // NOTE: Make sure to review this if there is ever type-directed dispatch, - // like automatic solving of traits. It should be fine since it is strictly - // after type checking, but care should be taken that it doesn't change which - // impls are chosen. let default = if self.is_range_loop && matches!(kind, TypeVariableKind::IntegerOrField) { Type::default_range_loop_type() @@ -1430,7 +1422,7 @@ fn unwrap_struct_type(typ: &HirType) -> Vec<(String, HirType)> { fn perform_instantiation_bindings(bindings: &TypeBindings) { for (var, binding) in bindings.values() { - var.bind(binding.clone()); + var.force_bind(binding.clone()); } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 077855f40ec..48a8379f1d1 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1013,22 +1013,14 @@ impl NodeInterner { self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; for (existing_object_type, impl_kind) in impls { - let (existing_object_type, mut instantiation_bindings) = + let (existing_object_type, instantiation_bindings) = existing_object_type.instantiate(self); let mut fresh_bindings = TypeBindings::new(); - println!("try_unify {:?}", object_type); - println!(" == {:?}", existing_object_type); - if object_type.try_unify(&existing_object_type, &mut fresh_bindings).is_ok() { - println!("fresh_bindings len = {}", fresh_bindings.len()); - println!("insta_bindings len = {}", instantiation_bindings.len()); - println!(""); - // The unification was successful so we can append fresh_bindings to our bindings list type_bindings.extend(fresh_bindings); - type_bindings.extend(instantiation_bindings.clone()); if let TraitImplKind::Normal(impl_id) = impl_kind { let trait_impl = self.get_trait_implementation(*impl_id); @@ -1037,7 +1029,7 @@ impl NodeInterner { if let Err(mut errors) = self.validate_where_clause( &trait_impl.where_clause, type_bindings, - &mut instantiation_bindings, + &instantiation_bindings, recursion_limit, ) { errors.push(make_constraint()); @@ -1046,8 +1038,6 @@ impl NodeInterner { } return Ok(impl_kind.clone()); - } else { - println!("try_unify failed"); } } @@ -1060,16 +1050,17 @@ impl NodeInterner { &self, where_clause: &[TraitConstraint], type_bindings: &mut TypeBindings, - type_bindings2: &mut TypeBindings, + instantiation_bindings: &TypeBindings, recursion_limit: u32, ) -> Result<(), Vec> { for constraint in where_clause { - let constraint_type = constraint.typ.substitute(type_bindings2); + let constraint_type = constraint.typ.substitute(instantiation_bindings); + let constraint_type = constraint_type.substitute(type_bindings); self.lookup_trait_implementation_helper( &constraint_type, constraint.trait_id, - type_bindings, + &mut TypeBindings::new(), recursion_limit - 1, )?; } From 034c871cf9e543723e9bdd9f07a8495326b4ab5b Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 5 Dec 2023 13:42:55 -0600 Subject: [PATCH 06/11] Remove unused debug variable --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index b34802f28e5..f16d9b7d447 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -377,8 +377,6 @@ impl<'function> PerFunctionContext<'function> { self.context.call_stack.push_back(location); } - let c = self.context.call_stack.clone(); - let new_results = self.context.inline_function(ssa, function, &arguments); if has_location { From 9aad96341f4ebd81c14efd93e133d1b82bca5959 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 5 Dec 2023 13:44:20 -0600 Subject: [PATCH 07/11] Remove last of inlining debug code --- compiler/noirc_evaluator/src/ssa/opt/inlining.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index f16d9b7d447..b4f12b2f897 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -38,7 +38,8 @@ impl Ssa { /// as well save the work for later instead of performing it twice. pub(crate) fn inline_functions(mut self) -> Ssa { self.functions = btree_map(get_entry_point_functions(&self), |entry_point| { - (entry_point, InlineContext::new(&self, entry_point).inline_all(&self)) + let new_function = InlineContext::new(&self, entry_point).inline_all(&self); + (entry_point, new_function) }); self @@ -405,11 +406,6 @@ impl<'function> PerFunctionContext<'function> { self.context.builder.set_call_stack(call_stack); let new_results = self.context.builder.insert_instruction(instruction, ctrl_typevars); - - if results.len() != new_results.len() { - println!("In function {}", self.source_function.name()); - } - Self::insert_new_instruction_results(&mut self.values, &results, new_results); } From 6c5e3d6495b6cacb94298480bc43573fabd4ac0f Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 5 Dec 2023 13:50:08 -0600 Subject: [PATCH 08/11] Small cleanup --- compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs | 5 +---- compiler/noirc_frontend/src/node_interner.rs | 2 ++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 01ed0f72a73..d6eddeffc07 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -415,10 +415,7 @@ pub(crate) fn check_methods_signatures( let self_type = resolver.get_self_type().expect("trait impl must have a Self type"); // Temporarily bind the trait's Self type to self_type so we can type check - the_trait - .self_type_typevar - .try_bind(self_type.clone(), the_trait.span) - .expect("Failed binding Self type of trait"); + the_trait.self_type_typevar.bind(self_type.clone()); for (file_id, func_id) in impl_methods { let impl_method = resolver.interner.function_meta(func_id); diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index dd0a85e90d0..0be1c93f478 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1087,6 +1087,8 @@ impl NodeInterner { self.lookup_trait_implementation_helper( &constraint_type, constraint.trait_id, + // Use a fresh set of type bindings here since the constraint_type originates from + // our impl list, which we don't want to bind to. &mut TypeBindings::new(), recursion_limit - 1, )?; From 20bd680871796062964fe6e57e19ad3c958fd7c7 Mon Sep 17 00:00:00 2001 From: jfecher Date: Thu, 7 Dec 2023 06:43:00 -0600 Subject: [PATCH 09/11] fix: Make trait functions generic over `Self` (#3702) # Description ## Problem & Summary Checking a trait function's type previously was very ad-hoc. When a `TraitMethodReference` was found, it'd manually create a function type from the trait function's definition and instantiate it - even though instantiating a `Type::Function` does nothing. I've changed the TraitFunction struct to add a `Type` field for the whole function type, which includes generics. The type now should usually be a `Type::Forall`. In addition, I've added the trait's `Self` type to the list of generics in the Type::Forall. This fixes the issue where the `Self` type would be bound after certain function calls which would affect callsites but not a function body leading to a type mismatch between function return type and a call's return type. This is the issue we were encountering after adding traits to noir-protocol-circuits. ## Additional Information No tests are provided since I was unable to make a minimal repro from the thousands of lines of noir-protocol-circuits unfortunately. I'll keep experimenting while this PR is up. I'm also merging into jf/fix-3089 since this was somewhat accidentally built on top of it. The two changes are otherwise unrelated. ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../noirc_evaluator/src/ssa/ssa_gen/mod.rs | 1 - .../src/hir/def_collector/dc_crate.rs | 16 +++++--- .../src/hir/resolution/traits.rs | 23 ++++++++--- .../noirc_frontend/src/hir/type_check/expr.rs | 11 +---- compiler/noirc_frontend/src/hir_def/traits.rs | 41 +++++++++++++------ compiler/noirc_frontend/src/hir_def/types.rs | 21 +--------- .../src/monomorphization/mod.rs | 15 ++++--- 7 files changed, 66 insertions(+), 62 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index 41327c988d2..d7e6b8b0a3d 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -591,7 +591,6 @@ impl<'a> FunctionContext<'a> { } self.codegen_intrinsic_call_checks(function, &arguments, call.location); - Ok(self.insert_call(function, arguments, &call.return_type, call.location)) } diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index d6eddeffc07..0806a8eb757 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -433,7 +433,10 @@ pub(crate) fn check_methods_signatures( let impl_method_generic_count = impl_method.typ.generic_count() - trait_impl_generic_count; - let trait_method_generic_count = trait_method.generics.len(); + + // We subtract 1 here to account for the implicit generic `Self` type that is on all + // traits (and thus trait methods) but is not required (or allowed) for users to specify. + let trait_method_generic_count = trait_method.generics().len() - 1; if impl_method_generic_count != trait_method_generic_count { let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics { @@ -447,9 +450,9 @@ pub(crate) fn check_methods_signatures( } if let Type::Function(impl_params, _, _) = impl_function_type.0 { - if trait_method.arguments.len() == impl_params.len() { + if trait_method.arguments().len() == impl_params.len() { // Check the parameters of the impl method against the parameters of the trait method - let args = trait_method.arguments.iter(); + let args = trait_method.arguments().iter(); let args_and_params = args.zip(&impl_params).zip(&impl_method.parameters.0); for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in @@ -468,7 +471,7 @@ pub(crate) fn check_methods_signatures( } else { let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters { actual_num_parameters: impl_method.parameters.0.len(), - expected_num_parameters: trait_method.arguments.len(), + expected_num_parameters: trait_method.arguments().len(), trait_name: the_trait.name.to_string(), method_name: func_name.to_string(), span: impl_method.location.span, @@ -481,11 +484,12 @@ pub(crate) fn check_methods_signatures( let resolved_return_type = resolver.resolve_type(impl_method.return_type.get_type().into_owned()); - trait_method.return_type.unify(&resolved_return_type, &mut typecheck_errors, || { + // TODO: This is not right since it may bind generic return types + trait_method.return_type().unify(&resolved_return_type, &mut typecheck_errors, || { let ret_type_span = impl_method.return_type.get_type().span; let expr_span = ret_type_span.expect("return type must always have a span"); - let expected_typ = trait_method.return_type.to_string(); + let expected_typ = trait_method.return_type().to_string(); let expr_typ = impl_method.return_type().to_string(); TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span } }); diff --git a/compiler/noirc_frontend/src/hir/resolution/traits.rs b/compiler/noirc_frontend/src/hir/resolution/traits.rs index 702e96362a6..7a6cbccb081 100644 --- a/compiler/noirc_frontend/src/hir/resolution/traits.rs +++ b/compiler/noirc_frontend/src/hir/resolution/traits.rs @@ -18,7 +18,7 @@ use crate::{ }, hir_def::traits::{Trait, TraitConstant, TraitFunction, TraitImpl, TraitType}, node_interner::{FuncId, NodeInterner, TraitId}, - Path, Shared, TraitItem, Type, TypeVariableKind, + Path, Shared, TraitItem, Type, TypeBinding, TypeVariableKind, }; use super::{ @@ -111,8 +111,17 @@ fn resolve_trait_methods( resolver.set_self_type(Some(self_type)); let arguments = vecmap(parameters, |param| resolver.resolve_type(param.1.clone())); - let resolved_return_type = resolver.resolve_type(return_type.get_type().into_owned()); - let generics = resolver.get_generics().to_vec(); + let return_type = resolver.resolve_type(return_type.get_type().into_owned()); + + let mut generics = vecmap(resolver.get_generics(), |(_, type_var, _)| match &*type_var + .borrow() + { + TypeBinding::Unbound(id) => (*id, type_var.clone()), + TypeBinding::Bound(binding) => unreachable!("Trait generic was bound to {binding}"), + }); + + // Ensure the trait is generic over the Self type as well + generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar)); let name = name.clone(); let span: Span = name.span(); @@ -128,11 +137,13 @@ fn resolve_trait_methods( None }; + let no_environment = Box::new(Type::Unit); + let function_type = Type::Function(arguments, Box::new(return_type), no_environment); + let typ = Type::Forall(generics, Box::new(function_type)); + let f = TraitFunction { name, - generics, - arguments, - return_type: resolved_return_type, + typ, span, default_impl, default_impl_file_id: unresolved_trait.file_id, diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 1fe1eaa899c..9a64ab55196 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -289,14 +289,7 @@ impl<'interner> TypeChecker<'interner> { } HirExpression::TraitMethodReference(method) => { let the_trait = self.interner.get_trait(method.trait_id); - let method = &the_trait.methods[method.method_index]; - - let typ = Type::Function( - method.arguments.clone(), - Box::new(method.return_type.clone()), - Box::new(Type::Unit), - ); - + let typ = &the_trait.methods[method.method_index].typ; let (typ, bindings) = typ.instantiate(self.interner); self.interner.store_instantiation_bindings(*expr_id, bindings); typ @@ -546,7 +539,7 @@ impl<'interner> TypeChecker<'interner> { HirMethodReference::TraitMethodId(method) => { let the_trait = self.interner.get_trait(method.trait_id); let method = &the_trait.methods[method.method_index]; - (method.get_type(), method.arguments.len()) + (method.typ.clone(), method.arguments().len()) } }; diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 5f0bf49ca0f..e6c46a46073 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{ graph::CrateId, node_interner::{FuncId, TraitId, TraitMethodId}, @@ -11,9 +9,7 @@ use noirc_errors::Span; #[derive(Clone, Debug, PartialEq, Eq)] pub struct TraitFunction { pub name: Ident, - pub generics: Vec<(Rc, TypeVariable, Span)>, - pub arguments: Vec, - pub return_type: Type, + pub typ: Type, pub span: Span, pub default_impl: Option>, pub default_impl_file_id: fm::FileId, @@ -145,12 +141,33 @@ impl std::fmt::Display for Trait { } impl TraitFunction { - pub fn get_type(&self) -> Type { - Type::Function( - self.arguments.clone(), - Box::new(self.return_type.clone()), - Box::new(Type::Unit), - ) - .generalize() + pub fn arguments(&self) -> &[Type] { + match &self.typ { + Type::Function(args, _, _) => args, + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(args, _, _) => args, + _ => unreachable!("Trait function does not have a function type"), + }, + _ => unreachable!("Trait function does not have a function type"), + } + } + + pub fn generics(&self) -> &[(TypeVariableId, TypeVariable)] { + match &self.typ { + Type::Function(..) => &[], + Type::Forall(generics, _) => generics, + _ => unreachable!("Trait function does not have a function type"), + } + } + + pub fn return_type(&self) -> &Type { + match &self.typ { + Type::Function(_, return_type, _) => return_type, + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(_, return_type, _) => return_type, + _ => unreachable!("Trait function does not have a function type"), + }, + _ => unreachable!("Trait function does not have a function type"), + } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index c7b3fcc499c..ff5e157cec4 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -75,7 +75,7 @@ pub enum Type { /// the environment should be `Unit` by default, /// for closures it should contain a `Tuple` type with the captured /// variable types. - Function(Vec, Box, Box), + Function(Vec, /*return_type:*/ Box, /*environment:*/ Box), /// &mut T MutableReference(Box), @@ -668,31 +668,12 @@ impl Type { } } - /// Takes a monomorphic type and generalizes it over each of the given type variables. - pub(crate) fn generalize_from_variables( - self, - type_vars: HashMap, - ) -> Type { - let polymorphic_type_vars = vecmap(type_vars, |type_var| type_var); - Type::Forall(polymorphic_type_vars, Box::new(self)) - } - /// Takes a monomorphic type and generalizes it over each of the type variables in the /// given type bindings, ignoring what each type variable is bound to in the TypeBindings. pub(crate) fn generalize_from_substitutions(self, type_bindings: TypeBindings) -> Type { let polymorphic_type_vars = vecmap(type_bindings, |(id, (type_var, _))| (id, type_var)); Type::Forall(polymorphic_type_vars, Box::new(self)) } - - /// Takes a monomorphic type and generalizes it over each type variable found within. - /// - /// Note that Noir's type system assumes any Type::Forall are only present at top-level, - /// and thus all type variable's within a type are free. - pub(crate) fn generalize(self) -> Type { - let mut type_variables = HashMap::new(); - self.find_all_unbound_type_variables(&mut type_variables); - self.generalize_from_variables(type_variables) - } } impl std::fmt::Display for Type { diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index e72c3a2a948..52ed0c746e1 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -886,7 +886,6 @@ impl<'interner> Monomorphizer<'interner> { let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); - let func: Box; let return_type = self.interner.id_type(id); let return_type = self.convert_type(&return_type); @@ -907,7 +906,8 @@ impl<'interner> Monomorphizer<'interner> { let func_type = self.interner.id_type(call.func); let func_type = self.convert_type(&func_type); let is_closure = self.is_function_closure(func_type); - if is_closure { + + let func = if is_closure { let local_id = self.next_local_id(); // store the function in a temporary variable before calling it @@ -929,14 +929,13 @@ impl<'interner> Monomorphizer<'interner> { typ: self.convert_type(&self.interner.id_type(call.func)), }); - func = Box::new(ast::Expression::ExtractTupleField( - Box::new(extracted_func.clone()), - 1usize, - )); - let env_argument = ast::Expression::ExtractTupleField(Box::new(extracted_func), 0usize); + let env_argument = + ast::Expression::ExtractTupleField(Box::new(extracted_func.clone()), 0usize); arguments.insert(0, env_argument); + + Box::new(ast::Expression::ExtractTupleField(Box::new(extracted_func), 1usize)) } else { - func = original_func.clone(); + original_func.clone() }; let call = self From 90fd5f35bf18f45af5b48f57e816162a14ec999b Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 7 Dec 2023 08:56:28 -0600 Subject: [PATCH 10/11] Remove equality check --- compiler/noirc_frontend/src/hir_def/types.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index ff5e157cec4..9634a1e9d88 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -429,14 +429,7 @@ impl TypeVariable { pub fn bind(&self, typ: Type) { let id = match &*self.0.borrow() { TypeBinding::Bound(binding) => { - if *binding == typ { - return; - } else { - unreachable!( - "TypeVariable::bind, cannot bind bound var {} to {}", - binding, typ - ); - } + unreachable!("TypeVariable::bind, cannot bind bound var {} to {}", binding, typ) } TypeBinding::Unbound(id) => *id, }; From 6425e564ba120cf73add7fd3eadaf9c467f5a354 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Thu, 7 Dec 2023 09:35:52 -0600 Subject: [PATCH 11/11] Add regression test --- .../method_call_regression/Nargo.toml | 7 ++++++ .../method_call_regression/src/main.nr | 25 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 test_programs/compile_success_empty/method_call_regression/Nargo.toml create mode 100644 test_programs/compile_success_empty/method_call_regression/src/main.nr diff --git a/test_programs/compile_success_empty/method_call_regression/Nargo.toml b/test_programs/compile_success_empty/method_call_regression/Nargo.toml new file mode 100644 index 00000000000..92c9b942008 --- /dev/null +++ b/test_programs/compile_success_empty/method_call_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "short" +type = "bin" +authors = [""] +compiler_version = ">=0.19.4" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/method_call_regression/src/main.nr b/test_programs/compile_success_empty/method_call_regression/src/main.nr new file mode 100644 index 00000000000..8bb7ebcac45 --- /dev/null +++ b/test_programs/compile_success_empty/method_call_regression/src/main.nr @@ -0,0 +1,25 @@ +use dep::std; + +fn main() { + // s: Struct + let s = Struct { b: () }; + // Regression for #3089 + s.foo(); +} + +struct Struct { b: B } + +// Before the fix, this candidate is searched first, binding ? to `u8` permanently. +impl Struct { + fn foo(self) {} +} + +// Then this candidate would be searched next but would not be a valid +// candidate since `Struct` != `Struct`. +// +// With the fix, the type of `s` correctly no longer changes until a +// method is actually selected. So this candidate is now valid since +// `Struct` unifies with `Struct` with `? = u32`. +impl Struct { + fn foo(self) {} +}