From c1329695b086e9c9cf1401f9dac4c03b68ff0515 Mon Sep 17 00:00:00 2001 From: Ben Newman Date: Tue, 19 Aug 2025 14:43:15 -0400 Subject: [PATCH 1/2] Use `SchemaTypeRef` instead of `Node` in `ConnectedElement`. A SchemaTypeRef is a reference to a &Schema as well as the &Name and &ExtendedType of some named type within the schema, all with the same 'schema lifetime. The SchemaTypeRef::new method returns an Option to guarantee that SchemaTypeRef references are only created for actual schema types. Because SchemaTypeRef has access to the Schema (not just an individual element), it can perform operations that are aware of other types within the same schema, like finding all the implementing/member types of an interface or union. Because SchemaTypeRef can represent any named type in the schema, not just ObjectType, these changes should lay some of the necessary groundwork for working with abstract types, where the type in question might be an InterfaceType or UnionType. After some earlier attempts, I've decided it's important to use a type like SchemaTypeRef that still implements the Copy trait, rather than (say) ExtendedType, so we don't have to update all the data structures that assume they can implement Copy while using ConnectedElement. --- .../src/connectors/expand/mod.rs | 16 +- apollo-federation/src/connectors/id.rs | 40 +-- apollo-federation/src/connectors/mod.rs | 1 + .../src/connectors/schema_type_ref.rs | 280 ++++++++++++++++++ .../src/connectors/validation/connect.rs | 109 ++++--- .../connectors/validation/connect/entity.rs | 169 +++++++---- .../validation/connect/selection.rs | 21 +- .../validation/connect/selection/variables.rs | 30 +- .../src/connectors/validation/expression.rs | 31 +- .../src/connectors/validation/mod.rs | 2 + ..._entity_true_returning_scalar.graphql.snap | 3 +- 11 files changed, 547 insertions(+), 155 deletions(-) create mode 100644 apollo-federation/src/connectors/schema_type_ref.rs diff --git a/apollo-federation/src/connectors/expand/mod.rs b/apollo-federation/src/connectors/expand/mod.rs index b312302b2d..79165d90d1 100644 --- a/apollo-federation/src/connectors/expand/mod.rs +++ b/apollo-federation/src/connectors/expand/mod.rs @@ -409,20 +409,20 @@ mod helpers { // without at least a root-level Query) let parent_pos = ObjectTypeDefinitionPosition { - type_name: parent_type.name.clone(), + type_name: parent_type.name().clone(), }; self.insert_object_and_field(&mut schema, &parent_pos, field_def)?; self.ensure_query_root_type( &mut schema, &query_alias, - Some(&parent_type.name), + Some(parent_type.name()), )?; if let Some(mutation_alias) = mutation_alias { self.ensure_mutation_root_type( &mut schema, &mutation_alias, - &parent_type.name, + parent_type.name(), )?; } @@ -430,11 +430,11 @@ mod helpers { self.process_outputs( &mut schema, connector, - parent_type.name.clone(), + parent_type.name().clone(), field_def.ty.inner_named_type().clone(), )?; } - ConnectedElement::Type { type_def } => { + ConnectedElement::Type { type_ref } => { SchemaVisitor::new( self.original_schema, &mut schema, @@ -442,7 +442,7 @@ mod helpers { ) .walk(( ObjectTypeDefinitionPosition { - type_name: type_def.name.clone(), + type_name: type_ref.name().clone(), }, connector .selection @@ -460,8 +460,8 @@ mod helpers { self.process_outputs( &mut schema, connector, - type_def.name.clone(), - type_def.name.clone(), + type_ref.name().clone(), + type_ref.name().clone(), )?; } } diff --git a/apollo-federation/src/connectors/id.rs b/apollo-federation/src/connectors/id.rs index 169039618e..daf7e92461 100644 --- a/apollo-federation/src/connectors/id.rs +++ b/apollo-federation/src/connectors/id.rs @@ -4,14 +4,12 @@ use std::fmt::Formatter; use std::hash::Hash; use apollo_compiler::Name; -use apollo_compiler::Node; use apollo_compiler::Schema; use apollo_compiler::ast::FieldDefinition; use apollo_compiler::ast::NamedType; use apollo_compiler::schema::Component; -use apollo_compiler::schema::ExtendedType; -use apollo_compiler::schema::ObjectType; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::error::FederationError; use crate::schema::position::ObjectOrInterfaceFieldDirectivePosition; @@ -37,16 +35,7 @@ impl ConnectorPosition { ) -> Result, FederationError> { match self { Self::Field(pos) => Ok(ConnectedElement::Field { - parent_type: schema - .types - .get(pos.field.parent().type_name()) - .and_then(|ty| { - if let ExtendedType::Object(obj) = ty { - Some(obj) - } else { - None - } - }) + parent_type: SchemaTypeRef::new(schema, pos.field.parent().type_name()) .ok_or_else(|| { FederationError::internal("Parent type for connector not found") })?, @@ -62,16 +51,7 @@ impl ConnectorPosition { }, }), Self::Type(pos) => Ok(ConnectedElement::Type { - type_def: schema - .types - .get(&pos.type_name) - .and_then(|ty| { - if let ExtendedType::Object(obj) = ty { - Some(obj) - } else { - None - } - }) + type_ref: SchemaTypeRef::new(schema, &pos.type_name) .ok_or_else(|| FederationError::internal("Type for connector not found"))?, }), } @@ -173,12 +153,12 @@ impl ConnectorPosition { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum ConnectedElement<'schema> { Field { - parent_type: &'schema Node, + parent_type: SchemaTypeRef<'schema>, field_def: &'schema Component, parent_category: ObjectCategory, }, Type { - type_def: &'schema Node, + type_ref: SchemaTypeRef<'schema>, }, } @@ -186,7 +166,7 @@ impl ConnectedElement<'_> { pub(super) fn base_type_name(&self) -> NamedType { match self { ConnectedElement::Field { field_def, .. } => field_def.ty.inner_named_type().clone(), - ConnectedElement::Type { type_def } => type_def.name.clone(), + ConnectedElement::Type { type_ref } => type_ref.name().clone(), } } @@ -201,7 +181,7 @@ impl ConnectedElement<'_> { .as_ref() .is_some_and(|query| match self { ConnectedElement::Field { .. } => false, - ConnectedElement::Type { type_def } => type_def.name == query.name, + ConnectedElement::Type { type_ref } => type_ref.name() == query.name.as_str(), }) } @@ -212,7 +192,7 @@ impl ConnectedElement<'_> { .as_ref() .is_some_and(|mutation| match self { ConnectedElement::Field { .. } => false, - ConnectedElement::Type { type_def } => type_def.name == mutation.name, + ConnectedElement::Type { type_ref } => type_ref.name() == mutation.name.as_str(), }) } } @@ -231,8 +211,8 @@ impl Display for ConnectedElement<'_> { parent_type, field_def, .. - } => write!(f, "{}.{}", parent_type.name, field_def.name), - Self::Type { type_def } => write!(f, "{}", type_def.name), + } => write!(f, "{}.{}", parent_type.name(), field_def.name), + Self::Type { type_ref } => write!(f, "{}", type_ref.name()), } } } diff --git a/apollo-federation/src/connectors/mod.rs b/apollo-federation/src/connectors/mod.rs index f23ebeb8fd..527590811b 100644 --- a/apollo-federation/src/connectors/mod.rs +++ b/apollo-federation/src/connectors/mod.rs @@ -30,6 +30,7 @@ pub mod header; mod id; mod json_selection; mod models; +mod schema_type_ref; pub use models::ProblemLocation; pub mod runtime; pub(crate) mod spec; diff --git a/apollo-federation/src/connectors/schema_type_ref.rs b/apollo-federation/src/connectors/schema_type_ref.rs new file mode 100644 index 0000000000..e0f1946d4b --- /dev/null +++ b/apollo-federation/src/connectors/schema_type_ref.rs @@ -0,0 +1,280 @@ +use apollo_compiler::Name; +use apollo_compiler::Node; +use apollo_compiler::Schema; +use apollo_compiler::ast::FieldDefinition; +use apollo_compiler::ast::Type; +use apollo_compiler::collections::IndexMap; +use apollo_compiler::collections::IndexSet; +use apollo_compiler::schema::Component; +use apollo_compiler::schema::ExtendedType; +use apollo_compiler::schema::ObjectType; +use shape::Shape; + +/// A [`SchemaTypeRef`] is a `Copy`able reference to a named type within a +/// [`Schema`]. Because [`SchemaTypeRef`] holds a `&'schema Schema` reference to +/// the schema in question, it can perform operations like finding all the +/// concrete types of an interface or union, which requires full-schema +/// awareness. Other reference-like types, such as [`ExtendedType`], only +/// provide access to a single element, not the rest of the schema. In fact, as +/// you can get an [`&ExtendedType`] by calling [`SchemaTypeRef::extended`], you +/// can pretty much always safely use a [`SchemaTypeRef`] where you would have +/// previously used an [`ExtendedType`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct SchemaTypeRef<'schema>(&'schema Schema, &'schema Name, &'schema ExtendedType); + +impl<'schema> SchemaTypeRef<'schema> { + pub(super) fn new(schema: &'schema Schema, name: &str) -> Option { + schema + .types + .get_full(name) + .map(|(_index, name, extended)| Self(schema, name, extended)) + } + + #[allow(dead_code)] + pub(super) fn from_node( + schema: &'schema Schema, + node: &'schema Node, + ) -> Option { + SchemaTypeRef::new(schema, node.name.as_str()) + } + + pub(super) fn as_object_node(&self) -> Option<&'schema Node> { + if let ExtendedType::Object(obj) = self.2 { + Some(obj) + } else { + None + } + } + + fn shape_with_visited(&self, visited: &mut IndexSet) -> Shape { + let type_name = self.name().to_string(); + if visited.contains(&type_name) { + return Shape::name(&type_name, []); + } + visited.insert(type_name.clone()); + + let result = match self.extended() { + ExtendedType::Object(o) => { + // Check if we're being called from an abstract type (interface or union) + let from_abstract_parent = visited + .iter() + .last() + .and_then(|parent_name| SchemaTypeRef::new(self.0, parent_name)) + .map(|parent_ref| parent_ref.is_abstract()) + .unwrap_or(false); + + // Generate __typename field based on context + let typename_shape = if from_abstract_parent { + // Required typename when accessed via interface/union + Shape::string_value(self.name().as_str(), []) + } else { + // Optional typename when accessed directly + Shape::one( + [Shape::string_value(self.name().as_str(), []), Shape::none()], + [], + ) + }; + + // Build fields map with __typename first + let mut fields = Shape::empty_map(); + fields.insert("__typename".to_string(), typename_shape); + + // Add all the object's declared fields + for (name, field) in &o.fields { + fields.insert( + name.to_string(), + self.shape_from_type_with_visited(&field.ty, visited), + ); + } + + Shape::record(fields, []) + } + ExtendedType::Scalar(_) => Shape::unknown([]), + + ExtendedType::Enum(e) => { + // Enums are unions of their string values + Shape::one( + e.values + .keys() + .map(|value| Shape::string_value(value.as_str(), [])), + [], + ) + } + ExtendedType::Interface(i) => Shape::one( + self.0.types.values().filter_map(|extended_type| { + if let ExtendedType::Object(object_type) = extended_type { + if object_type.implements_interfaces.contains(&i.name) { + SchemaTypeRef::new(self.0, object_type.name.as_str()) + .map(|type_ref| type_ref.shape_with_visited(visited)) + } else { + None + } + } else { + None + } + }), + [], + ), + ExtendedType::Union(u) => Shape::one( + u.members.iter().filter_map(|member_name| { + SchemaTypeRef::new(self.0, member_name.as_str()) + .map(|type_ref| type_ref.shape_with_visited(visited)) + }), + [], + ), + ExtendedType::InputObject(i) => Shape::record( + i.fields + .iter() + .map(|(name, field)| { + ( + name.to_string(), + self.shape_from_type_with_visited(&field.ty, visited), + ) + }) + .collect(), + [], + ), + }; + + visited.swap_remove(&type_name); + result + } + + /// Helper to make a shape nullable (can be null) + fn nullable(&self, shape: Shape) -> Shape { + Shape::one([shape, Shape::null([])], []) + } + + #[allow(dead_code)] + pub(super) fn shape_from_type(&self, ty: &Type) -> Shape { + self.shape_from_type_with_visited(ty, &mut IndexSet::default()) + } + + fn shape_from_type_with_visited(&self, ty: &Type, visited: &mut IndexSet) -> Shape { + let inner_type_name = ty.inner_named_type(); + let base_shape = if visited.contains(inner_type_name.as_str()) { + // Avoid infinite recursion for circular references + Shape::name(inner_type_name.as_str(), []) + } else if let Some(named_type) = SchemaTypeRef::new(self.0, inner_type_name.as_str()) { + named_type.shape_with_visited(visited) + } else { + Shape::name(inner_type_name.as_str(), []) + }; + + match ty { + Type::Named(_) => self.nullable(base_shape), + Type::NonNullNamed(_) => base_shape, + Type::List(inner) => self.nullable(Shape::list( + self.shape_from_type_with_visited(inner, visited), + [], + )), + Type::NonNullList(inner) => { + Shape::list(self.shape_from_type_with_visited(inner, visited), []) + } + } + } + + #[allow(dead_code)] + pub(super) fn schema(&self) -> &'schema Schema { + self.0 + } + + pub(super) fn name(&self) -> &'schema Name { + self.1 + } + + pub(super) fn extended(&self) -> &'schema ExtendedType { + self.2 + } + + #[allow(dead_code)] + pub(super) fn is_object(&self) -> bool { + self.2.is_object() + } + + #[allow(dead_code)] + pub(super) fn is_interface(&self) -> bool { + self.2.is_interface() + } + + #[allow(dead_code)] + pub(super) fn is_union(&self) -> bool { + self.2.is_union() + } + + #[allow(dead_code)] + pub(super) fn is_abstract(&self) -> bool { + self.is_interface() || self.is_union() + } + + #[allow(dead_code)] + pub(super) fn is_input_object(&self) -> bool { + self.2.is_input_object() + } + + #[allow(dead_code)] + pub(super) fn is_enum(&self) -> bool { + self.2.is_enum() + } + + #[allow(dead_code)] + pub(super) fn is_scalar(&self) -> bool { + self.2.is_scalar() + } + + #[allow(dead_code)] + pub(super) fn is_built_in(&self) -> bool { + self.2.is_built_in() + } + + pub(super) fn get_fields( + &self, + field_name: &str, + ) -> IndexMap> { + self.0 + .types + .get(self.1) + .into_iter() + .flat_map(|ty| match ty { + ExtendedType::Object(o) => { + let mut map = IndexMap::default(); + if let Some(field_def) = o.fields.get(field_name) { + map.insert(o.name.to_string(), field_def); + } + map + } + + ExtendedType::Interface(i) => { + let mut map = IndexMap::default(); + if let Some(implementers) = self.0.implementers_map().get(i.name.as_str()) { + for obj_name in &implementers.objects { + if let Some(impl_obj) = SchemaTypeRef::new(self.0, obj_name.as_str()) { + map.extend(impl_obj.get_fields(field_name).into_iter()); + } + } + for iface_name in &implementers.interfaces { + if let Some(impl_iface) = + SchemaTypeRef::new(self.0, iface_name.as_str()) + { + map.extend(impl_iface.get_fields(field_name).into_iter()); + } + } + } + map + } + + ExtendedType::Union(u) => u + .members + .iter() + .flat_map(|m| { + SchemaTypeRef::new(self.0, m.name.as_str()) + .map(|type_ref| type_ref.get_fields(field_name)) + .unwrap_or_default() + }) + .collect(), + + _ => IndexMap::default(), + }) + .collect() + } +} diff --git a/apollo-federation/src/connectors/validation/connect.rs b/apollo-federation/src/connectors/validation/connect.rs index 836105188f..0ecc9ca56a 100644 --- a/apollo-federation/src/connectors/validation/connect.rs +++ b/apollo-federation/src/connectors/validation/connect.rs @@ -9,7 +9,6 @@ use apollo_compiler::Node; use apollo_compiler::ast::Value; use apollo_compiler::parser::LineColumn; use apollo_compiler::schema::ExtendedType; -use apollo_compiler::schema::ObjectType; use hashbrown::HashSet; use itertools::Itertools; use multi_try::MultiTry; @@ -25,6 +24,7 @@ use crate::connectors::Namespace; use crate::connectors::SourceName; use crate::connectors::id::ConnectedElement; use crate::connectors::id::ObjectCategory; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::spec::connect::CONNECT_ID_ARGUMENT_NAME; use crate::connectors::spec::connect::CONNECT_SOURCE_ARGUMENT_NAME; use crate::connectors::spec::source::SOURCE_NAME_ARGUMENT_NAME; @@ -43,14 +43,18 @@ pub(super) fn fields_seen_by_all_connects( let mut messages = Vec::new(); let mut connects = Vec::new(); - for extended_type in schema.types.values().filter(|ty| !ty.is_built_in()) { - let ExtendedType::Object(node) = extended_type else { - continue; - }; - let (connects_for_type, messages_for_type) = - Connect::find_on_type(node, schema, all_source_names); - connects.extend(connects_for_type); - messages.extend(messages_for_type); + for (type_name, extended_type) in schema.types.iter().filter(|(_, ty)| !ty.is_built_in()) { + // Only check types that can have connectors (objects, interfaces, unions) + if matches!( + extended_type, + ExtendedType::Object(_) | ExtendedType::Interface(_) | ExtendedType::Union(_) + ) && let Some(type_ref) = SchemaTypeRef::new(schema.schema, type_name) + { + let (connects_for_type, messages_for_type) = + Connect::find_on_type(type_ref, schema, all_source_names); + connects.extend(connects_for_type); + messages.extend(messages_for_type); + } } let mut seen_fields = Vec::new(); @@ -136,7 +140,7 @@ struct Connect<'schema> { impl<'schema> Connect<'schema> { /// Find and parse any `@connect` directives on this type or its fields. fn find_on_type( - object: &'schema Node, + type_ref: SchemaTypeRef<'schema>, schema: &'schema SchemaInfo, source_names: &'schema [SourceName], ) -> (Vec, Vec) { @@ -144,43 +148,76 @@ impl<'schema> Connect<'schema> { .schema_definition .query .as_ref() - .is_some_and(|query| query.name == object.name) + .is_some_and(|query| query.name == *type_ref.name()) { ObjectCategory::Query } else if schema .schema_definition .mutation .as_ref() - .is_some_and(|mutation| mutation.name == object.name) + .is_some_and(|mutation| mutation.name == *type_ref.name()) { ObjectCategory::Mutation } else { ObjectCategory::Other }; - let directives_on_type = object - .directives - .iter() - .filter(|directive| directive.name == *schema.connect_directive_name()) - .map(|directive| ConnectDirectiveCoordinate { - directive, - element: ConnectedElement::Type { type_def: object }, - }); + let directives_on_type = match type_ref.extended() { + ExtendedType::Object(obj) => obj.directives.iter(), + ExtendedType::Interface(iface) => iface.directives.iter(), + ExtendedType::Union(union) => union.directives.iter(), + _ => [].iter(), // Other types (scalars, enums, etc.) don't support connectors + } + .filter(|directive| directive.name == *schema.connect_directive_name()) + .map(|directive| ConnectDirectiveCoordinate { + directive, + element: ConnectedElement::Type { type_ref }, + }); - let directives_on_fields = object.fields.values().flat_map(|field| { - field - .directives - .iter() - .filter(|directive| directive.name == *schema.connect_directive_name()) - .map(|directive| ConnectDirectiveCoordinate { - directive, - element: ConnectedElement::Field { - parent_type: object, - parent_category: object_category, - field_def: field, - }, + let directives_on_fields = match type_ref.extended() { + ExtendedType::Object(obj) => obj + .fields + .values() + .flat_map(|field| { + field + .directives + .iter() + .filter(|directive| directive.name == *schema.connect_directive_name()) + .map(|directive| ConnectDirectiveCoordinate { + directive, + element: ConnectedElement::Field { + parent_type: type_ref, + parent_category: object_category, + field_def: field, + }, + }) }) - }); + .collect::>(), + ExtendedType::Interface(iface) => iface + .fields + .values() + .flat_map(|field| { + field + .directives + .iter() + .filter(|directive| directive.name == *schema.connect_directive_name()) + .map(|directive| ConnectDirectiveCoordinate { + directive, + element: ConnectedElement::Field { + parent_type: type_ref, + parent_category: object_category, + field_def: field, + }, + }) + }) + .collect::>(), + ExtendedType::Union(_) => { + // Unions don't have fields, so no field-level directives + Vec::new() + } + _ => Vec::new(), + } + .into_iter(); let (connects, messages): (Vec, Vec>) = directives_on_type .chain(directives_on_fields) @@ -328,16 +365,16 @@ impl<'schema> Connect<'schema> { { // mark the field with a @connect directive as seen seen.push(ResolvedField { - object_name: parent_type.name.clone(), + object_name: parent_type.name().clone(), field_name: field_def.name.clone(), }); // direct recursion isn't allowed, like a connector on User.friends: [User] - if &parent_type.name == field_def.ty.inner_named_type() { + if parent_type.name() == field_def.ty.inner_named_type().as_str() { messages.push(Message { code: Code::CircularReference, message: format!( "Direct circular reference detected in `{}.{}: {}`. For more information, see https://go.apollo.dev/connectors/limitations#circular-references", - parent_type.name, + parent_type.name().clone(), field_def.name, field_def.ty ), diff --git a/apollo-federation/src/connectors/validation/connect/entity.rs b/apollo-federation/src/connectors/validation/connect/entity.rs index 78875ef12b..f0c80a4b40 100644 --- a/apollo-federation/src/connectors/validation/connect/entity.rs +++ b/apollo-federation/src/connectors/validation/connect/entity.rs @@ -9,7 +9,6 @@ use apollo_compiler::ast::FieldDefinition; use apollo_compiler::ast::InputValueDefinition; use apollo_compiler::schema::ExtendedType; use apollo_compiler::schema::InputObjectType; -use apollo_compiler::schema::ObjectType; use super::Code; use super::Message; @@ -17,6 +16,7 @@ use super::ObjectCategory; use crate::connectors::expand::visitors::FieldVisitor; use crate::connectors::expand::visitors::GroupVisitor; use crate::connectors::id::ConnectedElement; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::spec::connect::CONNECT_ENTITY_ARGUMENT_NAME; use crate::connectors::validation::coordinates::ConnectDirectiveCoordinate; use crate::connectors::validation::graphql::SchemaInfo; @@ -84,7 +84,7 @@ pub(super) fn validate_entity_arg( }); } - let Some(object_type) = schema.get_object(field.ty.inner_named_type()) else { + let Some(object_type) = SchemaTypeRef::new(schema, field.ty.inner_named_type()) else { return Err(Message { code: Code::EntityTypeInvalid, message: format!( @@ -97,6 +97,21 @@ pub(super) fn validate_entity_arg( }); }; + // TODO: When abstract types (interfaces/unions) are supported for entity connectors, + // change this check to: !object_type.is_object() && !object_type.is_interface() && !object_type.is_union() + if !object_type.is_object() { + return Err(Message { + code: Code::EntityTypeInvalid, + message: format!( + "{coordinate} is invalid. Entity connectors must return object types.", + ), + locations: entity_arg + .line_column_range(&schema.sources) + .into_iter() + .collect(), + }); + } + if field.ty.is_list() || field.ty.is_non_null() { return Err(Message { code: Code::EntityTypeInvalid, @@ -141,12 +156,12 @@ enum Group<'schema> { /// The entity itself, we're matching argument names & types to these fields Root { field: &'schema Node, - entity_type: &'schema Node, + entity_type: SchemaTypeRef<'schema>, }, /// A child field of the entity we're matching against an input type. Child { input_type: &'schema Node, - entity_type: &'schema ExtendedType, + entity_type: SchemaTypeRef<'schema>, }, } @@ -154,11 +169,11 @@ enum Group<'schema> { struct Field<'schema> { node: &'schema Node, /// The object which has a field that we're comparing against - object_type: &'schema ObjectType, + object_type: SchemaTypeRef<'schema>, /// The field definition of the input that correlates to a field on the entity - input_field: &'schema ExtendedType, + input_field: SchemaTypeRef<'schema>, /// The field of the entity that we're comparing against, part of `object_type` - entity_field: &'schema ExtendedType, + entity_field: SchemaTypeRef<'schema>, } /// Visitor for entity resolver arguments. @@ -179,7 +194,7 @@ impl<'schema> GroupVisitor, Field<'schema>> for ArgumentVisitor<' ) -> Result>, Self::Error> { Ok( // Each input type within an argument to the entity field is another group to visit - if let ExtendedType::InputObject(input_object_type) = field.input_field { + if let ExtendedType::InputObject(input_object_type) = field.input_field.extended() { Some(Group::Child { input_type: input_object_type, entity_type: field.entity_field, @@ -194,12 +209,12 @@ impl<'schema> GroupVisitor, Field<'schema>> for ArgumentVisitor<' match group { Group::Root { field, entity_type, .. - } => self.enter_root_group(field, entity_type), + } => self.enter_root_group(field, *entity_type), Group::Child { input_type, entity_type, .. - } => self.enter_child_group(input_type, entity_type), + } => self.enter_child_group(input_type, *entity_type), } } @@ -212,7 +227,7 @@ impl<'schema> FieldVisitor> for ArgumentVisitor<'schema> { type Error = Message; fn visit(&mut self, field: Field<'schema>) -> Result<(), Self::Error> { - let ok = match field.input_field { + let ok = match field.input_field.extended() { ExtendedType::InputObject(_) => field.entity_field.is_object(), ExtendedType::Scalar(_) | ExtendedType::Enum(_) => { field.input_field == field.entity_field @@ -228,7 +243,7 @@ impl<'schema> FieldVisitor> for ArgumentVisitor<'schema> { "`{coordinate}({field_name}:)` is of type `{input_type}`, but must match `{object}.{field_name}` of type `{entity_type}` because `entity` is `true`.", coordinate = self.coordinate.connect.element, field_name = field.node.name.as_str(), - object = field.object_type.name, + object = field.object_type.name(), input_type = field.input_field.name(), entity_type = field.entity_field.name(), ), @@ -247,82 +262,128 @@ impl<'schema> ArgumentVisitor<'schema> { fn enter_root_group( &mut self, field: &'schema Node, - entity_type: &'schema Node, + entity_type: SchemaTypeRef<'schema>, ) -> Result>, >>::Error> { + let mut fields: Vec> = Vec::new(); + // At the root level, visit each argument to the entity field - field.arguments.iter().filter_map(|arg| { - if let Some(input_type) = self.schema.types.get(arg.ty.inner_named_type()) { - // Check that the argument has a corresponding field on the entity type - if let Some(entity_field) = entity_type.fields.get(&*arg.name) - .and_then(|entity_field| self.schema.types.get(entity_field.ty.inner_named_type())) { - Some(Ok(Field { - node: arg, - input_field: input_type, - entity_field, - object_type: entity_type, - })) - } else { - Some(Err(Message { + for arg in field.arguments.iter() { + // if let Some(input_type) = self.schema.types.get(arg.ty.inner_named_type()) { + if let Some(input_type) = SchemaTypeRef::new(self.schema, arg.ty.inner_named_type()) { + let fields_by_type_name = entity_type.get_fields(arg.name.as_str()); + if fields_by_type_name.is_empty() { + return Err(Message { code: Code::EntityResolverArgumentMismatch, message: format!( "`{coordinate}` has invalid arguments. Argument `{arg_name}` does not have a matching field `{arg_name}` on type `{entity_type}`.", coordinate = self.coordinate.connect.element, arg_name = &*arg.name, - entity_type = entity_type.name, + entity_type = entity_type.name(), ), locations: arg .line_column_range(&self.schema.sources) .into_iter() .chain(self.entity_arg.line_column_range(&self.schema.sources)) .collect(), - })) + }); } - } else { - // The input type is missing - this will be reported elsewhere, so just ignore - None + + fields.extend( + fields_by_type_name + .iter() + .flat_map(|(type_name, entity_field)| { + if let (Some(entity_type), Some(entity_field_type_ref)) = ( + // Look up concrete object type and use it instead + // of original entity_type. + SchemaTypeRef::new(self.schema, type_name.as_str()), + SchemaTypeRef::new(self.schema, entity_field.ty.inner_named_type()), + ) { + Some(Field { + node: arg, + input_field: input_type, + entity_field: entity_field_type_ref, + object_type: entity_type, + }) + } else { + None + } + }), + ); } - }).collect() + } + + Ok(fields) } fn enter_child_group( &mut self, child_input_type: &'schema Node, - entity_type: &'schema ExtendedType, + entity_type: SchemaTypeRef<'schema>, ) -> Result>, >>::Error> { + let mut fields = Vec::new(); + // At the child level, visit each field on the input type - let ExtendedType::Object(entity_object_type) = entity_type else { - // Entity type was not an object type - this will be reported by field visitor - return Ok(Vec::new()); - }; - child_input_type.fields.iter().filter_map(|(name, input_field)| { - if let Some(entity_field) = entity_object_type.fields.get(name) { - let entity_field_type = entity_field.ty.inner_named_type(); - let input_type = self.schema.types.get(input_field.ty.inner_named_type())?; - - self.schema.types.get(entity_field_type).map(|entity_type| Ok(Field { - node: input_field, - object_type: entity_object_type, - input_field: input_type, - entity_field: entity_type, - })) - } else { - // The input type field does not have a corresponding field on the entity type - Some(Err(Message { + for (name, input_field) in child_input_type.fields.iter() { + let field_type_name = input_field.ty.inner_named_type(); + let Some(input_type) = SchemaTypeRef::new(self.schema, field_type_name) else { + // Report an error if the input_field's type is not found in + // self.schema. + return Err(Message { + code: Code::MissingSchemaType, + message: format!( + "Input field `{name}` on `{child_input_type}` has unknown type {field_type_name}", + name = name, + child_input_type = child_input_type.name, + ), + locations: input_field + .line_column_range(&self.schema.sources) + .into_iter() + .collect(), + }); + }; + + let fields_by_type_name = entity_type.get_fields(name.as_str()); + if fields_by_type_name.is_empty() { + return Err(Message { code: Code::EntityResolverArgumentMismatch, message: format!( "`{coordinate}` has invalid arguments. Field `{name}` on `{input_type}` does not have a matching field `{name}` on `{entity_type}`.", coordinate = self.coordinate.connect.element, input_type = child_input_type.name, - entity_type = entity_object_type.name, + entity_type = entity_type.name(), ), locations: input_field .line_column_range(&self.schema.sources) .into_iter() .chain(self.entity_arg.line_column_range(&self.schema.sources)) .collect(), - })) + }); } - }).collect() + + fields.extend( + fields_by_type_name + .iter() + .flat_map(|(type_name, entity_field)| { + if let (Some(entity_type), Some(entity_field_type_ref)) = ( + // Look up concrete object type and use it instead of + // original entity_type. + SchemaTypeRef::new(self.schema, type_name.as_str()), + SchemaTypeRef::new(self.schema, entity_field.ty.inner_named_type()), + ) { + Some(Field { + node: input_field, + object_type: entity_type, + input_field: input_type, + entity_field: entity_field_type_ref, + }) + } else { + None + } + }), + ); + } + + Ok(fields) } } diff --git a/apollo-federation/src/connectors/validation/connect/selection.rs b/apollo-federation/src/connectors/validation/connect/selection.rs index 0da511bd2c..c0acb2e244 100644 --- a/apollo-federation/src/connectors/validation/connect/selection.rs +++ b/apollo-federation/src/connectors/validation/connect/selection.rs @@ -120,14 +120,19 @@ impl<'schema> Selection<'schema> { SelectionValidator::new( schema, - PathPart::Root(parent_type), + PathPart::Root(parent_type.as_object_node().ok_or_else(|| Message { + code: Code::GraphQLError, + message: "Parent type is not an object type".to_string(), + locations: vec![], + })?), &self.node, self.coordinate, ) .walk(group) .map(|validator| validator.seen_fields) } - ConnectedElement::Type { type_def } => { + + ConnectedElement::Type { type_ref } => { let Some(sub_selection) = self.parsed.next_subselection() else { // TODO: Validate scalar selections return Ok(Vec::new()); @@ -135,13 +140,21 @@ impl<'schema> Selection<'schema> { let group = Group { selection: sub_selection, - ty: type_def, + ty: type_ref.as_object_node().ok_or_else(|| Message { + code: Code::GraphQLError, + message: "Type definition is not an object type".to_string(), + locations: vec![], + })?, definition: None, }; SelectionValidator::new( schema, - PathPart::Root(type_def), + PathPart::Root(type_ref.as_object_node().ok_or_else(|| Message { + code: Code::GraphQLError, + message: "Type definition is not an object type".to_string(), + locations: vec![], + })?), &self.node, self.coordinate, ) diff --git a/apollo-federation/src/connectors/validation/connect/selection/variables.rs b/apollo-federation/src/connectors/validation/connect/selection/variables.rs index 2ef0af91fa..a0aff14190 100644 --- a/apollo-federation/src/connectors/validation/connect/selection/variables.rs +++ b/apollo-federation/src/connectors/validation/connect/selection/variables.rs @@ -9,11 +9,11 @@ use apollo_compiler::ast::Type; use apollo_compiler::ast::Value; use apollo_compiler::schema::Component; use apollo_compiler::schema::ExtendedType; -use apollo_compiler::schema::ObjectType; use itertools::Itertools; use crate::connectors::id::ConnectedElement; use crate::connectors::json_selection::SelectionTrie; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::validation::Code; use crate::connectors::validation::Message; use crate::connectors::validation::graphql::SchemaInfo; @@ -40,7 +40,7 @@ impl<'a> VariableResolver<'a> { } => { resolvers.insert( Namespace::This, - Box::new(ThisResolver::new(parent_type, field_def)), + Box::new(ThisResolver::new(*parent_type, field_def)), ); resolvers.insert(Namespace::Args, Box::new(ArgsResolver::new(field_def))); } @@ -176,13 +176,16 @@ fn resolve_path( /// Resolves variables in the `$this` namespace pub(crate) struct ThisResolver<'a> { - object: &'a ObjectType, + object_type: SchemaTypeRef<'a>, field: &'a Component, } impl<'a> ThisResolver<'a> { - pub(crate) const fn new(object: &'a ObjectType, field: &'a Component) -> Self { - Self { object, field } + pub(crate) const fn new( + object_type: SchemaTypeRef<'a>, + field: &'a Component, + ) -> Self { + Self { object_type, field } } } @@ -194,25 +197,26 @@ impl NamespaceResolver for ThisResolver<'_> { schema: &SchemaInfo, ) -> Result<(), Message> { for (root, sub_trie) in reference.selection.iter() { - let fields = &self.object.fields; + let fields = self.object_type.get_fields(root); - let field_type = fields - .get(root) - .ok_or_else(|| Message { + if fields.is_empty() { + return Err(Message { code: Code::UndefinedField, message: format!( "`{object}` does not have a field named `{root}`", - object = self.object.name, + object = self.object_type.name(), ), locations: reference .selection .key_ranges(root) .flat_map(|range| subslice_location(node, range, schema)) .collect(), - }) - .map(|field| field.ty.clone())?; + }); + } - resolve_path(schema, sub_trie, node, &field_type, self.field)?; + for field_def in fields.values() { + resolve_path(schema, sub_trie, node, &field_def.ty, self.field)?; + } } Ok(()) diff --git a/apollo-federation/src/connectors/validation/expression.rs b/apollo-federation/src/connectors/validation/expression.rs index 21227a0d48..4249ccef7f 100644 --- a/apollo-federation/src/connectors/validation/expression.rs +++ b/apollo-federation/src/connectors/validation/expression.rs @@ -23,6 +23,7 @@ use crate::connectors::Namespace; use crate::connectors::id::ConnectedElement; use crate::connectors::id::ObjectCategory; use crate::connectors::json_selection::VarPaths; +use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::string_template::Expression; use crate::connectors::validation::Code; use crate::connectors::validation::Message; @@ -95,7 +96,7 @@ impl<'schema> Context<'schema> { .collect(); if matches!(parent_category, ObjectCategory::Other) { - var_lookup.insert(Namespace::This, Shape::from(parent_type)); + var_lookup.insert(Namespace::This, shape_from_schema_type_ref(parent_type)); } Self { @@ -106,10 +107,13 @@ impl<'schema> Context<'schema> { has_response_body: false, } } - ConnectedElement::Type { type_def } => { + ConnectedElement::Type { type_ref } => { let var_lookup: IndexMap = [ - (Namespace::This, Shape::from(type_def)), - (Namespace::Batch, Shape::list(Shape::from(type_def), [])), + (Namespace::This, shape_from_schema_type_ref(type_ref)), + ( + Namespace::Batch, + Shape::list(shape_from_schema_type_ref(type_ref), []), + ), (Namespace::Config, Shape::unknown([])), (Namespace::Context, Shape::unknown([])), (Namespace::Request, REQUEST_SHAPE.clone()), @@ -156,7 +160,7 @@ impl<'schema> Context<'schema> { .collect(); if matches!(parent_category, ObjectCategory::Other) { - var_lookup.insert(Namespace::This, Shape::from(parent_type)); + var_lookup.insert(Namespace::This, shape_from_schema_type_ref(parent_type)); } Self { @@ -167,10 +171,13 @@ impl<'schema> Context<'schema> { has_response_body: true, } } - ConnectedElement::Type { type_def } => { + ConnectedElement::Type { type_ref } => { let var_lookup: IndexMap = [ - (Namespace::This, Shape::from(type_def)), - (Namespace::Batch, Shape::list(Shape::from(type_def), [])), + (Namespace::This, shape_from_schema_type_ref(type_ref)), + ( + Namespace::Batch, + Shape::list(shape_from_schema_type_ref(type_ref), []), + ), (Namespace::Config, Shape::unknown([])), (Namespace::Context, Shape::unknown([])), (Namespace::Status, Shape::int([])), @@ -265,6 +272,11 @@ impl<'schema> Context<'schema> { } } +/// Convert a SchemaTypeRef to a Shape by using its ExtendedType +fn shape_from_schema_type_ref(type_ref: SchemaTypeRef<'_>) -> Shape { + Shape::from(type_ref.extended()) +} + pub(crate) fn scalars() -> Shape { Shape::one( vec![ @@ -679,6 +691,7 @@ mod tests { use super::*; use crate::connectors::ConnectSpec; use crate::connectors::JSONSelection; + use crate::connectors::schema_type_ref::SchemaTypeRef; use crate::connectors::validation::ConnectLink; fn expression(selection: &str, spec: ConnectSpec) -> Expression { @@ -758,7 +771,7 @@ mod tests { .clone(); let coordinate = ConnectDirectiveCoordinate { element: ConnectedElement::Field { - parent_type: object, + parent_type: SchemaTypeRef::from_node(&schema, object).unwrap(), field_def: field, parent_category: ObjectCategory::Query, }, diff --git a/apollo-federation/src/connectors/validation/mod.rs b/apollo-federation/src/connectors/validation/mod.rs index 181a3e6001..1a91ca8a85 100644 --- a/apollo-federation/src/connectors/validation/mod.rs +++ b/apollo-federation/src/connectors/validation/mod.rs @@ -282,6 +282,8 @@ pub enum Code { ConnectBatchAndThis, /// Invalid URL property InvalidUrlProperty, + /// Any named type not found in a GraphQL schema where expected + MissingSchemaType, } impl Code { diff --git a/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap b/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap index 92b863c8a9..7fe78eed4d 100644 --- a/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap +++ b/apollo-federation/src/connectors/validation/snapshots/validation_tests@keys_and_entities__invalid__entity_true_returning_scalar.graphql.snap @@ -1,6 +1,7 @@ --- source: apollo-federation/src/connectors/validation/mod.rs -expression: "format!(\"{:#?}\", errors)" +assertion_line: 325 +expression: "format!(\"{:#?}\", result.errors)" input_file: apollo-federation/src/connectors/validation/test_data/keys_and_entities/invalid/entity_true_returning_scalar.graphql --- [ From 56ab2fc41b04bbda190b3da6b755fb7d6bb5c2d5 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 2 Dec 2025 12:37:20 -0500 Subject: [PATCH 2/2] Address review feedback for `SchemaTypeRef` changes (#8676) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- .../src/connectors/schema_type_ref.rs | 62 ++++++++----------- .../src/connectors/validation/connect.rs | 39 ++++++------ .../src/connectors/validation/mod.rs | 3 +- 3 files changed, 48 insertions(+), 56 deletions(-) diff --git a/apollo-federation/src/connectors/schema_type_ref.rs b/apollo-federation/src/connectors/schema_type_ref.rs index e0f1946d4b..21a4acaee3 100644 --- a/apollo-federation/src/connectors/schema_type_ref.rs +++ b/apollo-federation/src/connectors/schema_type_ref.rs @@ -145,7 +145,7 @@ impl<'schema> SchemaTypeRef<'schema> { Shape::one([shape, Shape::null([])], []) } - #[allow(dead_code)] + #[expect(dead_code)] pub(super) fn shape_from_type(&self, ty: &Type) -> Shape { self.shape_from_type_with_visited(ty, &mut IndexSet::default()) } @@ -174,7 +174,7 @@ impl<'schema> SchemaTypeRef<'schema> { } } - #[allow(dead_code)] + #[expect(dead_code)] pub(super) fn schema(&self) -> &'schema Schema { self.0 } @@ -187,42 +187,38 @@ impl<'schema> SchemaTypeRef<'schema> { self.2 } - #[allow(dead_code)] pub(super) fn is_object(&self) -> bool { self.2.is_object() } - #[allow(dead_code)] pub(super) fn is_interface(&self) -> bool { self.2.is_interface() } - #[allow(dead_code)] pub(super) fn is_union(&self) -> bool { self.2.is_union() } - #[allow(dead_code)] pub(super) fn is_abstract(&self) -> bool { self.is_interface() || self.is_union() } - #[allow(dead_code)] + #[expect(dead_code)] pub(super) fn is_input_object(&self) -> bool { self.2.is_input_object() } - #[allow(dead_code)] + #[expect(dead_code)] pub(super) fn is_enum(&self) -> bool { self.2.is_enum() } - #[allow(dead_code)] + #[expect(dead_code)] pub(super) fn is_scalar(&self) -> bool { self.2.is_scalar() } - #[allow(dead_code)] + #[expect(dead_code)] pub(super) fn is_built_in(&self) -> bool { self.2.is_built_in() } @@ -236,32 +232,28 @@ impl<'schema> SchemaTypeRef<'schema> { .get(self.1) .into_iter() .flat_map(|ty| match ty { - ExtendedType::Object(o) => { - let mut map = IndexMap::default(); - if let Some(field_def) = o.fields.get(field_name) { - map.insert(o.name.to_string(), field_def); - } - map - } + ExtendedType::Object(o) => o + .fields + .get(field_name) + .map(|field_def| { + std::iter::once((o.name.to_string(), field_def)).collect::>() + }) + .unwrap_or_default(), - ExtendedType::Interface(i) => { - let mut map = IndexMap::default(); - if let Some(implementers) = self.0.implementers_map().get(i.name.as_str()) { - for obj_name in &implementers.objects { - if let Some(impl_obj) = SchemaTypeRef::new(self.0, obj_name.as_str()) { - map.extend(impl_obj.get_fields(field_name).into_iter()); - } - } - for iface_name in &implementers.interfaces { - if let Some(impl_iface) = - SchemaTypeRef::new(self.0, iface_name.as_str()) - { - map.extend(impl_iface.get_fields(field_name).into_iter()); - } - } - } - map - } + ExtendedType::Interface(i) => self + .0 + .implementers_map() + .get(i.name.as_str()) + .map(|implementers| { + implementers + .objects + .iter() + .chain(&implementers.interfaces) + .filter_map(|name| SchemaTypeRef::new(self.0, name.as_str())) + .flat_map(|type_ref| type_ref.get_fields(field_name)) + .collect::>() + }) + .unwrap_or_default(), ExtendedType::Union(u) => u .members diff --git a/apollo-federation/src/connectors/validation/connect.rs b/apollo-federation/src/connectors/validation/connect.rs index 0ecc9ca56a..f8dee99a33 100644 --- a/apollo-federation/src/connectors/validation/connect.rs +++ b/apollo-federation/src/connectors/validation/connect.rs @@ -40,22 +40,21 @@ pub(super) fn fields_seen_by_all_connects( schema: &SchemaInfo, all_source_names: &[SourceName], ) -> Result, Vec> { - let mut messages = Vec::new(); - let mut connects = Vec::new(); - - for (type_name, extended_type) in schema.types.iter().filter(|(_, ty)| !ty.is_built_in()) { - // Only check types that can have connectors (objects, interfaces, unions) - if matches!( - extended_type, - ExtendedType::Object(_) | ExtendedType::Interface(_) | ExtendedType::Union(_) - ) && let Some(type_ref) = SchemaTypeRef::new(schema.schema, type_name) - { - let (connects_for_type, messages_for_type) = - Connect::find_on_type(type_ref, schema, all_source_names); - connects.extend(connects_for_type); - messages.extend(messages_for_type); - } - } + let (connects, messages): (Vec<_>, Vec<_>) = schema + .types + .iter() + .filter(|(_, ty)| !ty.is_built_in()) + .filter(|(_, ty)| { + matches!( + ty, + ExtendedType::Object(_) | ExtendedType::Interface(_) | ExtendedType::Union(_) + ) + }) + .filter_map(|(type_name, _)| SchemaTypeRef::new(schema.schema, type_name)) + .map(|type_ref| Connect::find_on_type(type_ref, schema, all_source_names)) + .unzip(); + let connects: Vec<_> = connects.into_iter().flatten().collect(); + let mut messages: Vec<_> = messages.into_iter().flatten().collect(); let mut seen_fields = Vec::new(); let mut valid_id_names: HashMap<_, Vec<_>> = HashMap::new(); @@ -168,7 +167,7 @@ impl<'schema> Connect<'schema> { ExtendedType::Union(union) => union.directives.iter(), _ => [].iter(), // Other types (scalars, enums, etc.) don't support connectors } - .filter(|directive| directive.name == *schema.connect_directive_name()) + .filter(|directive| &directive.name == schema.connect_directive_name()) .map(|directive| ConnectDirectiveCoordinate { directive, element: ConnectedElement::Type { type_ref }, @@ -182,7 +181,7 @@ impl<'schema> Connect<'schema> { field .directives .iter() - .filter(|directive| directive.name == *schema.connect_directive_name()) + .filter(|directive| &directive.name == schema.connect_directive_name()) .map(|directive| ConnectDirectiveCoordinate { directive, element: ConnectedElement::Field { @@ -200,7 +199,7 @@ impl<'schema> Connect<'schema> { field .directives .iter() - .filter(|directive| directive.name == *schema.connect_directive_name()) + .filter(|directive| &directive.name == schema.connect_directive_name()) .map(|directive| ConnectDirectiveCoordinate { directive, element: ConnectedElement::Field { @@ -374,7 +373,7 @@ impl<'schema> Connect<'schema> { code: Code::CircularReference, message: format!( "Direct circular reference detected in `{}.{}: {}`. For more information, see https://go.apollo.dev/connectors/limitations#circular-references", - parent_type.name().clone(), + parent_type.name(), field_def.name, field_def.ty ), diff --git a/apollo-federation/src/connectors/validation/mod.rs b/apollo-federation/src/connectors/validation/mod.rs index 1a91ca8a85..fe9ebcd20b 100644 --- a/apollo-federation/src/connectors/validation/mod.rs +++ b/apollo-federation/src/connectors/validation/mod.rs @@ -308,6 +308,7 @@ pub enum Severity { mod test_validate_source { use std::fs::read_to_string; + use insta::assert_debug_snapshot; use insta::assert_snapshot; use insta::glob; use pretty_assertions::assert_str_eq; @@ -322,7 +323,7 @@ mod test_validate_source { let start_time = std::time::Instant::now(); let result = validate(schema.clone(), path.to_str().unwrap()); let end_time = std::time::Instant::now(); - assert_snapshot!(format!("{:#?}", result.errors)); + assert_debug_snapshot!(result.errors); if path.parent().is_some_and(|parent| parent.ends_with("transformed")) { assert_snapshot!(&diff::lines(&schema, &result.transformed).into_iter().filter_map(|res| match res { diff::Result::Left(line) => Some(format!("- {line}")),