Skip to content

Commit 0c6c637

Browse files
jfecherasterite
andauthored
feat(experimental): Implement enum tag constants (#7183)
Co-authored-by: Ary Borenszweig <asterite@gmail.com>
1 parent 5841122 commit 0c6c637

11 files changed

Lines changed: 144 additions & 40 deletions

File tree

compiler/noirc_frontend/src/ast/enumeration.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ impl NoirEnumeration {
3030
#[derive(Clone, Debug, PartialEq, Eq)]
3131
pub struct EnumVariant {
3232
pub name: Ident,
33-
pub parameters: Vec<UnresolvedType>,
33+
34+
/// This is None for tag variants without parameters.
35+
/// A value of `Some(vec![])` corresponds to a variant defined as `Foo()`
36+
/// with parenthesis but no parameters.
37+
pub parameters: Option<Vec<UnresolvedType>>,
3438
}
3539

3640
impl Display for NoirEnumeration {
@@ -41,8 +45,12 @@ impl Display for NoirEnumeration {
4145
writeln!(f, "enum {}{} {{", self.name, generics)?;
4246

4347
for variant in self.variants.iter() {
44-
let parameters = vecmap(&variant.item.parameters, ToString::to_string).join(", ");
45-
writeln!(f, " {}({}),", variant.item.name, parameters)?;
48+
if let Some(parameters) = &variant.item.parameters {
49+
let parameters = vecmap(parameters, ToString::to_string).join(", ");
50+
writeln!(f, " {}({}),", variant.item.name, parameters)?;
51+
} else {
52+
writeln!(f, " {},", variant.item.name)?;
53+
}
4654
}
4755

4856
write!(f, "}}")

compiler/noirc_frontend/src/ast/visitor.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,8 +795,10 @@ impl NoirEnumeration {
795795
}
796796

797797
for variant in &self.variants {
798-
for parameter in &variant.item.parameters {
799-
parameter.accept(visitor);
798+
if let Some(parameters) = &variant.item.parameters {
799+
for parameter in parameters {
800+
parameter.accept(visitor);
801+
}
800802
}
801803
}
802804
}

compiler/noirc_frontend/src/elaborator/enums.rs

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,104 @@ use crate::{
88
function::{FuncMeta, FunctionBody, HirFunction, Parameters},
99
stmt::HirPattern,
1010
},
11-
node_interner::{DefinitionKind, FuncId, FunctionModifiers, TypeId},
11+
node_interner::{DefinitionKind, ExprId, FunctionModifiers, GlobalValue, TypeId},
1212
token::Attributes,
1313
DataType, Shared, Type,
1414
};
1515

1616
use super::Elaborator;
1717

1818
impl Elaborator<'_> {
19+
/// Defines the value of an enum variant that we resolve an enum
20+
/// variant expression to. E.g. `Foo::Bar` in `Foo::Bar(baz)`.
21+
///
22+
/// If the variant requires arguments we should define a function,
23+
/// otherwise we define a polymorphic global containing the tag value.
1924
#[allow(clippy::too_many_arguments)]
20-
pub(super) fn define_enum_variant_function(
25+
pub(super) fn define_enum_variant_constructor(
26+
&mut self,
27+
enum_: &NoirEnumeration,
28+
type_id: TypeId,
29+
variant: &EnumVariant,
30+
variant_arg_types: Option<Vec<Type>>,
31+
variant_index: usize,
32+
datatype: &Shared<DataType>,
33+
self_type: &Type,
34+
self_type_unresolved: UnresolvedType,
35+
) {
36+
match variant_arg_types {
37+
Some(args) => self.define_enum_variant_function(
38+
enum_,
39+
type_id,
40+
variant,
41+
args,
42+
variant_index,
43+
datatype,
44+
self_type,
45+
self_type_unresolved,
46+
),
47+
None => self.define_enum_variant_global(
48+
enum_,
49+
type_id,
50+
variant,
51+
variant_index,
52+
datatype,
53+
self_type,
54+
),
55+
}
56+
}
57+
58+
#[allow(clippy::too_many_arguments)]
59+
fn define_enum_variant_global(
60+
&mut self,
61+
enum_: &NoirEnumeration,
62+
type_id: TypeId,
63+
variant: &EnumVariant,
64+
variant_index: usize,
65+
datatype: &Shared<DataType>,
66+
self_type: &Type,
67+
) {
68+
let name = &variant.name;
69+
let location = Location::new(variant.name.span(), self.file);
70+
71+
let global_id = self.interner.push_empty_global(
72+
name.clone(),
73+
type_id.local_module_id(),
74+
type_id.krate(),
75+
self.file,
76+
Vec::new(),
77+
false,
78+
false,
79+
);
80+
81+
let mut typ = self_type.clone();
82+
if !datatype.borrow().generics.is_empty() {
83+
let typevars = vecmap(&datatype.borrow().generics, |generic| generic.type_var.clone());
84+
typ = Type::Forall(typevars, Box::new(typ));
85+
}
86+
87+
let definition_id = self.interner.get_global(global_id).definition_id;
88+
self.interner.push_definition_type(definition_id, typ.clone());
89+
90+
let no_parameters = Parameters(Vec::new());
91+
let global_body =
92+
self.make_enum_variant_constructor(datatype, variant_index, &no_parameters, location);
93+
let let_statement = crate::hir_def::stmt::HirStatement::Expression(global_body);
94+
95+
let statement_id = self.interner.get_global(global_id).let_statement;
96+
self.interner.replace_statement(statement_id, let_statement);
97+
98+
self.interner.get_global_mut(global_id).value = GlobalValue::Resolved(
99+
crate::hir::comptime::Value::Enum(variant_index, Vec::new(), typ),
100+
);
101+
102+
Self::get_module_mut(self.def_maps, type_id.module_id())
103+
.declare_global(name.clone(), enum_.visibility, global_id)
104+
.ok();
105+
}
106+
107+
#[allow(clippy::too_many_arguments)]
108+
fn define_enum_variant_function(
21109
&mut self,
22110
enum_: &NoirEnumeration,
23111
type_id: TypeId,
@@ -48,7 +136,10 @@ impl Elaborator<'_> {
48136

49137
let hir_name = HirIdent::non_trait_method(definition_id, location);
50138
let parameters = self.make_enum_variant_parameters(variant_arg_types, location);
51-
self.push_enum_variant_function_body(id, datatype, variant_index, &parameters, location);
139+
140+
let body =
141+
self.make_enum_variant_constructor(datatype, variant_index, &parameters, location);
142+
self.interner.update_fn(id, HirFunction::unchecked_from_expr(body));
52143

53144
let function_type =
54145
datatype_ref.variant_function_type_with_forall(variant_index, datatype.clone());
@@ -106,14 +197,13 @@ impl Elaborator<'_> {
106197
// }
107198
// }
108199
// ```
109-
fn push_enum_variant_function_body(
200+
fn make_enum_variant_constructor(
110201
&mut self,
111-
id: FuncId,
112202
self_type: &Shared<DataType>,
113203
variant_index: usize,
114204
parameters: &Parameters,
115205
location: Location,
116-
) {
206+
) -> ExprId {
117207
// Each parameter of the enum variant function is used as a parameter of the enum
118208
// constructor expression
119209
let arguments = vecmap(&parameters.0, |(pattern, typ, _)| match pattern {
@@ -126,18 +216,18 @@ impl Elaborator<'_> {
126216
_ => unreachable!(),
127217
});
128218

129-
let enum_generics = self_type.borrow().generic_types();
130-
let construct_variant = HirExpression::EnumConstructor(HirEnumConstructorExpression {
219+
let constructor = HirExpression::EnumConstructor(HirEnumConstructorExpression {
131220
r#type: self_type.clone(),
132221
arguments,
133222
variant_index,
134223
});
135-
let body = self.interner.push_expr(construct_variant);
136-
self.interner.update_fn(id, HirFunction::unchecked_from_expr(body));
137224

225+
let body = self.interner.push_expr(constructor);
226+
let enum_generics = self_type.borrow().generic_types();
138227
let typ = Type::DataType(self_type.clone(), enum_generics);
139228
self.interner.push_expr_type(body, typ);
140229
self.interner.push_expr_location(body, location.span, location.file);
230+
body
141231
}
142232

143233
fn make_enum_variant_parameters(

compiler/noirc_frontend/src/elaborator/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,16 +1841,16 @@ impl<'context> Elaborator<'context> {
18411841
let module_id = ModuleId { krate: self.crate_id, local_id: typ.module_id };
18421842

18431843
for (i, variant) in typ.enum_def.variants.iter().enumerate() {
1844-
let types = vecmap(&variant.item.parameters, |typ| self.resolve_type(typ.clone()));
1844+
let parameters = variant.item.parameters.as_ref();
1845+
let types =
1846+
parameters.map(|params| vecmap(params, |typ| self.resolve_type(typ.clone())));
18451847
let name = variant.item.name.clone();
18461848

1847-
// false here is for the eventual change to allow enum "constants" rather than
1848-
// always having them be called as functions. This can be replaced with an actual
1849-
// check once #7172 is implemented.
1850-
datatype.borrow_mut().push_variant(EnumVariant::new(name, types.clone(), false));
1849+
let is_function = types.is_some();
1850+
let params = types.clone().unwrap_or_default();
1851+
datatype.borrow_mut().push_variant(EnumVariant::new(name, params, is_function));
18511852

1852-
// Define a function for each variant to construct it
1853-
self.define_enum_variant_function(
1853+
self.define_enum_variant_constructor(
18541854
&typ.enum_def,
18551855
*type_id,
18561856
&variant.item,

compiler/noirc_frontend/src/hir/comptime/interpreter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> {
12941294
id: ExprId,
12951295
) -> IResult<Value> {
12961296
let fields = try_vecmap(constructor.arguments, |arg| self.evaluate(arg))?;
1297-
let typ = self.elaborator.interner.id_type(id).follow_bindings();
1297+
let typ = self.elaborator.interner.id_type(id).unwrap_forall().1.follow_bindings();
12981298
Ok(Value::Enum(constructor.variant_index, fields, typ))
12991299
}
13001300

compiler/noirc_frontend/src/hir/comptime/value.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ impl Value {
405405
})
406406
}
407407
Value::Enum(variant_index, args, typ) => {
408-
let r#type = match typ.follow_bindings() {
408+
// Enum constants can have generic types but aren't functions
409+
let r#type = match typ.unwrap_forall().1.follow_bindings() {
409410
Type::DataType(def, _) => def,
410411
_ => return Err(InterpreterError::NonEnumInConstructor { typ, location }),
411412
};

compiler/noirc_frontend/src/monomorphization/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2209,7 +2209,7 @@ fn unwrap_enum_type(
22092209
typ: &HirType,
22102210
location: Location,
22112211
) -> Result<Vec<(String, Vec<HirType>)>, MonomorphizationError> {
2212-
match typ.follow_bindings() {
2212+
match typ.unwrap_forall().1.follow_bindings() {
22132213
HirType::DataType(def, args) => {
22142214
// Some of args might not be mentioned in fields, so we need to check that they aren't unbound.
22152215
for arg in &args {

compiler/noirc_frontend/src/parser/parser/enums.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,10 @@ impl<'a> Parser<'a> {
9292
self.bump();
9393
}
9494

95-
let mut parameters = Vec::new();
96-
97-
if self.eat_left_paren() {
95+
let parameters = self.eat_left_paren().then(|| {
9896
let comma_separated = separated_by_comma_until_right_paren();
99-
parameters = self.parse_many("variant parameters", comma_separated, Self::parse_type);
100-
}
97+
self.parse_many("variant parameters", comma_separated, Self::parse_type)
98+
});
10199

102100
Some(Documented::new(EnumVariant { name, parameters }, doc_comments))
103101
}
@@ -189,18 +187,19 @@ mod tests {
189187
let variant = noir_enum.variants.remove(0).item;
190188
assert_eq!("X", variant.name.to_string());
191189
assert!(matches!(
192-
variant.parameters[0].typ,
190+
variant.parameters.as_ref().unwrap()[0].typ,
193191
UnresolvedTypeData::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo)
194192
));
195193

196194
let variant = noir_enum.variants.remove(0).item;
197195
assert_eq!("y", variant.name.to_string());
198-
assert!(matches!(variant.parameters[0].typ, UnresolvedTypeData::FieldElement));
199-
assert!(matches!(variant.parameters[1].typ, UnresolvedTypeData::Integer(..)));
196+
let parameters = variant.parameters.as_ref().unwrap();
197+
assert!(matches!(parameters[0].typ, UnresolvedTypeData::FieldElement));
198+
assert!(matches!(parameters[1].typ, UnresolvedTypeData::Integer(..)));
200199

201200
let variant = noir_enum.variants.remove(0).item;
202201
assert_eq!("Z", variant.name.to_string());
203-
assert_eq!(variant.parameters.len(), 0);
202+
assert!(variant.parameters.is_none());
204203
}
205204

206205
#[test]

test_programs/compile_success_empty/comptime_enums/src/main.nr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ fn main() {
22
comptime {
33
let _two = Foo::Couple(1, 2);
44
let _one = Foo::One(3);
5-
let _none = Foo::None();
5+
let _none = Foo::None;
66
}
77
}
88

test_programs/compile_success_empty/enums/src/main.nr

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ fn main() {
33
let _b: Foo<u16> = Foo::B(3);
44
let _c = Foo::C(4);
55

6-
// (#7172): Single variant enums must be called as functions currently
76
let _d: fn() -> Foo<(i32, i32)> = Foo::D;
87
let _d: Foo<(i32, i32)> = Foo::D();
8+
let _e: Foo<u16> = Foo::E;
9+
let _e: Foo<u32> = Foo::E; // Ensure we can still use Foo::E polymorphically
910

1011
// Enum variants are functions and can be passed around as such
1112
let _many_cs = [1, 2, 3].map(Foo::C);
@@ -15,5 +16,6 @@ enum Foo<T> {
1516
A(Field, Field),
1617
B(u32),
1718
C(T),
18-
D,
19+
D(),
20+
E,
1921
}

0 commit comments

Comments
 (0)