Skip to content

Commit 742b8d0

Browse files
committed
feat!: type-check trait default methods
1 parent 011fbc1 commit 742b8d0

File tree

12 files changed

+233
-73
lines changed

12 files changed

+233
-73
lines changed

compiler/noirc_frontend/src/ast/expression.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -821,8 +821,8 @@ impl FunctionDefinition {
821821
is_unconstrained: bool,
822822
generics: &UnresolvedGenerics,
823823
parameters: &[(Ident, UnresolvedType)],
824-
body: &BlockExpression,
825-
where_clause: &[UnresolvedTraitConstraint],
824+
body: BlockExpression,
825+
where_clause: Vec<UnresolvedTraitConstraint>,
826826
return_type: &FunctionReturnType,
827827
) -> FunctionDefinition {
828828
let p = parameters
@@ -843,9 +843,9 @@ impl FunctionDefinition {
843843
visibility: ItemVisibility::Private,
844844
generics: generics.clone(),
845845
parameters: p,
846-
body: body.clone(),
846+
body,
847847
span: name.span(),
848-
where_clause: where_clause.to_vec(),
848+
where_clause,
849849
return_type: return_type.clone(),
850850
return_visibility: Visibility::Private,
851851
}

compiler/noirc_frontend/src/elaborator/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ impl<'context> Elaborator<'context> {
328328
self.elaborate_functions(functions);
329329
}
330330

331+
for (trait_id, unresolved_trait) in items.traits {
332+
self.current_trait = Some(trait_id);
333+
self.elaborate_functions(unresolved_trait.fns_with_default_impl);
334+
}
335+
self.current_trait = None;
336+
331337
for impls in items.impls.into_values() {
332338
self.elaborate_impls(impls);
333339
}

compiler/noirc_frontend/src/elaborator/patterns.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ impl<'context> Elaborator<'context> {
856856

857857
let impl_kind = match method {
858858
HirMethodReference::FuncId(_) => ImplKind::NotATraitMethod,
859-
HirMethodReference::TraitMethodId(method_id, generics) => {
859+
HirMethodReference::TraitMethodId(method_id, generics, _) => {
860860
let mut constraint =
861861
self.interner.get_trait(method_id.trait_id).as_constraint(span);
862862
constraint.trait_bound.trait_generics = generics;

compiler/noirc_frontend/src/elaborator/traits.rs

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ impl<'context> Elaborator<'context> {
2828
self.recover_generics(|this| {
2929
this.current_trait = Some(*trait_id);
3030

31+
let the_trait = this.interner.get_trait(*trait_id);
32+
let self_typevar = the_trait.self_type_typevar.clone();
33+
let self_type = Type::TypeVariable(self_typevar.clone());
34+
this.self_type = Some(self_type.clone());
35+
3136
let resolved_generics = this.interner.get_trait(*trait_id).generics.clone();
3237
this.add_existing_generics(
3338
&unresolved_trait.trait_def.generics,
@@ -48,12 +53,15 @@ impl<'context> Elaborator<'context> {
4853
.add_trait_dependency(DependencyId::Trait(bound.trait_id), *trait_id);
4954
}
5055

56+
this.interner.update_trait(*trait_id, |trait_def| {
57+
trait_def.set_trait_bounds(resolved_trait_bounds);
58+
trait_def.set_where_clause(where_clause);
59+
});
60+
5161
let methods = this.resolve_trait_methods(*trait_id, unresolved_trait);
5262

5363
this.interner.update_trait(*trait_id, |trait_def| {
5464
trait_def.set_methods(methods);
55-
trait_def.set_trait_bounds(resolved_trait_bounds);
56-
trait_def.set_where_clause(where_clause);
5765
});
5866
});
5967

@@ -94,7 +102,7 @@ impl<'context> Elaborator<'context> {
94102
parameters,
95103
return_type,
96104
where_clause,
97-
body: _,
105+
body,
98106
is_unconstrained,
99107
visibility: _,
100108
is_comptime: _,
@@ -103,7 +111,6 @@ impl<'context> Elaborator<'context> {
103111
self.recover_generics(|this| {
104112
let the_trait = this.interner.get_trait(trait_id);
105113
let self_typevar = the_trait.self_type_typevar.clone();
106-
let self_type = Type::TypeVariable(self_typevar.clone());
107114
let name_span = the_trait.name.span();
108115

109116
this.add_existing_generic(
@@ -115,9 +122,12 @@ impl<'context> Elaborator<'context> {
115122
span: name_span,
116123
},
117124
);
118-
this.self_type = Some(self_type.clone());
119125

120126
let func_id = unresolved_trait.method_ids[&name.0.contents];
127+
let mut where_clause = where_clause.to_vec();
128+
129+
// Attach any trait constraints on the trait to the function
130+
where_clause.extend(unresolved_trait.trait_def.where_clause.clone());
121131

122132
this.resolve_trait_function(
123133
trait_id,
@@ -127,6 +137,7 @@ impl<'context> Elaborator<'context> {
127137
parameters,
128138
return_type,
129139
where_clause,
140+
body,
130141
func_id,
131142
);
132143

@@ -188,20 +199,22 @@ impl<'context> Elaborator<'context> {
188199
generics: &UnresolvedGenerics,
189200
parameters: &[(Ident, UnresolvedType)],
190201
return_type: &FunctionReturnType,
191-
where_clause: &[UnresolvedTraitConstraint],
202+
where_clause: Vec<UnresolvedTraitConstraint>,
203+
body: &Option<BlockExpression>,
192204
func_id: FuncId,
193205
) {
194-
let old_generic_count = self.generics.len();
195-
196-
self.scopes.start_function();
206+
let body = match body {
207+
Some(body) => body.clone(),
208+
None => BlockExpression { statements: Vec::new() },
209+
};
197210

198211
let kind = FunctionKind::Normal;
199212
let mut def = FunctionDefinition::normal(
200213
name,
201214
is_unconstrained,
202215
generics,
203216
parameters,
204-
&BlockExpression { statements: Vec::new() },
217+
body,
205218
where_clause,
206219
return_type,
207220
);
@@ -210,10 +223,6 @@ impl<'context> Elaborator<'context> {
210223

211224
let mut function = NoirFunction { kind, def };
212225
self.define_function_meta(&mut function, func_id, Some(trait_id));
213-
self.elaborate_function(func_id);
214-
let _ = self.scopes.end_function();
215-
// Don't check the scope tree for unused variables, they can't be used in a declaration anyway.
216-
self.generics.truncate(old_generic_count);
217226
}
218227
}
219228

compiler/noirc_frontend/src/elaborator/types.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,12 +566,17 @@ impl<'context> Elaborator<'context> {
566566
}
567567

568568
// this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type)
569+
// or inside a trait default method.
569570
//
570571
// Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not
571572
// E.g. `t.method()` with `where T: Foo<Bar>` in scope will return `(Foo::method, T, vec![Bar])`
572573
fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option<TraitPathResolution> {
573-
let trait_impl = self.current_trait_impl?;
574-
let trait_id = self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id;
574+
let trait_id = if let Some(current_trait) = self.current_trait {
575+
current_trait
576+
} else {
577+
let trait_impl = self.current_trait_impl?;
578+
self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id
579+
};
575580

576581
if path.kind == PathKind::Plain && path.segments.len() == 2 {
577582
let name = &path.segments[0].ident.0.contents;
@@ -1395,6 +1400,25 @@ impl<'context> Elaborator<'context> {
13951400
};
13961401
let func_meta = self.interner.function_meta(&func_id);
13971402

1403+
// If inside a trait method, check if it's a method on `self`
1404+
if let Some(trait_id) = func_meta.trait_id {
1405+
if Some(object_type) == self.self_type.as_ref() {
1406+
let the_trait = self.interner.get_trait(trait_id);
1407+
let constraint = the_trait.as_constraint(the_trait.name.span());
1408+
if let Some(HirMethodReference::TraitMethodId(method_id, generics, _)) = self
1409+
.lookup_method_in_trait(
1410+
the_trait,
1411+
method_name,
1412+
&constraint.trait_bound,
1413+
the_trait.id,
1414+
)
1415+
{
1416+
// If it is, it's an assumed trait
1417+
return Some(HirMethodReference::TraitMethodId(method_id, generics, true));
1418+
}
1419+
}
1420+
}
1421+
13981422
for constraint in &func_meta.trait_constraints {
13991423
if *object_type == constraint.typ {
14001424
if let Some(the_trait) =
@@ -1432,6 +1456,7 @@ impl<'context> Elaborator<'context> {
14321456
return Some(HirMethodReference::TraitMethodId(
14331457
trait_method,
14341458
trait_bound.trait_generics.clone(),
1459+
false,
14351460
));
14361461
}
14371462

compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,8 @@ impl<'a> ModCollector<'a> {
518518
*is_unconstrained,
519519
generics,
520520
parameters,
521-
body,
522-
where_clause,
521+
body.clone(),
522+
where_clause.clone(),
523523
return_type,
524524
));
525525
unresolved_functions.push_fn(

compiler/noirc_frontend/src/hir_def/expr.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,14 @@ pub enum HirMethodReference {
209209
/// Or a method can come from a Trait impl block, in which case
210210
/// the actual function called will depend on the instantiated type,
211211
/// which can be only known during monomorphization.
212-
TraitMethodId(TraitMethodId, TraitGenerics),
212+
TraitMethodId(TraitMethodId, TraitGenerics, bool /* assumed */),
213213
}
214214

215215
impl HirMethodReference {
216216
pub fn func_id(&self, interner: &NodeInterner) -> Option<FuncId> {
217217
match self {
218218
HirMethodReference::FuncId(func_id) => Some(*func_id),
219-
HirMethodReference::TraitMethodId(method_id, _) => {
219+
HirMethodReference::TraitMethodId(method_id, _, _) => {
220220
let id = interner.trait_method_id(*method_id);
221221
match &interner.try_definition(id)?.kind {
222222
DefinitionKind::Function(func_id) => Some(*func_id),
@@ -246,7 +246,7 @@ impl HirMethodCallExpression {
246246
HirMethodReference::FuncId(func_id) => {
247247
(interner.function_definition_id(func_id), ImplKind::NotATraitMethod)
248248
}
249-
HirMethodReference::TraitMethodId(method_id, trait_generics) => {
249+
HirMethodReference::TraitMethodId(method_id, trait_generics, assumed) => {
250250
let id = interner.trait_method_id(method_id);
251251
let constraint = TraitConstraint {
252252
typ: object_type,
@@ -256,7 +256,8 @@ impl HirMethodCallExpression {
256256
span: location.span,
257257
},
258258
};
259-
(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed: false }))
259+
260+
(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed }))
260261
}
261262
};
262263
let func_var = HirIdent { location, id, impl_kind };

compiler/noirc_frontend/src/hir_def/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,12 @@ pub enum FunctionBody {
175175

176176
impl FuncMeta {
177177
/// A stub function does not have a body. This includes Builtin, LowLevel,
178-
/// and Oracle functions in addition to method declarations within a trait.
178+
/// and Oracle functions.
179179
///
180180
/// We don't check the return type of these functions since it will always have
181181
/// an empty body, and we don't check for unused parameters.
182182
pub fn is_stub(&self) -> bool {
183-
self.kind.can_ignore_return_type() || self.trait_id.is_some()
183+
self.kind.can_ignore_return_type()
184184
}
185185

186186
pub fn function_signature(&self) -> FunctionSignature {

compiler/noirc_frontend/src/node_interner.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use std::borrow::Cow;
22
use std::fmt;
33
use std::hash::Hash;
44
use std::marker::Copy;
5-
use std::ops::Deref;
65

76
use fm::FileId;
87
use iter_extended::vecmap;
@@ -1478,25 +1477,6 @@ impl NodeInterner {
14781477
Ok(impl_kind)
14791478
}
14801479

1481-
/// Given a `ObjectType: TraitId` pair, find all implementations without taking constraints into account or
1482-
/// applying any type bindings. Useful to look for a specific trait in a type that is used in a macro.
1483-
pub fn lookup_all_trait_implementations(
1484-
&self,
1485-
object_type: &Type,
1486-
trait_id: TraitId,
1487-
) -> Vec<&TraitImplKind> {
1488-
let trait_impl = self.trait_implementation_map.get(&trait_id);
1489-
1490-
let trait_impl = trait_impl.map(|trait_impl| {
1491-
let impls = trait_impl.iter().filter_map(|(typ, impl_kind)| match &typ {
1492-
Type::Forall(_, typ) => (typ.deref() == object_type).then_some(impl_kind),
1493-
_ => None,
1494-
});
1495-
impls.collect()
1496-
});
1497-
trait_impl.unwrap_or_default()
1498-
}
1499-
15001480
/// Similar to `lookup_trait_implementation` but does not apply any type bindings on success.
15011481
/// On error returns either:
15021482
/// - 1+ failing trait constraints, including the original.

compiler/noirc_frontend/src/tests.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2942,7 +2942,7 @@ fn uses_self_type_inside_trait() {
29422942
fn uses_self_type_in_trait_where_clause() {
29432943
let src = r#"
29442944
pub trait Trait {
2945-
fn trait_func() -> bool;
2945+
fn trait_func(self) -> bool;
29462946
}
29472947
29482948
pub trait Foo where Self: Trait {
@@ -2963,6 +2963,7 @@ fn uses_self_type_in_trait_where_clause() {
29632963
"#;
29642964

29652965
let errors = get_program_errors(src);
2966+
dbg!(&errors);
29662967
assert_eq!(errors.len(), 2);
29672968

29682969
let CompilationError::ResolverError(ResolverError::TraitNotImplemented { .. }) = &errors[0].0

0 commit comments

Comments
 (0)