diff --git a/README.md b/README.md
index 3ea29e1..b08ef30 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,10 @@ Note that GitHub requires authentication to consume the feed. See [here](https:/
# Limitations
+`@skip` and `@include` directives are ignored; all selected fields of the selected operation will
+be checked for authentication requirements, including referenced fragments. (Other operations
+in the same document will correctly be skipped.)
+
This authorization framework only supports policy-based authorization. It does not support role-based authorization, or the
`[AllowAnonymous]` attribute/extension, or the `[Authorize]` attribute/extension indicating authorization is required
but without specifying a policy. It also does not integrate with ASP.NET Core's authorization framework.
@@ -84,5 +88,3 @@ public class MutationType
# Known Issues
- It is currently not possible to add a policy to Input objects using Schema first approach.
-
-- :warning: Authorization checks are skipped on fragments that are referenced by other fragments :warning:
diff --git a/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs b/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs
index c568189..3ddfb5e 100644
--- a/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs
+++ b/src/GraphQL.Authorization.Tests/AuthorizationValidationRuleTests.cs
@@ -148,7 +148,7 @@ public void issue5_with_fragment_should_fail()
});
}
- [Fact(Skip = "This needs to be fixed")]
+ [Fact]
public void nested_fragment_should_fail()
{
Settings.AddPolicy("AdminPolicy", builder => builder.RequireClaim("admin"));
diff --git a/src/GraphQL.Authorization/AuthorizationValidationRule.cs b/src/GraphQL.Authorization/AuthorizationValidationRule.cs
index edc49f6..16aceee 100644
--- a/src/GraphQL.Authorization/AuthorizationValidationRule.cs
+++ b/src/GraphQL.Authorization/AuthorizationValidationRule.cs
@@ -1,8 +1,6 @@
using GraphQL.Types;
using GraphQL.Validation;
-using GraphQLParser;
using GraphQLParser.AST;
-using GraphQLParser.Visitors;
namespace GraphQL.Authorization;
@@ -12,7 +10,7 @@ namespace GraphQL.Authorization;
///
public class AuthorizationValidationRule : IValidationRule
{
- private readonly Visitor _visitor;
+ private readonly IAuthorizationEvaluator _evaluator;
///
/// Creates an instance of with
@@ -20,87 +18,28 @@ public class AuthorizationValidationRule : IValidationRule
///
public AuthorizationValidationRule(IAuthorizationEvaluator evaluator)
{
- _visitor = new(evaluator);
- }
-
- private static async ValueTask ShouldBeSkippedAsync(GraphQLOperationDefinition actualOperation, ValidationContext context)
- {
- if (context.Document.OperationsCount() <= 1)
- {
- return false;
- }
-
- int i = 0;
- while (true)
- {
- var ancestor = context.TypeInfo.GetAncestor(i++);
-
- if (ancestor == actualOperation)
- {
- return false;
- }
-
- if (ancestor == context.Document)
- {
- return true;
- }
-
- if (ancestor is GraphQLFragmentDefinition fragment)
- {
- //TODO: may be rewritten completely later
- var c = new FragmentBelongsToOperationVisitorContext(fragment);
- await _fragmentBelongsToOperationVisitor.VisitAsync(actualOperation, c).ConfigureAwait(false);
- return !c.Found;
- }
- }
- }
-
- private sealed class FragmentBelongsToOperationVisitorContext : IASTVisitorContext
- {
- public FragmentBelongsToOperationVisitorContext(GraphQLFragmentDefinition fragment)
- {
- Fragment = fragment;
- }
-
- public GraphQLFragmentDefinition Fragment { get; }
-
- public bool Found { get; set; }
-
- public CancellationToken CancellationToken => default;
- }
-
- private static readonly FragmentBelongsToOperationVisitor _fragmentBelongsToOperationVisitor = new();
-
- private sealed class FragmentBelongsToOperationVisitor : ASTVisitor
- {
- protected override ValueTask VisitFragmentSpreadAsync(GraphQLFragmentSpread fragmentSpread, FragmentBelongsToOperationVisitorContext context)
- {
- context.Found = context.Fragment.FragmentName.Name == fragmentSpread.FragmentName.Name;
- return default;
- }
-
- public override ValueTask VisitAsync(ASTNode? node, FragmentBelongsToOperationVisitorContext context)
- {
- return context.Found ? default : base.VisitAsync(node, context);
- }
+ _evaluator = evaluator;
}
///
public async ValueTask ValidateAsync(ValidationContext context)
{
- await _visitor.AuthorizeAsync(null, context.Schema, context).ConfigureAwait(false);
+ var visitor = new Visitor(_evaluator);
+
+ await visitor.AuthorizeAsync(null, context.Schema, context).ConfigureAwait(false);
// this could leak info about hidden fields or types in error messages
// it would be better to implement a filter on the Schema so it
// acts as if they just don't exist vs. an auth denied error
// - filtering the Schema is not currently supported
// TODO: apply ISchemaFilter - context.Schema.Filter.AllowXXX
- return _visitor;
+ return visitor;
}
private class Visitor : INodeVisitor
{
private readonly IAuthorizationEvaluator _evaluator;
+ private bool _validate;
public Visitor(IAuthorizationEvaluator evaluator)
{
@@ -109,15 +48,19 @@ public Visitor(IAuthorizationEvaluator evaluator)
public async ValueTask EnterAsync(ASTNode node, ValidationContext context)
{
- if (node is GraphQLOperationDefinition astType && astType == context.Operation)
+ if ((node is GraphQLOperationDefinition astType && astType == context.Operation) ||
+ (node is GraphQLFragmentDefinition fragment && (context.GetRecursivelyReferencedFragments(context.Operation)?.Contains(fragment) ?? false)))
{
var type = context.TypeInfo.GetLastType();
- await AuthorizeAsync(astType, type, context).ConfigureAwait(false);
+ await AuthorizeAsync(node, type, context).ConfigureAwait(false);
+ _validate = true;
}
+ if (!_validate)
+ return;
+
if (node is GraphQLObjectField objectFieldAst &&
- context.TypeInfo.GetArgument()?.ResolvedType?.GetNamedType() is IComplexGraphType argumentType &&
- !await ShouldBeSkippedAsync(context.Operation, context).ConfigureAwait(false))
+ context.TypeInfo.GetArgument()?.ResolvedType?.GetNamedType() is IComplexGraphType argumentType)
{
var fieldType = argumentType.GetField(objectFieldAst.Name);
await AuthorizeAsync(objectFieldAst, fieldType, context).ConfigureAwait(false);
@@ -127,7 +70,7 @@ public async ValueTask EnterAsync(ASTNode node, ValidationContext context)
{
var fieldDef = context.TypeInfo.GetFieldDef();
- if (fieldDef == null || await ShouldBeSkippedAsync(context.Operation, context).ConfigureAwait(false))
+ if (fieldDef == null)
return;
// check target field
@@ -138,8 +81,7 @@ public async ValueTask EnterAsync(ASTNode node, ValidationContext context)
if (node is GraphQLVariable variableRef)
{
- if (context.TypeInfo.GetArgument()?.ResolvedType?.GetNamedType() is not IComplexGraphType variableType ||
- await ShouldBeSkippedAsync(context.Operation, context).ConfigureAwait(false))
+ if (context.TypeInfo.GetArgument()?.ResolvedType?.GetNamedType() is not IComplexGraphType variableType)
return;
await AuthorizeAsync(variableRef, variableType, context).ConfigureAwait(false);
@@ -163,7 +105,13 @@ await ShouldBeSkippedAsync(context.Operation, context).ConfigureAwait(false))
}
}
- public ValueTask LeaveAsync(ASTNode node, ValidationContext context) => default;
+ public ValueTask LeaveAsync(ASTNode node, ValidationContext context)
+ {
+ if (node is GraphQLOperationDefinition || node is GraphQLFragmentDefinition)
+ _validate = false;
+
+ return default;
+ }
public async ValueTask AuthorizeAsync(ASTNode? node, IProvideMetadata? provider, ValidationContext context)
{