Skip to content

Commit

Permalink
feat: check trait where clause (#6325)
Browse files Browse the repository at this point in the history
# Description

## Problem

Resolves #6023

## Summary

It turns out that checking a trait's `where` clause is very similar to
checking parent traits: for parent traits the type to check if `Self`,
for where clause it's the specified type. So `trait Foo where Self:
Constraint` is the same as `trait Foo: Constraint`. I thought about
unifying the code to only have a single list of constraints but I don't
know if it's worth it (we could maybe give different errors in one case
or another, though right now the errors are the same).

There's a chance this PR is a breaking change, because when I finished
implementing it one test broke and I had to amend it.

## Additional Context


## Documentation

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
asterite authored Oct 24, 2024
1 parent 3299c25 commit 0de3241
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 10 deletions.
100 changes: 96 additions & 4 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1041,11 +1041,14 @@ impl<'context> Elaborator<'context> {
self.file = trait_impl.file_id;
self.local_module = trait_impl.module_id;

self.check_parent_traits_are_implemented(&trait_impl);

self.generics = trait_impl.resolved_generics;
self.generics = trait_impl.resolved_generics.clone();
self.current_trait_impl = trait_impl.impl_id;

self.add_trait_impl_assumed_trait_implementations(trait_impl.impl_id);
self.check_trait_impl_where_clause_matches_trait_where_clause(&trait_impl);
self.check_parent_traits_are_implemented(&trait_impl);
self.remove_trait_impl_assumed_trait_implementations(trait_impl.impl_id);

for (module, function, _) in &trait_impl.methods.functions {
self.local_module = *module;
let errors = check_trait_impl_method_matches_declaration(self.interner, *function);
Expand All @@ -1059,6 +1062,95 @@ impl<'context> Elaborator<'context> {
self.generics.clear();
}

fn add_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option<TraitImplId>) {
if let Some(impl_id) = impl_id {
if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id)
{
for trait_constrain in &trait_implementation.borrow().where_clause {
let trait_bound = &trait_constrain.trait_bound;
self.interner.add_assumed_trait_implementation(
trait_constrain.typ.clone(),
trait_bound.trait_id,
trait_bound.trait_generics.clone(),
);
}
}
}
}

fn remove_trait_impl_assumed_trait_implementations(&mut self, impl_id: Option<TraitImplId>) {
if let Some(impl_id) = impl_id {
if let Some(trait_implementation) = self.interner.try_get_trait_implementation(impl_id)
{
for trait_constrain in &trait_implementation.borrow().where_clause {
self.interner.remove_assumed_trait_implementations_for_trait(
trait_constrain.trait_bound.trait_id,
);
}
}
}
}

fn check_trait_impl_where_clause_matches_trait_where_clause(
&mut self,
trait_impl: &UnresolvedTraitImpl,
) {
let Some(trait_id) = trait_impl.trait_id else {
return;
};

let Some(the_trait) = self.interner.try_get_trait(trait_id) else {
return;
};

if the_trait.where_clause.is_empty() {
return;
}

let impl_trait = the_trait.name.to_string();
let the_trait_file = the_trait.location.file;

let mut bindings = TypeBindings::new();
bind_ordered_generics(
&the_trait.generics,
&trait_impl.resolved_trait_generics,
&mut bindings,
);

// Check that each of the trait's where clause constraints is satisfied
for trait_constraint in the_trait.where_clause.clone() {
let Some(trait_constraint_trait) =
self.interner.try_get_trait(trait_constraint.trait_bound.trait_id)
else {
continue;
};

let trait_constraint_type = trait_constraint.typ.substitute(&bindings);
let trait_bound = &trait_constraint.trait_bound;

if self
.interner
.try_lookup_trait_implementation(
&trait_constraint_type,
trait_bound.trait_id,
&trait_bound.trait_generics.ordered,
&trait_bound.trait_generics.named,
)
.is_err()
{
let missing_trait =
format!("{}{}", trait_constraint_trait.name, trait_bound.trait_generics);
self.push_err(ResolverError::TraitNotImplemented {
impl_trait: impl_trait.clone(),
missing_trait,
type_missing_trait: trait_constraint_type.to_string(),
span: trait_impl.object_type.span,
missing_trait_location: Location::new(trait_bound.span, the_trait_file),
});
}
}
}

fn check_parent_traits_are_implemented(&mut self, trait_impl: &UnresolvedTraitImpl) {
let Some(trait_id) = trait_impl.trait_id else {
return;
Expand Down Expand Up @@ -1182,7 +1274,7 @@ impl<'context> Elaborator<'context> {
trait_id,
trait_generics,
file: trait_impl.file_id,
where_clause: where_clause.clone(),
where_clause,
methods,
});

Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ impl<'context> Elaborator<'context> {
&resolved_generics,
);

let where_clause =
this.resolve_trait_constraints(&unresolved_trait.trait_def.where_clause);

// Each associated type in this trait is also an implicit generic
for associated_type in &this.interner.get_trait(*trait_id).associated_types {
this.generics.push(associated_type.clone());
Expand All @@ -48,6 +51,7 @@ impl<'context> Elaborator<'context> {
this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_methods(methods);
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});
});

Expand Down
6 changes: 6 additions & 0 deletions compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub struct Trait {

/// The resolved trait bounds (for example in `trait Foo: Bar + Baz`, this would be `Bar + Baz`)
pub trait_bounds: Vec<ResolvedTraitBound>,

pub where_clause: Vec<TraitConstraint>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -154,6 +156,10 @@ impl Trait {
self.trait_bounds = trait_bounds;
}

pub fn set_where_clause(&mut self, where_clause: Vec<TraitConstraint>) {
self.where_clause = where_clause;
}

pub fn find_method(&self, name: &str) -> Option<TraitMethodId> {
for (idx, method) in self.methods.iter().enumerate() {
if &method.name == name {
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ impl NodeInterner {
method_ids: unresolved_trait.method_ids.clone(),
associated_types,
trait_bounds: Vec::new(),
where_clause: Vec::new(),
};

self.traits.insert(type_id, new_trait);
Expand Down
15 changes: 9 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2970,9 +2970,7 @@ fn uses_self_type_in_trait_where_clause() {
}
}
struct Bar {
}
struct Bar {}
impl Foo for Bar {
Expand All @@ -2984,12 +2982,17 @@ fn uses_self_type_in_trait_where_clause() {
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
assert_eq!(errors.len(), 2);

let CompilationError::ResolverError(ResolverError::TraitNotImplemented { .. }) = &errors[0].0
else {
panic!("Expected a trait not implemented error, got {:?}", errors[0].0);
};

let CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { method_name, .. }) =
&errors[0].0
&errors[1].0
else {
panic!("Expected an unresolved method call error, got {:?}", errors[0].0);
panic!("Expected an unresolved method call error, got {:?}", errors[1].0);
};

assert_eq!(method_name, "trait_func");
Expand Down
113 changes: 113 additions & 0 deletions compiler/noirc_frontend/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,116 @@ fn trait_inheritance_missing_parent_implementation() {
assert_eq!(typ, "Struct");
assert_eq!(impl_trait, "Bar");
}

#[test]
fn errors_on_unknown_type_in_trait_where_clause() {
let src = r#"
pub trait Foo<T> where T: Unknown {}
fn main() {
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
}

#[test]
fn does_not_error_if_impl_trait_constraint_is_satisfied_for_concrete_type() {
let src = r#"
pub trait Greeter {
fn greet(self);
}
pub trait Foo<T>
where
T: Greeter,
{
fn greet<U>(object: U)
where
U: Greeter,
{
object.greet();
}
}
pub struct SomeGreeter;
impl Greeter for SomeGreeter {
fn greet(self) {}
}
pub struct Bar;
impl Foo<SomeGreeter> for Bar {}
fn main() {}
"#;
assert_no_errors(src);
}

#[test]
fn does_not_error_if_impl_trait_constraint_is_satisfied_for_type_variable() {
let src = r#"
pub trait Greeter {
fn greet(self);
}
pub trait Foo<T> where T: Greeter {
fn greet(object: T) {
object.greet();
}
}
pub struct Bar;
impl<T> Foo<T> for Bar where T: Greeter {
}
fn main() {
}
"#;
assert_no_errors(src);
}
#[test]
fn errors_if_impl_trait_constraint_is_not_satisfied() {
let src = r#"
pub trait Greeter {
fn greet(self);
}
pub trait Foo<T>
where
T: Greeter,
{
fn greet<U>(object: U)
where
U: Greeter,
{
object.greet();
}
}
pub struct SomeGreeter;
pub struct Bar;
impl Foo<SomeGreeter> for Bar {}
fn main() {}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

let CompilationError::ResolverError(ResolverError::TraitNotImplemented {
impl_trait,
missing_trait: the_trait,
type_missing_trait: typ,
..
}) = &errors[0].0
else {
panic!("Expected a TraitNotImplemented error, got {:?}", &errors[0].0);
};

assert_eq!(the_trait, "Greeter");
assert_eq!(typ, "SomeGreeter");
assert_eq!(impl_trait, "Foo");
}

0 comments on commit 0de3241

Please sign in to comment.