diff --git a/compiler/noirc_frontend/src/elaborator/scope.rs b/compiler/noirc_frontend/src/elaborator/scope.rs index 73aed9bf06c..258c32d4427 100644 --- a/compiler/noirc_frontend/src/elaborator/scope.rs +++ b/compiler/noirc_frontend/src/elaborator/scope.rs @@ -1,6 +1,6 @@ use noirc_errors::{Location, Spanned}; -use crate::ast::ERROR_IDENT; +use crate::ast::{PathKind, ERROR_IDENT}; use crate::hir::def_map::{LocalModuleId, ModuleId}; use crate::hir::resolution::path_resolver::{PathResolver, StandardPathResolver}; use crate::hir::scope::{Scope as GenericScope, ScopeTree as GenericScopeTree}; @@ -43,7 +43,34 @@ impl<'context> Elaborator<'context> { } pub(super) fn resolve_path(&mut self, path: Path) -> Result { - let resolver = StandardPathResolver::new(self.module_id()); + let mut module_id = self.module_id(); + let mut path = path; + + if path.kind == PathKind::Plain && path.first_name() == SELF_TYPE_NAME { + if let Some(Type::Struct(struct_type, _)) = &self.self_type { + let struct_type = struct_type.borrow(); + if path.segments.len() == 1 { + return Ok(ModuleDefId::TypeId(struct_type.id)); + } + + module_id = struct_type.id.module_id(); + path = Path { + segments: path.segments[1..].to_vec(), + kind: PathKind::Plain, + span: path.span(), + }; + } + } + + self.resolve_path_in_module(path, module_id) + } + + fn resolve_path_in_module( + &mut self, + path: Path, + module_id: ModuleId, + ) -> Result { + let resolver = StandardPathResolver::new(module_id); let path_resolution; if self.interner.track_references { diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index a21259a4f0d..91644ff4470 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2729,3 +2729,83 @@ fn incorrect_generic_count_on_type_alias() { assert_eq!(actual, 1); assert_eq!(expected, 0); } + +#[test] +fn uses_self_type_for_struct_function_call() { + let src = r#" + struct S { } + + impl S { + fn one() -> Field { + 1 + } + + fn two() -> Field { + Self::one() + Self::one() + } + } + + fn main() {} + "#; + assert_no_errors(src); +} + +#[test] +fn uses_self_type_inside_trait() { + let src = r#" + trait Foo { + fn foo() -> Self { + Self::bar() + } + + fn bar() -> Self; + } + + impl Foo for Field { + fn bar() -> Self { + 1 + } + } + + fn main() { + let _: Field = Foo::foo(); + } + "#; + assert_no_errors(src); +} + +#[test] +fn uses_self_type_in_trait_where_clause() { + let src = r#" + trait Trait { + fn trait_func() -> bool; + } + + trait Foo where Self: Trait { + fn foo(self) -> bool { + self.trait_func() + } + } + + struct Bar { + + } + + impl Foo for Bar { + + } + + fn main() {} + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { method_name, .. }) = + &errors[0].0 + else { + panic!("Expected an unresolved method call error, got {:?}", errors[0].0); + }; + + assert_eq!(method_name, "trait_func"); +}