Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 150 additions & 9 deletions tooling/lsp/src/requests/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use noirc_frontend::{
ast::{
AsTraitPath, AttributeTarget, BlockExpression, CallExpression, ConstructorExpression,
Expression, ExpressionKind, ForLoopStatement, GenericTypeArgs, Ident, IfExpression,
ItemVisibility, Lambda, LetStatement, MemberAccessExpression, MethodCallExpression,
ItemVisibility, LValue, Lambda, LetStatement, MemberAccessExpression, MethodCallExpression,
NoirFunction, NoirStruct, NoirTraitImpl, Path, PathKind, Pattern, Statement,
TraitImplItemKind, TypeImpl, UnresolvedGeneric, UnresolvedGenerics, UnresolvedType,
UnresolvedTypeData, UseTree, UseTreeKind, Visitor,
Expand All @@ -29,7 +29,7 @@ use noirc_frontend::{
node_interner::ReferenceId,
parser::{Item, ItemKind, ParsedSubModule},
token::CustomAttribute,
ParsedModule, StructType, Type,
ParsedModule, StructType, Type, TypeBinding,
};
use sort_text::underscore_sort_text;

Expand Down Expand Up @@ -551,6 +551,7 @@ impl<'a> NodeFinder<'a> {
function_completion_kind: FunctionCompletionKind,
self_prefix: bool,
) {
let typ = &typ;
match typ {
Type::Struct(struct_type, generics) => {
self.complete_struct_fields(&struct_type.borrow(), generics, prefix, self_prefix);
Expand All @@ -575,6 +576,16 @@ impl<'a> NodeFinder<'a> {
Type::Tuple(types) => {
self.complete_tuple_fields(types, self_prefix);
}
Type::TypeVariable(var, _) | Type::NamedGeneric(var, _, _) => {
if let TypeBinding::Bound(typ) = &*var.borrow() {
self.complete_type_fields_and_methods(
typ,
prefix,
function_completion_kind,
self_prefix,
);
}
}
Type::FieldElement
| Type::Array(_, _)
| Type::Slice(_)
Expand All @@ -583,9 +594,7 @@ impl<'a> NodeFinder<'a> {
| Type::String(_)
| Type::FmtString(_, _)
| Type::Unit
| Type::TypeVariable(_, _)
| Type::TraitAsType(_, _, _)
| Type::NamedGeneric(_, _, _)
| Type::Function(..)
| Type::Forall(_, _)
| Type::Constant(_)
Expand Down Expand Up @@ -932,7 +941,8 @@ impl<'a> NodeFinder<'a> {
if let Some(ReferenceId::Local(definition_id)) =
self.interner.find_referenced(location)
{
self.self_type = Some(self.interner.definition_type(definition_id));
self.self_type =
Some(self.interner.definition_type(definition_id).follow_bindings());
}
}
}
Expand All @@ -941,6 +951,32 @@ impl<'a> NodeFinder<'a> {
}
}

fn get_lvalue_type(&self, lvalue: &LValue) -> Option<Type> {
match lvalue {
LValue::Ident(ident) => {
let location = Location::new(ident.span(), self.file);
if let Some(ReferenceId::Local(definition_id)) =
self.interner.find_referenced(location)
{
let typ = self.interner.definition_type(definition_id);
Some(typ)
} else {
None
}
}
LValue::MemberAccess { object, field_name, .. } => {
let typ = self.get_lvalue_type(object)?;
get_field_type(&typ, &field_name.0.contents)
}
LValue::Index { array, .. } => {
let typ = self.get_lvalue_type(array)?;
get_array_element_type(typ)
}
LValue::Dereference(lvalue, ..) => self.get_lvalue_type(lvalue),
LValue::Interned(..) => None,
}
}

fn includes_span(&self, span: Span) -> bool {
span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize
}
Expand Down Expand Up @@ -1153,7 +1189,6 @@ impl<'a> Visitor for NodeFinder<'a> {
if after_dot && call_expression.func.span.end() as usize == self.byte_index - 1 {
let location = Location::new(call_expression.func.span, self.file);
if let Some(typ) = self.interner.type_at_location(location) {
let typ = typ.follow_bindings();
let prefix = "";
let self_prefix = false;
self.complete_type_fields_and_methods(
Expand Down Expand Up @@ -1184,7 +1219,6 @@ impl<'a> Visitor for NodeFinder<'a> {
if self.includes_span(method_call_expression.method_name.span()) {
let location = Location::new(method_call_expression.object.span, self.file);
if let Some(typ) = self.interner.type_at_location(location) {
let typ = typ.follow_bindings();
let prefix = method_call_expression.method_name.to_string();
let offset =
self.byte_index - method_call_expression.method_name.span().start() as usize;
Expand Down Expand Up @@ -1258,6 +1292,7 @@ impl<'a> Visitor for NodeFinder<'a> {
}

fn visit_lvalue_ident(&mut self, ident: &Ident) {
// If we have `foo.>|<` we suggest `foo`'s type fields and methods
if self.byte == Some(b'.') && ident.span().end() as usize == self.byte_index - 1 {
let location = Location::new(ident.span(), self.file);
if let Some(ReferenceId::Local(definition_id)) = self.interner.find_referenced(location)
Expand All @@ -1275,6 +1310,72 @@ impl<'a> Visitor for NodeFinder<'a> {
}
}

fn visit_lvalue_member_access(
&mut self,
object: &LValue,
field_name: &Ident,
span: Span,
) -> bool {
// If we have `foo.bar.>|<` we solve the type of `foo`, get the field `bar`,
// then suggest methods of the resulting type.
if self.byte == Some(b'.') && span.end() as usize == self.byte_index - 1 {
if let Some(typ) = self.get_lvalue_type(object) {
if let Some(typ) = get_field_type(&typ, &field_name.0.contents) {
let prefix = "";
let self_prefix = false;
self.complete_type_fields_and_methods(
&typ,
prefix,
FunctionCompletionKind::NameAndParameters,
self_prefix,
);
}
}

return false;
}
true
}

fn visit_lvalue_index(&mut self, array: &LValue, _index: &Expression, span: Span) -> bool {
// If we have `foo[index].>|<` we solve the type of `foo`, then get the array/slice element type,
// then suggest methods of that type.
if self.byte == Some(b'.') && span.end() as usize == self.byte_index - 1 {
if let Some(typ) = self.get_lvalue_type(array) {
if let Some(typ) = get_array_element_type(typ) {
let prefix = "";
let self_prefix = false;
self.complete_type_fields_and_methods(
&typ,
prefix,
FunctionCompletionKind::NameAndParameters,
self_prefix,
);
}
}
return false;
}
true
}

fn visit_lvalue_dereference(&mut self, lvalue: &LValue, span: Span) -> bool {
if self.byte == Some(b'.') && span.end() as usize == self.byte_index - 1 {
if let Some(typ) = self.get_lvalue_type(lvalue) {
let prefix = "";
let self_prefix = false;
self.complete_type_fields_and_methods(
&typ,
prefix,
FunctionCompletionKind::NameAndParameters,
self_prefix,
);
}
return false;
}

true
}

fn visit_variable(&mut self, path: &Path, _: Span) -> bool {
self.find_in_path(path, RequestedItems::AnyItems);
false
Expand All @@ -1294,7 +1395,6 @@ impl<'a> Visitor for NodeFinder<'a> {
{
let location = Location::new(expression.span, self.file);
if let Some(typ) = self.interner.type_at_location(location) {
let typ = typ.follow_bindings();
let prefix = "";
let self_prefix = false;
self.complete_type_fields_and_methods(
Expand Down Expand Up @@ -1364,7 +1464,6 @@ impl<'a> Visitor for NodeFinder<'a> {
// Assuming member_access_expression is of the form `foo.bar`, we are right after `bar`
let location = Location::new(member_access_expression.lhs.span, self.file);
if let Some(typ) = self.interner.type_at_location(location) {
let typ = typ.follow_bindings();
let prefix = ident.to_string().to_case(Case::Snake);
let self_prefix = false;
self.complete_type_fields_and_methods(
Expand Down Expand Up @@ -1443,6 +1542,48 @@ impl<'a> Visitor for NodeFinder<'a> {
}
}

fn get_field_type(typ: &Type, name: &str) -> Option<Type> {
match typ {
Type::Struct(struct_type, generics) => {
Some(struct_type.borrow().get_field(name, generics)?.0)
}
Type::Tuple(types) => {
if let Ok(index) = name.parse::<i32>() {
types.get(index as usize).cloned()
} else {
None
}
}
Type::Alias(alias_type, generics) => Some(alias_type.borrow().get_type(generics)),
Type::TypeVariable(var, _) | Type::NamedGeneric(var, _, _) => {
if let TypeBinding::Bound(typ) = &*var.borrow() {
get_field_type(typ, name)
} else {
None
}
}
_ => None,
}
}

fn get_array_element_type(typ: Type) -> Option<Type> {
match typ {
Type::Array(_, typ) | Type::Slice(typ) => Some(*typ),
Type::Alias(alias_type, generics) => {
let typ = alias_type.borrow().get_type(&generics);
get_array_element_type(typ)
}
Type::TypeVariable(var, _) | Type::NamedGeneric(var, _, _) => {
if let TypeBinding::Bound(typ) = &*var.borrow() {
get_array_element_type(typ.clone())
} else {
None
}
}
_ => None,
}
}

/// Returns true if name matches a prefix written in code.
/// `prefix` must already be in snake case.
/// This method splits both name and prefix by underscore,
Expand Down
104 changes: 104 additions & 0 deletions tooling/lsp/src/requests/completion/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2015,4 +2015,108 @@ mod completion_tests {
)
.await;
}

#[test]
async fn test_suggests_when_assignment_follows_in_chain_1() {
let src = r#"
struct Foo {
bar: Bar
}

struct Bar {
baz: Field
}

fn f(foo: Foo) {
let mut x = 1;

foo.bar.>|<

x = 2;
}"#;

assert_completion(src, vec![field_completion_item("baz", "Field")]).await;
}

#[test]
async fn test_suggests_when_assignment_follows_in_chain_2() {
let src = r#"
struct Foo {
bar: Bar
}

struct Bar {
baz: Baz
}

struct Baz {
qux: Field
}

fn f(foo: Foo) {
let mut x = 1;

foo.bar.baz.>|<

x = 2;
}"#;

assert_completion(src, vec![field_completion_item("qux", "Field")]).await;
}

#[test]
async fn test_suggests_when_assignment_follows_in_chain_3() {
let src = r#"
struct Foo {
foo: Field
}

fn execute() {
let a = Foo { foo: 1 };
a.>|<

x = 1;
}"#;

assert_completion(src, vec![field_completion_item("foo", "Field")]).await;
}

#[test]
async fn test_suggests_when_assignment_follows_in_chain_4() {
let src = r#"
struct Foo {
bar: Bar
}

struct Bar {
baz: Field
}

fn execute() {
let foo = Foo { foo: 1 };
foo.bar.>|<

x = 1;
}"#;

assert_completion(src, vec![field_completion_item("baz", "Field")]).await;
}

#[test]
async fn test_suggests_when_assignment_follows_in_chain_with_index() {
let src = r#"
struct Foo {
bar: Field
}

fn f(foos: [Foo; 3]) {
let mut x = 1;

foos[0].>|<

x = 2;
}"#;

assert_completion(src, vec![field_completion_item("bar", "Field")]).await;
}
}