diff --git a/apollo-federation/src/connectors/expand/mod.rs b/apollo-federation/src/connectors/expand/mod.rs index b312302b2d..79165d90d1 100644 --- a/apollo-federation/src/connectors/expand/mod.rs +++ b/apollo-federation/src/connectors/expand/mod.rs @@ -409,20 +409,20 @@ mod helpers { // without at least a root-level Query) let parent_pos = ObjectTypeDefinitionPosition { - type_name: parent_type.name.clone(), + type_name: parent_type.name().clone(), }; self.insert_object_and_field(&mut schema, &parent_pos, field_def)?; self.ensure_query_root_type( &mut schema, &query_alias, - Some(&parent_type.name), + Some(parent_type.name()), )?; if let Some(mutation_alias) = mutation_alias { self.ensure_mutation_root_type( &mut schema, &mutation_alias, - &parent_type.name, + parent_type.name(), )?; } @@ -430,11 +430,11 @@ mod helpers { self.process_outputs( &mut schema, connector, - parent_type.name.clone(), + parent_type.name().clone(), field_def.ty.inner_named_type().clone(), )?; } - ConnectedElement::Type { type_def } => { + ConnectedElement::Type { type_ref } => { SchemaVisitor::new( self.original_schema, &mut schema, @@ -442,7 +442,7 @@ mod helpers { ) .walk(( ObjectTypeDefinitionPosition { - type_name: type_def.name.clone(), + type_name: type_ref.name().clone(), }, connector .selection @@ -460,8 +460,8 @@ mod helpers { self.process_outputs( &mut schema, connector, - type_def.name.clone(), - type_def.name.clone(), + type_ref.name().clone(), + type_ref.name().clone(), )?; } } diff --git a/apollo-federation/src/connectors/id.rs b/apollo-federation/src/connectors/id.rs index 169039618e..daf7e92461 100644 --- a/apollo-federation/src/connectors/id.rs +++ b/apollo-federation/src/connectors/id.rs @@ -4,14 +4,12 @@ use std::fmt::Formatter; use std::hash::Hash; use apollo_compiler::Name; -use apollo_compiler::Node; use apollo_compiler::Schema; use apollo_compiler::ast::FieldDefinition; use apollo_compiler::ast::NamedType; use apollo_compiler::schema::Component; -use apollo_compiler::schema::ExtendedType; -use apollo_compiler::schema::ObjectType; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::error::FederationError; use crate::schema::position::ObjectOrInterfaceFieldDirectivePosition; @@ -37,16 +35,7 @@ impl ConnectorPosition { ) -> Result, FederationError> { match self { Self::Field(pos) => Ok(ConnectedElement::Field { - parent_type: schema - .types - .get(pos.field.parent().type_name()) - .and_then(|ty| { - if let ExtendedType::Object(obj) = ty { - Some(obj) - } else { - None - } - }) + parent_type: SchemaTypeRef::new(schema, pos.field.parent().type_name()) .ok_or_else(|| { FederationError::internal("Parent type for connector not found") })?, @@ -62,16 +51,7 @@ impl ConnectorPosition { }, }), Self::Type(pos) => Ok(ConnectedElement::Type { - type_def: schema - .types - .get(&pos.type_name) - .and_then(|ty| { - if let ExtendedType::Object(obj) = ty { - Some(obj) - } else { - None - } - }) + type_ref: SchemaTypeRef::new(schema, &pos.type_name) .ok_or_else(|| FederationError::internal("Type for connector not found"))?, }), } @@ -173,12 +153,12 @@ impl ConnectorPosition { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum ConnectedElement<'schema> { Field { - parent_type: &'schema Node, + parent_type: SchemaTypeRef<'schema>, field_def: &'schema Component, parent_category: ObjectCategory, }, Type { - type_def: &'schema Node, + type_ref: SchemaTypeRef<'schema>, }, } @@ -186,7 +166,7 @@ impl ConnectedElement<'_> { pub(super) fn base_type_name(&self) -> NamedType { match self { ConnectedElement::Field { field_def, .. } => field_def.ty.inner_named_type().clone(), - ConnectedElement::Type { type_def } => type_def.name.clone(), + ConnectedElement::Type { type_ref } => type_ref.name().clone(), } } @@ -201,7 +181,7 @@ impl ConnectedElement<'_> { .as_ref() .is_some_and(|query| match self { ConnectedElement::Field { .. } => false, - ConnectedElement::Type { type_def } => type_def.name == query.name, + ConnectedElement::Type { type_ref } => type_ref.name() == query.name.as_str(), }) } @@ -212,7 +192,7 @@ impl ConnectedElement<'_> { .as_ref() .is_some_and(|mutation| match self { ConnectedElement::Field { .. } => false, - ConnectedElement::Type { type_def } => type_def.name == mutation.name, + ConnectedElement::Type { type_ref } => type_ref.name() == mutation.name.as_str(), }) } } @@ -231,8 +211,8 @@ impl Display for ConnectedElement<'_> { parent_type, field_def, .. - } => write!(f, "{}.{}", parent_type.name, field_def.name), - Self::Type { type_def } => write!(f, "{}", type_def.name), + } => write!(f, "{}.{}", parent_type.name(), field_def.name), + Self::Type { type_ref } => write!(f, "{}", type_ref.name()), } } } diff --git a/apollo-federation/src/connectors/mod.rs b/apollo-federation/src/connectors/mod.rs index f23ebeb8fd..527590811b 100644 --- a/apollo-federation/src/connectors/mod.rs +++ b/apollo-federation/src/connectors/mod.rs @@ -30,6 +30,7 @@ pub mod header; mod id; mod json_selection; mod models; +mod schema_type_ref; pub use models::ProblemLocation; pub mod runtime; pub(crate) mod spec; diff --git a/apollo-federation/src/connectors/schema_type_ref.rs b/apollo-federation/src/connectors/schema_type_ref.rs new file mode 100644 index 0000000000..21a4acaee3 --- /dev/null +++ b/apollo-federation/src/connectors/schema_type_ref.rs @@ -0,0 +1,272 @@ +use apollo_compiler::Name; +use apollo_compiler::Node; +use apollo_compiler::Schema; +use apollo_compiler::ast::FieldDefinition; +use apollo_compiler::ast::Type; +use apollo_compiler::collections::IndexMap; +use apollo_compiler::collections::IndexSet; +use apollo_compiler::schema::Component; +use apollo_compiler::schema::ExtendedType; +use apollo_compiler::schema::ObjectType; +use shape::Shape; + +/// A [`SchemaTypeRef`] is a `Copy`able reference to a named type within a +/// [`Schema`]. Because [`SchemaTypeRef`] holds a `&'schema Schema` reference to +/// the schema in question, it can perform operations like finding all the +/// concrete types of an interface or union, which requires full-schema +/// awareness. Other reference-like types, such as [`ExtendedType`], only +/// provide access to a single element, not the rest of the schema. In fact, as +/// you can get an [`&ExtendedType`] by calling [`SchemaTypeRef::extended`], you +/// can pretty much always safely use a [`SchemaTypeRef`] where you would have +/// previously used an [`ExtendedType`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct SchemaTypeRef<'schema>(&'schema Schema, &'schema Name, &'schema ExtendedType); + +impl<'schema> SchemaTypeRef<'schema> { + pub(super) fn new(schema: &'schema Schema, name: &str) -> Option { + schema + .types + .get_full(name) + .map(|(_index, name, extended)| Self(schema, name, extended)) + } + + #[allow(dead_code)] + pub(super) fn from_node( + schema: &'schema Schema, + node: &'schema Node, + ) -> Option { + SchemaTypeRef::new(schema, node.name.as_str()) + } + + pub(super) fn as_object_node(&self) -> Option<&'schema Node> { + if let ExtendedType::Object(obj) = self.2 { + Some(obj) + } else { + None + } + } + + fn shape_with_visited(&self, visited: &mut IndexSet) -> Shape { + let type_name = self.name().to_string(); + if visited.contains(&type_name) { + return Shape::name(&type_name, []); + } + visited.insert(type_name.clone()); + + let result = match self.extended() { + ExtendedType::Object(o) => { + // Check if we're being called from an abstract type (interface or union) + let from_abstract_parent = visited + .iter() + .last() + .and_then(|parent_name| SchemaTypeRef::new(self.0, parent_name)) + .map(|parent_ref| parent_ref.is_abstract()) + .unwrap_or(false); + + // Generate __typename field based on context + let typename_shape = if from_abstract_parent { + // Required typename when accessed via interface/union + Shape::string_value(self.name().as_str(), []) + } else { + // Optional typename when accessed directly + Shape::one( + [Shape::string_value(self.name().as_str(), []), Shape::none()], + [], + ) + }; + + // Build fields map with __typename first + let mut fields = Shape::empty_map(); + fields.insert("__typename".to_string(), typename_shape); + + // Add all the object's declared fields + for (name, field) in &o.fields { + fields.insert( + name.to_string(), + self.shape_from_type_with_visited(&field.ty, visited), + ); + } + + Shape::record(fields, []) + } + ExtendedType::Scalar(_) => Shape::unknown([]), + + ExtendedType::Enum(e) => { + // Enums are unions of their string values + Shape::one( + e.values + .keys() + .map(|value| Shape::string_value(value.as_str(), [])), + [], + ) + } + ExtendedType::Interface(i) => Shape::one( + self.0.types.values().filter_map(|extended_type| { + if let ExtendedType::Object(object_type) = extended_type { + if object_type.implements_interfaces.contains(&i.name) { + SchemaTypeRef::new(self.0, object_type.name.as_str()) + .map(|type_ref| type_ref.shape_with_visited(visited)) + } else { + None + } + } else { + None + } + }), + [], + ), + ExtendedType::Union(u) => Shape::one( + u.members.iter().filter_map(|member_name| { + SchemaTypeRef::new(self.0, member_name.as_str()) + .map(|type_ref| type_ref.shape_with_visited(visited)) + }), + [], + ), + ExtendedType::InputObject(i) => Shape::record( + i.fields + .iter() + .map(|(name, field)| { + ( + name.to_string(), + self.shape_from_type_with_visited(&field.ty, visited), + ) + }) + .collect(), + [], + ), + }; + + visited.swap_remove(&type_name); + result + } + + /// Helper to make a shape nullable (can be null) + fn nullable(&self, shape: Shape) -> Shape { + Shape::one([shape, Shape::null([])], []) + } + + #[expect(dead_code)] + pub(super) fn shape_from_type(&self, ty: &Type) -> Shape { + self.shape_from_type_with_visited(ty, &mut IndexSet::default()) + } + + fn shape_from_type_with_visited(&self, ty: &Type, visited: &mut IndexSet) -> Shape { + let inner_type_name = ty.inner_named_type(); + let base_shape = if visited.contains(inner_type_name.as_str()) { + // Avoid infinite recursion for circular references + Shape::name(inner_type_name.as_str(), []) + } else if let Some(named_type) = SchemaTypeRef::new(self.0, inner_type_name.as_str()) { + named_type.shape_with_visited(visited) + } else { + Shape::name(inner_type_name.as_str(), []) + }; + + match ty { + Type::Named(_) => self.nullable(base_shape), + Type::NonNullNamed(_) => base_shape, + Type::List(inner) => self.nullable(Shape::list( + self.shape_from_type_with_visited(inner, visited), + [], + )), + Type::NonNullList(inner) => { + Shape::list(self.shape_from_type_with_visited(inner, visited), []) + } + } + } + + #[expect(dead_code)] + pub(super) fn schema(&self) -> &'schema Schema { + self.0 + } + + pub(super) fn name(&self) -> &'schema Name { + self.1 + } + + pub(super) fn extended(&self) -> &'schema ExtendedType { + self.2 + } + + pub(super) fn is_object(&self) -> bool { + self.2.is_object() + } + + pub(super) fn is_interface(&self) -> bool { + self.2.is_interface() + } + + pub(super) fn is_union(&self) -> bool { + self.2.is_union() + } + + pub(super) fn is_abstract(&self) -> bool { + self.is_interface() || self.is_union() + } + + #[expect(dead_code)] + pub(super) fn is_input_object(&self) -> bool { + self.2.is_input_object() + } + + #[expect(dead_code)] + pub(super) fn is_enum(&self) -> bool { + self.2.is_enum() + } + + #[expect(dead_code)] + pub(super) fn is_scalar(&self) -> bool { + self.2.is_scalar() + } + + #[expect(dead_code)] + pub(super) fn is_built_in(&self) -> bool { + self.2.is_built_in() + } + + pub(super) fn get_fields( + &self, + field_name: &str, + ) -> IndexMap> { + self.0 + .types + .get(self.1) + .into_iter() + .flat_map(|ty| match ty { + ExtendedType::Object(o) => o + .fields + .get(field_name) + .map(|field_def| { + std::iter::once((o.name.to_string(), field_def)).collect::>() + }) + .unwrap_or_default(), + + ExtendedType::Interface(i) => self + .0 + .implementers_map() + .get(i.name.as_str()) + .map(|implementers| { + implementers + .objects + .iter() + .chain(&implementers.interfaces) + .filter_map(|name| SchemaTypeRef::new(self.0, name.as_str())) + .flat_map(|type_ref| type_ref.get_fields(field_name)) + .collect::>() + }) + .unwrap_or_default(), + + ExtendedType::Union(u) => u + .members + .iter() + .flat_map(|m| { + SchemaTypeRef::new(self.0, m.name.as_str()) + .map(|type_ref| type_ref.get_fields(field_name)) + .unwrap_or_default() + }) + .collect(), + + _ => IndexMap::default(), + }) + .collect() + } +} diff --git a/apollo-federation/src/connectors/validation/connect.rs b/apollo-federation/src/connectors/validation/connect.rs index 836105188f..f8dee99a33 100644 --- a/apollo-federation/src/connectors/validation/connect.rs +++ b/apollo-federation/src/connectors/validation/connect.rs @@ -9,7 +9,6 @@ use apollo_compiler::Node; use apollo_compiler::ast::Value; use apollo_compiler::parser::LineColumn; use apollo_compiler::schema::ExtendedType; -use apollo_compiler::schema::ObjectType; use hashbrown::HashSet; use itertools::Itertools; use multi_try::MultiTry; @@ -25,6 +24,7 @@ use crate::connectors::Namespace; use crate::connectors::SourceName; use crate::connectors::id::ConnectedElement; use crate::connectors::id::ObjectCategory; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::spec::connect::CONNECT_ID_ARGUMENT_NAME; use crate::connectors::spec::connect::CONNECT_SOURCE_ARGUMENT_NAME; use crate::connectors::spec::source::SOURCE_NAME_ARGUMENT_NAME; @@ -40,18 +40,21 @@ pub(super) fn fields_seen_by_all_connects( schema: &SchemaInfo, all_source_names: &[SourceName], ) -> Result, Vec> { - let mut messages = Vec::new(); - let mut connects = Vec::new(); - - for extended_type in schema.types.values().filter(|ty| !ty.is_built_in()) { - let ExtendedType::Object(node) = extended_type else { - continue; - }; - let (connects_for_type, messages_for_type) = - Connect::find_on_type(node, schema, all_source_names); - connects.extend(connects_for_type); - messages.extend(messages_for_type); - } + let (connects, messages): (Vec<_>, Vec<_>) = schema + .types + .iter() + .filter(|(_, ty)| !ty.is_built_in()) + .filter(|(_, ty)| { + matches!( + ty, + ExtendedType::Object(_) | ExtendedType::Interface(_) | ExtendedType::Union(_) + ) + }) + .filter_map(|(type_name, _)| SchemaTypeRef::new(schema.schema, type_name)) + .map(|type_ref| Connect::find_on_type(type_ref, schema, all_source_names)) + .unzip(); + let connects: Vec<_> = connects.into_iter().flatten().collect(); + let mut messages: Vec<_> = messages.into_iter().flatten().collect(); let mut seen_fields = Vec::new(); let mut valid_id_names: HashMap<_, Vec<_>> = HashMap::new(); @@ -136,7 +139,7 @@ struct Connect<'schema> { impl<'schema> Connect<'schema> { /// Find and parse any `@connect` directives on this type or its fields. fn find_on_type( - object: &'schema Node, + type_ref: SchemaTypeRef<'schema>, schema: &'schema SchemaInfo, source_names: &'schema [SourceName], ) -> (Vec, Vec) { @@ -144,43 +147,76 @@ impl<'schema> Connect<'schema> { .schema_definition .query .as_ref() - .is_some_and(|query| query.name == object.name) + .is_some_and(|query| query.name == *type_ref.name()) { ObjectCategory::Query } else if schema .schema_definition .mutation .as_ref() - .is_some_and(|mutation| mutation.name == object.name) + .is_some_and(|mutation| mutation.name == *type_ref.name()) { ObjectCategory::Mutation } else { ObjectCategory::Other }; - let directives_on_type = object - .directives - .iter() - .filter(|directive| directive.name == *schema.connect_directive_name()) - .map(|directive| ConnectDirectiveCoordinate { - directive, - element: ConnectedElement::Type { type_def: object }, - }); + let directives_on_type = match type_ref.extended() { + ExtendedType::Object(obj) => obj.directives.iter(), + ExtendedType::Interface(iface) => iface.directives.iter(), + ExtendedType::Union(union) => union.directives.iter(), + _ => [].iter(), // Other types (scalars, enums, etc.) don't support connectors + } + .filter(|directive| &directive.name == schema.connect_directive_name()) + .map(|directive| ConnectDirectiveCoordinate { + directive, + element: ConnectedElement::Type { type_ref }, + }); - let directives_on_fields = object.fields.values().flat_map(|field| { - field - .directives - .iter() - .filter(|directive| directive.name == *schema.connect_directive_name()) - .map(|directive| ConnectDirectiveCoordinate { - directive, - element: ConnectedElement::Field { - parent_type: object, - parent_category: object_category, - field_def: field, - }, + let directives_on_fields = match type_ref.extended() { + ExtendedType::Object(obj) => obj + .fields + .values() + .flat_map(|field| { + field + .directives + .iter() + .filter(|directive| &directive.name == schema.connect_directive_name()) + .map(|directive| ConnectDirectiveCoordinate { + directive, + element: ConnectedElement::Field { + parent_type: type_ref, + parent_category: object_category, + field_def: field, + }, + }) }) - }); + .collect::>(), + ExtendedType::Interface(iface) => iface + .fields + .values() + .flat_map(|field| { + field + .directives + .iter() + .filter(|directive| &directive.name == schema.connect_directive_name()) + .map(|directive| ConnectDirectiveCoordinate { + directive, + element: ConnectedElement::Field { + parent_type: type_ref, + parent_category: object_category, + field_def: field, + }, + }) + }) + .collect::>(), + ExtendedType::Union(_) => { + // Unions don't have fields, so no field-level directives + Vec::new() + } + _ => Vec::new(), + } + .into_iter(); let (connects, messages): (Vec, Vec>) = directives_on_type .chain(directives_on_fields) @@ -328,16 +364,16 @@ impl<'schema> Connect<'schema> { { // mark the field with a @connect directive as seen seen.push(ResolvedField { - object_name: parent_type.name.clone(), + object_name: parent_type.name().clone(), field_name: field_def.name.clone(), }); // direct recursion isn't allowed, like a connector on User.friends: [User] - if &parent_type.name == field_def.ty.inner_named_type() { + if parent_type.name() == field_def.ty.inner_named_type().as_str() { messages.push(Message { code: Code::CircularReference, message: format!( "Direct circular reference detected in `{}.{}: {}`. For more information, see https://go.apollo.dev/connectors/limitations#circular-references", - parent_type.name, + parent_type.name(), field_def.name, field_def.ty ), diff --git a/apollo-federation/src/connectors/validation/connect/entity.rs b/apollo-federation/src/connectors/validation/connect/entity.rs index 78875ef12b..f0c80a4b40 100644 --- a/apollo-federation/src/connectors/validation/connect/entity.rs +++ b/apollo-federation/src/connectors/validation/connect/entity.rs @@ -9,7 +9,6 @@ use apollo_compiler::ast::FieldDefinition; use apollo_compiler::ast::InputValueDefinition; use apollo_compiler::schema::ExtendedType; use apollo_compiler::schema::InputObjectType; -use apollo_compiler::schema::ObjectType; use super::Code; use super::Message; @@ -17,6 +16,7 @@ use super::ObjectCategory; use crate::connectors::expand::visitors::FieldVisitor; use crate::connectors::expand::visitors::GroupVisitor; use crate::connectors::id::ConnectedElement; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::spec::connect::CONNECT_ENTITY_ARGUMENT_NAME; use crate::connectors::validation::coordinates::ConnectDirectiveCoordinate; use crate::connectors::validation::graphql::SchemaInfo; @@ -84,7 +84,7 @@ pub(super) fn validate_entity_arg( }); } - let Some(object_type) = schema.get_object(field.ty.inner_named_type()) else { + let Some(object_type) = SchemaTypeRef::new(schema, field.ty.inner_named_type()) else { return Err(Message { code: Code::EntityTypeInvalid, message: format!( @@ -97,6 +97,21 @@ pub(super) fn validate_entity_arg( }); }; + // TODO: When abstract types (interfaces/unions) are supported for entity connectors, + // change this check to: !object_type.is_object() && !object_type.is_interface() && !object_type.is_union() + if !object_type.is_object() { + return Err(Message { + code: Code::EntityTypeInvalid, + message: format!( + "{coordinate} is invalid. Entity connectors must return object types.", + ), + locations: entity_arg + .line_column_range(&schema.sources) + .into_iter() + .collect(), + }); + } + if field.ty.is_list() || field.ty.is_non_null() { return Err(Message { code: Code::EntityTypeInvalid, @@ -141,12 +156,12 @@ enum Group<'schema> { /// The entity itself, we're matching argument names & types to these fields Root { field: &'schema Node, - entity_type: &'schema Node, + entity_type: SchemaTypeRef<'schema>, }, /// A child field of the entity we're matching against an input type. Child { input_type: &'schema Node, - entity_type: &'schema ExtendedType, + entity_type: SchemaTypeRef<'schema>, }, } @@ -154,11 +169,11 @@ enum Group<'schema> { struct Field<'schema> { node: &'schema Node, /// The object which has a field that we're comparing against - object_type: &'schema ObjectType, + object_type: SchemaTypeRef<'schema>, /// The field definition of the input that correlates to a field on the entity - input_field: &'schema ExtendedType, + input_field: SchemaTypeRef<'schema>, /// The field of the entity that we're comparing against, part of `object_type` - entity_field: &'schema ExtendedType, + entity_field: SchemaTypeRef<'schema>, } /// Visitor for entity resolver arguments. @@ -179,7 +194,7 @@ impl<'schema> GroupVisitor, Field<'schema>> for ArgumentVisitor<' ) -> Result>, Self::Error> { Ok( // Each input type within an argument to the entity field is another group to visit - if let ExtendedType::InputObject(input_object_type) = field.input_field { + if let ExtendedType::InputObject(input_object_type) = field.input_field.extended() { Some(Group::Child { input_type: input_object_type, entity_type: field.entity_field, @@ -194,12 +209,12 @@ impl<'schema> GroupVisitor, Field<'schema>> for ArgumentVisitor<' match group { Group::Root { field, entity_type, .. - } => self.enter_root_group(field, entity_type), + } => self.enter_root_group(field, *entity_type), Group::Child { input_type, entity_type, .. - } => self.enter_child_group(input_type, entity_type), + } => self.enter_child_group(input_type, *entity_type), } } @@ -212,7 +227,7 @@ impl<'schema> FieldVisitor> for ArgumentVisitor<'schema> { type Error = Message; fn visit(&mut self, field: Field<'schema>) -> Result<(), Self::Error> { - let ok = match field.input_field { + let ok = match field.input_field.extended() { ExtendedType::InputObject(_) => field.entity_field.is_object(), ExtendedType::Scalar(_) | ExtendedType::Enum(_) => { field.input_field == field.entity_field @@ -228,7 +243,7 @@ impl<'schema> FieldVisitor> for ArgumentVisitor<'schema> { "`{coordinate}({field_name}:)` is of type `{input_type}`, but must match `{object}.{field_name}` of type `{entity_type}` because `entity` is `true`.", coordinate = self.coordinate.connect.element, field_name = field.node.name.as_str(), - object = field.object_type.name, + object = field.object_type.name(), input_type = field.input_field.name(), entity_type = field.entity_field.name(), ), @@ -247,82 +262,128 @@ impl<'schema> ArgumentVisitor<'schema> { fn enter_root_group( &mut self, field: &'schema Node, - entity_type: &'schema Node, + entity_type: SchemaTypeRef<'schema>, ) -> Result>, >>::Error> { + let mut fields: Vec> = Vec::new(); + // At the root level, visit each argument to the entity field - field.arguments.iter().filter_map(|arg| { - if let Some(input_type) = self.schema.types.get(arg.ty.inner_named_type()) { - // Check that the argument has a corresponding field on the entity type - if let Some(entity_field) = entity_type.fields.get(&*arg.name) - .and_then(|entity_field| self.schema.types.get(entity_field.ty.inner_named_type())) { - Some(Ok(Field { - node: arg, - input_field: input_type, - entity_field, - object_type: entity_type, - })) - } else { - Some(Err(Message { + for arg in field.arguments.iter() { + // if let Some(input_type) = self.schema.types.get(arg.ty.inner_named_type()) { + if let Some(input_type) = SchemaTypeRef::new(self.schema, arg.ty.inner_named_type()) { + let fields_by_type_name = entity_type.get_fields(arg.name.as_str()); + if fields_by_type_name.is_empty() { + return Err(Message { code: Code::EntityResolverArgumentMismatch, message: format!( "`{coordinate}` has invalid arguments. Argument `{arg_name}` does not have a matching field `{arg_name}` on type `{entity_type}`.", coordinate = self.coordinate.connect.element, arg_name = &*arg.name, - entity_type = entity_type.name, + entity_type = entity_type.name(), ), locations: arg .line_column_range(&self.schema.sources) .into_iter() .chain(self.entity_arg.line_column_range(&self.schema.sources)) .collect(), - })) + }); } - } else { - // The input type is missing - this will be reported elsewhere, so just ignore - None + + fields.extend( + fields_by_type_name + .iter() + .flat_map(|(type_name, entity_field)| { + if let (Some(entity_type), Some(entity_field_type_ref)) = ( + // Look up concrete object type and use it instead + // of original entity_type. + SchemaTypeRef::new(self.schema, type_name.as_str()), + SchemaTypeRef::new(self.schema, entity_field.ty.inner_named_type()), + ) { + Some(Field { + node: arg, + input_field: input_type, + entity_field: entity_field_type_ref, + object_type: entity_type, + }) + } else { + None + } + }), + ); } - }).collect() + } + + Ok(fields) } fn enter_child_group( &mut self, child_input_type: &'schema Node, - entity_type: &'schema ExtendedType, + entity_type: SchemaTypeRef<'schema>, ) -> Result>, >>::Error> { + let mut fields = Vec::new(); + // At the child level, visit each field on the input type - let ExtendedType::Object(entity_object_type) = entity_type else { - // Entity type was not an object type - this will be reported by field visitor - return Ok(Vec::new()); - }; - child_input_type.fields.iter().filter_map(|(name, input_field)| { - if let Some(entity_field) = entity_object_type.fields.get(name) { - let entity_field_type = entity_field.ty.inner_named_type(); - let input_type = self.schema.types.get(input_field.ty.inner_named_type())?; - - self.schema.types.get(entity_field_type).map(|entity_type| Ok(Field { - node: input_field, - object_type: entity_object_type, - input_field: input_type, - entity_field: entity_type, - })) - } else { - // The input type field does not have a corresponding field on the entity type - Some(Err(Message { + for (name, input_field) in child_input_type.fields.iter() { + let field_type_name = input_field.ty.inner_named_type(); + let Some(input_type) = SchemaTypeRef::new(self.schema, field_type_name) else { + // Report an error if the input_field's type is not found in + // self.schema. + return Err(Message { + code: Code::MissingSchemaType, + message: format!( + "Input field `{name}` on `{child_input_type}` has unknown type {field_type_name}", + name = name, + child_input_type = child_input_type.name, + ), + locations: input_field + .line_column_range(&self.schema.sources) + .into_iter() + .collect(), + }); + }; + + let fields_by_type_name = entity_type.get_fields(name.as_str()); + if fields_by_type_name.is_empty() { + return Err(Message { code: Code::EntityResolverArgumentMismatch, message: format!( "`{coordinate}` has invalid arguments. Field `{name}` on `{input_type}` does not have a matching field `{name}` on `{entity_type}`.", coordinate = self.coordinate.connect.element, input_type = child_input_type.name, - entity_type = entity_object_type.name, + entity_type = entity_type.name(), ), locations: input_field .line_column_range(&self.schema.sources) .into_iter() .chain(self.entity_arg.line_column_range(&self.schema.sources)) .collect(), - })) + }); } - }).collect() + + fields.extend( + fields_by_type_name + .iter() + .flat_map(|(type_name, entity_field)| { + if let (Some(entity_type), Some(entity_field_type_ref)) = ( + // Look up concrete object type and use it instead of + // original entity_type. + SchemaTypeRef::new(self.schema, type_name.as_str()), + SchemaTypeRef::new(self.schema, entity_field.ty.inner_named_type()), + ) { + Some(Field { + node: input_field, + object_type: entity_type, + input_field: input_type, + entity_field: entity_field_type_ref, + }) + } else { + None + } + }), + ); + } + + Ok(fields) } } diff --git a/apollo-federation/src/connectors/validation/connect/selection.rs b/apollo-federation/src/connectors/validation/connect/selection.rs index 0da511bd2c..c0acb2e244 100644 --- a/apollo-federation/src/connectors/validation/connect/selection.rs +++ b/apollo-federation/src/connectors/validation/connect/selection.rs @@ -120,14 +120,19 @@ impl<'schema> Selection<'schema> { SelectionValidator::new( schema, - PathPart::Root(parent_type), + PathPart::Root(parent_type.as_object_node().ok_or_else(|| Message { + code: Code::GraphQLError, + message: "Parent type is not an object type".to_string(), + locations: vec![], + })?), &self.node, self.coordinate, ) .walk(group) .map(|validator| validator.seen_fields) } - ConnectedElement::Type { type_def } => { + + ConnectedElement::Type { type_ref } => { let Some(sub_selection) = self.parsed.next_subselection() else { // TODO: Validate scalar selections return Ok(Vec::new()); @@ -135,13 +140,21 @@ impl<'schema> Selection<'schema> { let group = Group { selection: sub_selection, - ty: type_def, + ty: type_ref.as_object_node().ok_or_else(|| Message { + code: Code::GraphQLError, + message: "Type definition is not an object type".to_string(), + locations: vec![], + })?, definition: None, }; SelectionValidator::new( schema, - PathPart::Root(type_def), + PathPart::Root(type_ref.as_object_node().ok_or_else(|| Message { + code: Code::GraphQLError, + message: "Type definition is not an object type".to_string(), + locations: vec![], + })?), &self.node, self.coordinate, ) diff --git a/apollo-federation/src/connectors/validation/connect/selection/variables.rs b/apollo-federation/src/connectors/validation/connect/selection/variables.rs index 2ef0af91fa..a0aff14190 100644 --- a/apollo-federation/src/connectors/validation/connect/selection/variables.rs +++ b/apollo-federation/src/connectors/validation/connect/selection/variables.rs @@ -9,11 +9,11 @@ use apollo_compiler::ast::Type; use apollo_compiler::ast::Value; use apollo_compiler::schema::Component; use apollo_compiler::schema::ExtendedType; -use apollo_compiler::schema::ObjectType; use itertools::Itertools; use crate::connectors::id::ConnectedElement; use crate::connectors::json_selection::SelectionTrie; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::validation::Code; use crate::connectors::validation::Message; use crate::connectors::validation::graphql::SchemaInfo; @@ -40,7 +40,7 @@ impl<'a> VariableResolver<'a> { } => { resolvers.insert( Namespace::This, - Box::new(ThisResolver::new(parent_type, field_def)), + Box::new(ThisResolver::new(*parent_type, field_def)), ); resolvers.insert(Namespace::Args, Box::new(ArgsResolver::new(field_def))); } @@ -176,13 +176,16 @@ fn resolve_path( /// Resolves variables in the `$this` namespace pub(crate) struct ThisResolver<'a> { - object: &'a ObjectType, + object_type: SchemaTypeRef<'a>, field: &'a Component, } impl<'a> ThisResolver<'a> { - pub(crate) const fn new(object: &'a ObjectType, field: &'a Component) -> Self { - Self { object, field } + pub(crate) const fn new( + object_type: SchemaTypeRef<'a>, + field: &'a Component, + ) -> Self { + Self { object_type, field } } } @@ -194,25 +197,26 @@ impl NamespaceResolver for ThisResolver<'_> { schema: &SchemaInfo, ) -> Result<(), Message> { for (root, sub_trie) in reference.selection.iter() { - let fields = &self.object.fields; + let fields = self.object_type.get_fields(root); - let field_type = fields - .get(root) - .ok_or_else(|| Message { + if fields.is_empty() { + return Err(Message { code: Code::UndefinedField, message: format!( "`{object}` does not have a field named `{root}`", - object = self.object.name, + object = self.object_type.name(), ), locations: reference .selection .key_ranges(root) .flat_map(|range| subslice_location(node, range, schema)) .collect(), - }) - .map(|field| field.ty.clone())?; + }); + } - resolve_path(schema, sub_trie, node, &field_type, self.field)?; + for field_def in fields.values() { + resolve_path(schema, sub_trie, node, &field_def.ty, self.field)?; + } } Ok(()) diff --git a/apollo-federation/src/connectors/validation/expression.rs b/apollo-federation/src/connectors/validation/expression.rs index 21227a0d48..4249ccef7f 100644 --- a/apollo-federation/src/connectors/validation/expression.rs +++ b/apollo-federation/src/connectors/validation/expression.rs @@ -23,6 +23,7 @@ use crate::connectors::Namespace; use crate::connectors::id::ConnectedElement; use crate::connectors::id::ObjectCategory; use crate::connectors::json_selection::VarPaths; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::string_template::Expression; use crate::connectors::validation::Code; use crate::connectors::validation::Message; @@ -95,7 +96,7 @@ impl<'schema> Context<'schema> { .collect(); if matches!(parent_category, ObjectCategory::Other) { - var_lookup.insert(Namespace::This, Shape::from(parent_type)); + var_lookup.insert(Namespace::This, shape_from_schema_type_ref(parent_type)); } Self { @@ -106,10 +107,13 @@ impl<'schema> Context<'schema> { has_response_body: false, } } - ConnectedElement::Type { type_def } => { + ConnectedElement::Type { type_ref } => { let var_lookup: IndexMap = [ - (Namespace::This, Shape::from(type_def)), - (Namespace::Batch, Shape::list(Shape::from(type_def), [])), + (Namespace::This, shape_from_schema_type_ref(type_ref)), + ( + Namespace::Batch, + Shape::list(shape_from_schema_type_ref(type_ref), []), + ), (Namespace::Config, Shape::unknown([])), (Namespace::Context, Shape::unknown([])), (Namespace::Request, REQUEST_SHAPE.clone()), @@ -156,7 +160,7 @@ impl<'schema> Context<'schema> { .collect(); if matches!(parent_category, ObjectCategory::Other) { - var_lookup.insert(Namespace::This, Shape::from(parent_type)); + var_lookup.insert(Namespace::This, shape_from_schema_type_ref(parent_type)); } Self { @@ -167,10 +171,13 @@ impl<'schema> Context<'schema> { has_response_body: true, } } - ConnectedElement::Type { type_def } => { + ConnectedElement::Type { type_ref } => { let var_lookup: IndexMap = [ - (Namespace::This, Shape::from(type_def)), - (Namespace::Batch, Shape::list(Shape::from(type_def), [])), + (Namespace::This, shape_from_schema_type_ref(type_ref)), + ( + Namespace::Batch, + Shape::list(shape_from_schema_type_ref(type_ref), []), + ), (Namespace::Config, Shape::unknown([])), (Namespace::Context, Shape::unknown([])), (Namespace::Status, Shape::int([])), @@ -265,6 +272,11 @@ impl<'schema> Context<'schema> { } } +/// Convert a SchemaTypeRef to a Shape by using its ExtendedType +fn shape_from_schema_type_ref(type_ref: SchemaTypeRef<'_>) -> Shape { + Shape::from(type_ref.extended()) +} + pub(crate) fn scalars() -> Shape { Shape::one( vec![ @@ -679,6 +691,7 @@ mod tests { use super::*; use crate::connectors::ConnectSpec; use crate::connectors::JSONSelection; + use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::validation::ConnectLink; fn expression(selection: &str, spec: ConnectSpec) -> Expression { @@ -758,7 +771,7 @@ mod tests { .clone(); let coordinate = ConnectDirectiveCoordinate { element: ConnectedElement::Field { - parent_type: object, + parent_type: SchemaTypeRef::from_node(&schema, object).unwrap(), field_def: field, parent_category: ObjectCategory::Query, }, diff --git a/apollo-federation/src/connectors/validation/mod.rs b/apollo-federation/src/connectors/validation/mod.rs index 181a3e6001..fe9ebcd20b 100644 --- a/apollo-federation/src/connectors/validation/mod.rs +++ b/apollo-federation/src/connectors/validation/mod.rs @@ -282,6 +282,8 @@ pub enum Code { ConnectBatchAndThis, /// Invalid URL property InvalidUrlProperty, + /// Any named type not found in a GraphQL schema where expected + MissingSchemaType, } impl Code { @@ -306,6 +308,7 @@ pub enum Severity { mod test_validate_source { use std::fs::read_to_string; + use insta::assert_debug_snapshot; use insta::assert_snapshot; use insta::glob; use pretty_assertions::assert_str_eq; @@ -320,7 +323,7 @@ mod test_validate_source { let start_time = std::time::Instant::now(); let result = validate(schema.clone(), path.to_str().unwrap()); let end_time = std::time::Instant::now(); - assert_snapshot!(format!("{:#?}", result.errors)); + assert_debug_snapshot!(result.errors); if path.parent().is_some_and(|parent| parent.ends_with("transformed")) { assert_snapshot!(&diff::lines(&schema, &result.transformed).into_iter().filter_map(|res| match res { diff::Result::Left(line) => Some(format!("- {line}")), diff --git a/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap b/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap index 92b863c8a9..7fe78eed4d 100644 --- a/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap +++ b/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap @@ -1,6 +1,7 @@ --- source: apollo-federation/src/connectors/validation/mod.rs -expression: "format!(\"{:#?}\", errors)" +assertion_line: 325 +expression: "format!(\"{:#?}\", result.errors)" input_file: apollo-federation/src/connectors/validation/test_data/keys_and_entities/invalid/entity_true_returning_scalar.graphql --- [