Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Resolve arguments to attributes #5649

Merged
merged 10 commits into from
Aug 2, 2024
352 changes: 348 additions & 4 deletions compiler/noirc_frontend/src/elaborator/comptime.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
use std::{collections::BTreeMap, fmt::Display};

use chumsky::Parser;
use fm::FileId;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};

use crate::{
hir::{
comptime::{Interpreter, InterpreterError, Value},
def_collector::{
dc_crate::{
CollectedItems, CompilationError, UnresolvedFunctions, UnresolvedStruct,
UnresolvedTrait, UnresolvedTraitImpl,
},
dc_mod,
},
resolution::errors::ResolverError,
},
hir_def::expr::HirIdent,
node_interner::{DependencyId, FuncId},
lexer::Lexer,
macros_api::{
Expression, ExpressionKind, HirExpression, NodeInterner, SecondaryAttribute, StructId,
},
node_interner::{DefinitionKind, DependencyId, FuncId, TraitId},
parser::{self, TopLevelStatement},
Type, TypeBindings,
};

use super::{Elaborator, FunctionContext, ResolverMeta};
Expand Down Expand Up @@ -35,13 +59,11 @@ impl<'context> Elaborator<'context> {
elaborator.introduce_generics_into_scope(meta.all_generics.clone());
}

elaborator.comptime_scopes = std::mem::take(&mut self.comptime_scopes);
elaborator.populate_scope_from_comptime_scopes();

let result = f(&mut elaborator);
elaborator.check_and_pop_function_context();

self.comptime_scopes = elaborator.comptime_scopes;
self.errors.append(&mut elaborator.errors);
result
}
Expand All @@ -50,7 +72,7 @@ impl<'context> Elaborator<'context> {
// Take the comptime scope to be our runtime scope.
// Iterate from global scope to the most local scope so that the
// later definitions will naturally shadow the former.
for scope in &self.comptime_scopes {
for scope in &self.interner.comptime_scopes {
for definition_id in scope.keys() {
let definition = self.interner.definition(*definition_id);
let name = definition.name.clone();
Expand All @@ -63,4 +85,326 @@ impl<'context> Elaborator<'context> {
}
}
}

pub(super) fn run_comptime_attributes_on_item(
&mut self,
attributes: &[SecondaryAttribute],
item: Value,
span: Span,
generated_items: &mut CollectedItems,
) {
for attribute in attributes {
if let SecondaryAttribute::Custom(name) = attribute {
if let Err(error) =
self.run_comptime_attribute_on_item(name, item.clone(), span, generated_items)
{
self.errors.push(error);
}
}
}
}

fn run_comptime_attribute_on_item(
&mut self,
attribute: &str,
item: Value,
span: Span,
generated_items: &mut CollectedItems,
) -> Result<(), (CompilationError, FileId)> {
let location = Location::new(span, self.file);
let Some((function, arguments)) = Self::parse_attribute(attribute, self.file)? else {
// Do not issue an error if the attribute is unknown
return Ok(());
};

// Elaborate the function, rolling back any errors generated in case it is unknown
let error_count = self.errors.len();
let function = self.elaborate_expression(function).0;
self.errors.truncate(error_count);

let definition_id = match self.interner.expression(&function) {
HirExpression::Ident(ident, _) => ident.id,
_ => return Ok(()),
};

let Some(definition) = self.interner.try_definition(definition_id) else {
// If there's no such function, don't return an error.
// This preserves backwards compatibility in allowing custom attributes that
// do not refer to comptime functions.
return Ok(());
};

let DefinitionKind::Function(function) = definition.kind else {
return Err((ResolverError::NonFunctionInAnnotation { span }.into(), self.file));
};

let mut interpreter = self.setup_interpreter();
let mut arguments =
Self::handle_attribute_arguments(&mut interpreter, function, arguments, location)
.map_err(|error| {
let file = error.get_location().file;
(error.into(), file)
})?;

arguments.insert(0, (item, location));

let value = interpreter
.call_function(function, arguments, TypeBindings::new(), location)
.map_err(|error| error.into_compilation_error_pair())?;

if value != Value::Unit {
let items = value
.into_top_level_items(location, self.interner)
.map_err(|error| error.into_compilation_error_pair())?;

self.add_items(items, generated_items, location);
}

Ok(())
}

/// Parses an attribute in the form of a function call (e.g. `#[foo(a b, c d)]`) into
/// the function and quoted arguments called (e.g. `("foo", vec![(a b, location), (c d, location)])`)
#[allow(clippy::type_complexity)]
fn parse_attribute(
annotation: &str,
file: FileId,
) -> Result<Option<(Expression, Vec<Expression>)>, (CompilationError, FileId)> {
let (tokens, mut lexing_errors) = Lexer::lex(annotation);
if !lexing_errors.is_empty() {
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
return Err((lexing_errors.swap_remove(0).into(), file));
}

let expression = parser::expression()
.parse(tokens)
.map_err(|mut errors| (errors.swap_remove(0).into(), file))?;

Ok(match expression.kind {
ExpressionKind::Call(call) => Some((*call.func, call.arguments)),
ExpressionKind::Variable(_) => Some((expression, Vec::new())),
_ => None,
})
}

fn handle_attribute_arguments(
interpreter: &mut Interpreter,
function: FuncId,
arguments: Vec<Expression>,
location: Location,
) -> Result<Vec<(Value, Location)>, InterpreterError> {
let meta = interpreter.elaborator.interner.function_meta(&function);
let mut parameters = vecmap(&meta.parameters.0, |(_, typ, _)| typ.clone());

// Remove the initial parameter for the comptime item since that is not included
// in `arguments` at this point.
parameters.remove(0);

// If the function is varargs, push the type of the last slice element N times
// to account for N extra arguments.
let modifiers = interpreter.elaborator.interner.function_modifiers(&function);
let is_varargs = modifiers.attributes.is_varargs();
let varargs_type = if is_varargs { parameters.pop() } else { None };

let varargs_elem_type = varargs_type.as_ref().and_then(|t| t.slice_element_type());

let mut new_arguments = Vec::with_capacity(arguments.len());
let mut varargs = im::Vector::new();

for (i, arg) in arguments.into_iter().enumerate() {
let param_type = parameters.get(i).or(varargs_elem_type).unwrap_or(&Type::Error);

let mut push_arg = |arg| {
if i >= parameters.len() {
varargs.push_back(arg);
} else {
new_arguments.push((arg, location));
}
};

if *param_type == Type::Quoted(crate::QuotedType::TraitDefinition) {
let trait_id = match arg.kind {
ExpressionKind::Variable(path) => interpreter
.elaborator
.resolve_trait_by_path(path)
.ok_or(InterpreterError::FailedToResolveTraitDefinition { location }),
_ => Err(InterpreterError::TraitDefinitionMustBeAPath { location }),
}?;
push_arg(Value::TraitDefinition(trait_id));
} else {
let expr_id = interpreter.elaborator.elaborate_expression(arg).0;
push_arg(interpreter.evaluate(expr_id)?);
}
}

if is_varargs {
let typ = varargs_type.unwrap_or(Type::Error);
new_arguments.push((Value::Slice(varargs, typ), location));
}

Ok(new_arguments)
}

fn add_items(
&mut self,
items: Vec<TopLevelStatement>,
generated_items: &mut CollectedItems,
location: Location,
) {
for item in items {
self.add_item(item, generated_items, location);
}
}

fn add_item(
&mut self,
item: TopLevelStatement,
generated_items: &mut CollectedItems,
location: Location,
) {
match item {
TopLevelStatement::Function(function) => {
let id = self.interner.push_empty_fn();
let module = self.module_id();
self.interner.push_function(id, &function.def, module, location);
let functions = vec![(self.local_module, id, function)];
generated_items.functions.push(UnresolvedFunctions {
file_id: self.file,
functions,
trait_id: None,
self_type: None,
});
}
TopLevelStatement::TraitImpl(mut trait_impl) => {
let methods = dc_mod::collect_trait_impl_functions(
self.interner,
&mut trait_impl,
self.crate_id,
self.file,
self.local_module,
);

generated_items.trait_impls.push(UnresolvedTraitImpl {
file_id: self.file,
module_id: self.local_module,
trait_generics: trait_impl.trait_generics,
trait_path: trait_impl.trait_name,
object_type: trait_impl.object_type,
methods,
generics: trait_impl.impl_generics,
where_clause: trait_impl.where_clause,

// These last fields are filled in later
trait_id: None,
impl_id: None,
resolved_object_type: None,
resolved_generics: Vec::new(),
resolved_trait_generics: Vec::new(),
});
}
TopLevelStatement::Global(global) => {
let (global, error) = dc_mod::collect_global(
self.interner,
self.def_maps.get_mut(&self.crate_id).unwrap(),
global,
self.file,
self.local_module,
self.crate_id,
);

generated_items.globals.push(global);
if let Some(error) = error {
self.errors.push(error);
}
}
// Assume that an error has already been issued
TopLevelStatement::Error => (),

TopLevelStatement::Module(_)
| TopLevelStatement::Import(_)
| TopLevelStatement::Struct(_)
| TopLevelStatement::Trait(_)
| TopLevelStatement::Impl(_)
| TopLevelStatement::TypeAlias(_)
| TopLevelStatement::SubModule(_) => {
let item = item.to_string();
let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location };
self.errors.push(error.into_compilation_error_pair());
}
}
}

pub fn setup_interpreter<'local>(&'local mut self) -> Interpreter<'local, 'context> {
let current_function = match self.current_item {
Some(DependencyId::Function(function)) => Some(function),
_ => None,
};
Interpreter::new(self, self.crate_id, current_function)
}

pub(super) fn debug_comptime<T: Display, F: FnMut(&mut NodeInterner) -> T>(
&mut self,
location: Location,
mut expr_f: F,
) {
if Some(location.file) == self.debug_comptime_in_file {
let displayed_expr = expr_f(self.interner);
self.errors.push((
InterpreterError::debug_evaluate_comptime(displayed_expr, location).into(),
location.file,
));
}
}

/// Run all the attributes on each item. The ordering is unspecified to users but currently
/// we run trait attributes first to (e.g.) register derive handlers before derive is
/// called on structs.
/// Returns any new items generated by attributes.
pub(super) fn run_attributes(
&mut self,
traits: &BTreeMap<TraitId, UnresolvedTrait>,
types: &BTreeMap<StructId, UnresolvedStruct>,
functions: &[UnresolvedFunctions],
) -> CollectedItems {
let mut generated_items = CollectedItems::default();

for (trait_id, trait_) in traits {
let attributes = &trait_.trait_def.attributes;
let item = Value::TraitDefinition(*trait_id);
let span = trait_.trait_def.span;
self.local_module = trait_.module_id;
self.file = trait_.file_id;
self.run_comptime_attributes_on_item(attributes, item, span, &mut generated_items);
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
}

for (struct_id, struct_def) in types {
let attributes = &struct_def.struct_def.attributes;
let item = Value::StructDefinition(*struct_id);
let span = struct_def.struct_def.span;
self.local_module = struct_def.module_id;
self.file = struct_def.file_id;
self.run_comptime_attributes_on_item(attributes, item, span, &mut generated_items);
}

self.run_attributes_on_functions(functions, &mut generated_items);
generated_items
}

fn run_attributes_on_functions(
&mut self,
function_sets: &[UnresolvedFunctions],
generated_items: &mut CollectedItems,
) {
for function_set in function_sets {
self.file = function_set.file_id;
self.self_type = function_set.self_type.clone();

for (local_module, function_id, function) in &function_set.functions {
self.local_module = *local_module;
let attributes = function.secondary_attributes();
let item = Value::FunctionDefinition(*function_id);
let span = function.span();
self.run_comptime_attributes_on_item(attributes, item, span, generated_items);
}
}
}
}
Loading
Loading