From 31130dc7aec24a7a7b9f342df94b14f295eb2103 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Mon, 16 Sep 2024 13:12:39 -0300 Subject: [PATCH] feat: add `Expr::as_lambda` (#6048) # Description ## Problem Part of #5668 ## Summary ## Additional Contex ## Documentation Check one: - [ ] No documentation needed. - [x] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../src/hir/comptime/interpreter/builtin.rs | 63 +++++++++++++++++++ docs/docs/noir/standard_library/meta/expr.md | 6 ++ noir_stdlib/src/meta/expr.nr | 43 +++++++++++++ .../comptime_expr/src/main.nr | 42 +++++++++++++ 4 files changed, 154 insertions(+) diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 407cdc216db..7298514b53c 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -83,6 +83,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "expr_as_if" => expr_as_if(interner, arguments, return_type, location), "expr_as_index" => expr_as_index(interner, arguments, return_type, location), "expr_as_integer" => expr_as_integer(interner, arguments, return_type, location), + "expr_as_lambda" => expr_as_lambda(interner, arguments, return_type, location), "expr_as_let" => expr_as_let(interner, arguments, return_type, location), "expr_as_member_access" => { expr_as_member_access(interner, arguments, return_type, location) @@ -1612,6 +1613,68 @@ fn expr_as_integer( }) } +// fn as_lambda(self) -> Option<([(Expr, Option)], Option, Expr)> +fn expr_as_lambda( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(interner, arguments, return_type.clone(), location, |expr| { + if let ExprValue::Expression(ExpressionKind::Lambda(lambda)) = expr { + // ([(Expr, Option)], Option, Expr) + let option_type = extract_option_generic_type(return_type); + let Type::Tuple(mut tuple_types) = option_type else { + panic!("Expected the return type option generic arg to be a tuple"); + }; + assert_eq!(tuple_types.len(), 3); + + // Expr + tuple_types.pop().unwrap(); + + // Option + let option_unresolved_type = tuple_types.pop().unwrap(); + + let parameters = lambda + .parameters + .into_iter() + .map(|(pattern, typ)| { + let pattern = Value::pattern(pattern); + let typ = if let UnresolvedTypeData::Unspecified = typ.typ { + None + } else { + Some(Value::UnresolvedType(typ.typ)) + }; + let typ = option(option_unresolved_type.clone(), typ).unwrap(); + Value::Tuple(vec![pattern, typ]) + }) + .collect(); + let parameters = Value::Slice( + parameters, + Type::Slice(Box::new(Type::Tuple(vec![ + Type::Quoted(QuotedType::Expr), + Type::Quoted(QuotedType::UnresolvedType), + ]))), + ); + + let return_type = lambda.return_type.typ; + let return_type = if let UnresolvedTypeData::Unspecified = return_type { + None + } else { + Some(return_type) + }; + let return_type = return_type.map(Value::UnresolvedType); + let return_type = option(option_unresolved_type, return_type).ok()?; + + let body = Value::expression(lambda.body.kind); + + Some(Value::Tuple(vec![parameters, return_type, body])) + } else { + None + } + }) +} + // fn as_let(self) -> Option<(Expr, Option, Expr)> fn expr_as_let( interner: &NodeInterner, diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md index ddbbcd7cdde..4e2d09102b5 100644 --- a/docs/docs/noir/standard_library/meta/expr.md +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -116,6 +116,12 @@ array and the index. If this expression is an integer literal, return the integer as a field as well as whether the integer is negative (true) or not (false). +### as_lambda + +#include_code as_lambda noir_stdlib/src/meta/expr.nr rust + +If this expression is a lambda, returns the parameters, return type and body. + ### as_let #include_code as_let noir_stdlib/src/meta/expr.nr rust diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index caf3fa172c4..c96f7d27442 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -115,6 +115,12 @@ impl Expr { comptime fn as_integer(self) -> Option<(Field, bool)> {} // docs:end:as_integer + /// If this expression is a lambda, returns the parameters, return type and body. + #[builtin(expr_as_lambda)] + // docs:start:as_lambda + comptime fn as_lambda(self) -> Option<([(Expr, Option)], Option, Expr)> {} + // docs:end:as_lambda + /// If this expression is a let statement, returns the let pattern as an `Expr`, /// the optional type annotation, and the assigned expression. #[builtin(expr_as_let)] @@ -234,6 +240,7 @@ impl Expr { let result = result.or_else(|| modify_index(self, f)); let result = result.or_else(|| modify_for(self, f)); let result = result.or_else(|| modify_for_range(self, f)); + let result = result.or_else(|| modify_lambda(self, f)); let result = result.or_else(|| modify_let(self, f)); let result = result.or_else(|| modify_function_call(self, f)); let result = result.or_else(|| modify_member_access(self, f)); @@ -427,6 +434,17 @@ comptime fn modify_for_range(expr: Expr, f: fn[Env](Expr) -> Option) ) } +comptime fn modify_lambda(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_lambda().map( + |expr: ([(Expr, Option)], Option, Expr)| { + let (params, return_type, body) = expr; + let params = params.map(|param: (Expr, Option)| (param.0.modify(f), param.1)); + let body = body.modify(f); + new_lambda(params, return_type, body) + } + ) +} + comptime fn modify_let(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_let().map( |expr: (Expr, Option, Expr)| { @@ -599,6 +617,31 @@ comptime fn new_index(object: Expr, index: Expr) -> Expr { quote { $object[$index] }.as_expr().unwrap() } +comptime fn new_lambda( + params: [(Expr, Option)], + return_type: Option, + body: Expr +) -> Expr { + let params = params.map( + |param: (Expr, Option)| { + let (name, typ) = param; + if typ.is_some() { + let typ = typ.unwrap(); + quote { $name: $typ } + } else { + quote { $name } + } + } + ).join(quote { , }); + + if return_type.is_some() { + let return_type = return_type.unwrap(); + quote { |$params| -> $return_type { $body } }.as_expr().unwrap() + } else { + quote { |$params| { $body } }.as_expr().unwrap() + } +} + comptime fn new_let(pattern: Expr, typ: Option, expr: Expr) -> Expr { if typ.is_some() { let typ = typ.unwrap(); diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index 50b10c45e59..709180879a0 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -635,6 +635,48 @@ mod tests { } } + #[test] + fn test_expr_as_lambda() { + comptime + { + let expr = quote { |x: Field| -> Field { 1 } }.as_expr().unwrap(); + let (params, return_type, body) = expr.as_lambda().unwrap(); + assert_eq(params.len(), 1); + assert(params[0].1.unwrap().is_field()); + assert(return_type.unwrap().is_field()); + assert_eq(body.as_block().unwrap()[0].as_integer().unwrap(), (1, false)); + + let expr = quote { |x| { 1 } }.as_expr().unwrap(); + let (params, return_type, body) = expr.as_lambda().unwrap(); + assert_eq(params.len(), 1); + assert(params[0].1.is_none()); + assert(return_type.is_none()); + assert_eq(body.as_block().unwrap()[0].as_integer().unwrap(), (1, false)); + } + } + + #[test] + fn test_expr_modify_lambda() { + comptime + { + let expr = quote { |x: Field| -> Field { 1 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (params, return_type, body) = expr.as_lambda().unwrap(); + assert_eq(params.len(), 1); + assert(params[0].1.unwrap().is_field()); + assert(return_type.unwrap().is_field()); + assert_eq(body.as_block().unwrap()[0].as_block().unwrap()[0].as_integer().unwrap(), (2, false)); + + let expr = quote { |x| { 1 } }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (params, return_type, body) = expr.as_lambda().unwrap(); + assert_eq(params.len(), 1); + assert(params[0].1.is_none()); + assert(return_type.is_none()); + assert_eq(body.as_block().unwrap()[0].as_block().unwrap()[0].as_integer().unwrap(), (2, false)); + } + } + #[test] fn test_expr_as_let() { comptime