Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MediatR.Contracts/MediatR.Contracts.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<Authors>Jimmy Bogard</Authors>
<Description>Contracts package for requests, responses, and notifications</Description>
<Copyright>Copyright Jimmy Bogard</Copyright>
<TargetFrameworks>netstandard2.0;net461;</TargetFrameworks>
<TargetFramework>netstandard2.0</TargetFramework>
<Nullable>enable</Nullable>
<Features>strict</Features>
<PackageTags>mediator;request;response;queries;commands;notifications</PackageTags>
Expand Down
62 changes: 43 additions & 19 deletions src/MediatR/Pipeline/RequestExceptionActionProcessorBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace MediatR.Pipeline;

using MediatR.Internal;
using Internal;
using System;
using System.Collections.Generic;
using System.Linq;
Expand Down Expand Up @@ -31,38 +31,62 @@ public async Task<TResponse> Handle(TRequest request, CancellationToken cancella
}
catch (Exception exception)
{
for (Type exceptionType = exception.GetType(); exceptionType != typeof(object); exceptionType = exceptionType.BaseType)
{
var actionsForException = GetActionsForException(exceptionType, request, out MethodInfo actionMethod);
var exceptionTypes = GetExceptionTypes(exception.GetType());

var actionsForException = exceptionTypes
.SelectMany(exceptionType => GetActionsForException(exceptionType, request))
.GroupBy(actionForException => actionForException.Action.GetType())
.Select(actionForException => actionForException.First())
.Select(actionForException => (MethodInfo: GetMethodInfoForAction(actionForException.ExceptionType), actionForException.Action))
.ToList();

foreach (var actionForException in actionsForException)
foreach (var actionForException in actionsForException)
{
try
{
await ((Task)(actionForException.MethodInfo.Invoke(actionForException.Action, new object[] { request, exception, cancellationToken })
?? throw new InvalidOperationException($"Could not create task for action method {actionForException.MethodInfo}."))).ConfigureAwait(false);
}
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
{
try
{
await ((Task)(actionMethod.Invoke(actionForException, new object[] { request, exception, cancellationToken })
?? throw new InvalidOperationException($"Could not create task for action method {actionMethod}."))).ConfigureAwait(false);
}
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
{
// Unwrap invocation exception to throw the actual error
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
}
// Unwrap invocation exception to throw the actual error
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
}
}

throw;
}
}

private IList<object> GetActionsForException(Type exceptionType, TRequest request, out MethodInfo actionMethodInfo)
private static IEnumerable<Type> GetExceptionTypes(Type? exceptionType)
{
while (exceptionType != null && exceptionType != typeof(object))
{
yield return exceptionType;
exceptionType = exceptionType.BaseType;
}
}

private IEnumerable<(Type ExceptionType, object Action)> GetActionsForException(Type exceptionType, TRequest request)
{
var exceptionActionInterfaceType = typeof(IRequestExceptionAction<,>).MakeGenericType(typeof(TRequest), exceptionType);
var enumerableExceptionActionInterfaceType = typeof(IEnumerable<>).MakeGenericType(exceptionActionInterfaceType);
actionMethodInfo = exceptionActionInterfaceType.GetMethod(nameof(IRequestExceptionAction<TRequest, Exception>.Execute))
?? throw new InvalidOperationException($"Could not find method {nameof(IRequestExceptionAction<TRequest, Exception>.Execute)} on type {exceptionActionInterfaceType}");

var actionsForException = (IEnumerable<object>)_serviceFactory(enumerableExceptionActionInterfaceType);

return HandlersOrderer.Prioritize(actionsForException.ToList(), request);
return HandlersOrderer.Prioritize(actionsForException.ToList(), request)
.Select(action => (exceptionType, action));
}

private static MethodInfo GetMethodInfoForAction(Type exceptionType)
{
var exceptionActionInterfaceType = typeof(IRequestExceptionAction<,>).MakeGenericType(typeof(TRequest), exceptionType);

var actionMethodInfo =
exceptionActionInterfaceType.GetMethod(nameof(IRequestExceptionAction<TRequest, Exception>.Execute))
?? throw new InvalidOperationException(
$"Could not find method {nameof(IRequestExceptionAction<TRequest, Exception>.Execute)} on type {exceptionActionInterfaceType}");

return actionMethodInfo;
}
}
70 changes: 44 additions & 26 deletions src/MediatR/Pipeline/RequestExceptionProcessorBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,32 @@ public async Task<TResponse> Handle(TRequest request, CancellationToken cancella
catch (Exception exception)
{
var state = new RequestExceptionHandlerState<TResponse>();
Type? exceptionType = null;

while (!state.Handled && exceptionType != typeof(Exception))
var exceptionTypes = GetExceptionTypes(exception.GetType());

var handlersForException = exceptionTypes
.SelectMany(exceptionType => GetHandlersForException(exceptionType, request))
.GroupBy(handlerForException => handlerForException.Handler.GetType())
.Select(handlerForException => handlerForException.First())
.Select(handlerForException => (MethodInfo: GetMethodInfoForHandler(handlerForException.ExceptionType), handlerForException.Handler))
.ToList();

foreach (var handlerForException in handlersForException)
{
exceptionType = exceptionType == null ? exception.GetType() : exceptionType.BaseType
?? throw new InvalidOperationException("Could not determine exception base type.");
var exceptionHandlers = GetExceptionHandlers(request, exceptionType, out MethodInfo handleMethod);
try
{
await ((Task) (handlerForException.MethodInfo.Invoke(handlerForException.Handler, new object[] { request, exception, state, cancellationToken })
?? throw new InvalidOperationException("Did not return a Task from the exception handler."))).ConfigureAwait(false);
}
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
{
// Unwrap invocation exception to throw the actual error
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
}

foreach (var exceptionHandler in exceptionHandlers)
if (state.Handled)
{
try
{
await ((Task)(handleMethod.Invoke(exceptionHandler, new object[] { request, exception, state, cancellationToken })
?? throw new InvalidOperationException("Did not return a Task from the exception handler."))).ConfigureAwait(false);
}
catch (TargetInvocationException invocationException) when (invocationException.InnerException != null)
{
// Unwrap invocation exception to throw the actual error
ExceptionDispatchInfo.Capture(invocationException.InnerException).Throw();
}

if (state.Handled)
{
break;
}
break;
}
}

Expand All @@ -73,16 +74,33 @@ public async Task<TResponse> Handle(TRequest request, CancellationToken cancella
return state.Response; //cannot be null if Handled
}
}
private static IEnumerable<Type> GetExceptionTypes(Type? exceptionType)
{
while (exceptionType != null && exceptionType != typeof(object))
{
yield return exceptionType;
exceptionType = exceptionType.BaseType;
}
}

private IList<object> GetExceptionHandlers(TRequest request, Type exceptionType, out MethodInfo handleMethodInfo)
private IEnumerable<(Type ExceptionType, object Handler)> GetHandlersForException(Type exceptionType, TRequest request)
{
var exceptionHandlerInterfaceType = typeof(IRequestExceptionHandler<,,>).MakeGenericType(typeof(TRequest), typeof(TResponse), exceptionType);
var enumerableExceptionHandlerInterfaceType = typeof(IEnumerable<>).MakeGenericType(exceptionHandlerInterfaceType);
handleMethodInfo = exceptionHandlerInterfaceType.GetMethod(nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle))
?? throw new InvalidOperationException($"Could not find method {nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle)} on type {exceptionHandlerInterfaceType}");

var exceptionHandlers = (IEnumerable<object>)_serviceFactory.Invoke(enumerableExceptionHandlerInterfaceType);
var exceptionHandlers = (IEnumerable<object>) _serviceFactory(enumerableExceptionHandlerInterfaceType);

return HandlersOrderer.Prioritize(exceptionHandlers.ToList(), request)
.Select(handler => (exceptionType, action: handler));
}

private static MethodInfo GetMethodInfoForHandler(Type exceptionType)
{
var exceptionHandlerInterfaceType = typeof(IRequestExceptionHandler<,,>).MakeGenericType(typeof(TRequest), typeof(TResponse), exceptionType);

var handleMethodInfo = exceptionHandlerInterfaceType.GetMethod(nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle))
?? throw new InvalidOperationException($"Could not find method {nameof(IRequestExceptionHandler<TRequest, TResponse, Exception>.Handle)} on type {exceptionHandlerInterfaceType}");

return HandlersOrderer.Prioritize(exceptionHandlers.ToList(), request);
return handleMethodInfo;
}
}
34 changes: 33 additions & 1 deletion test/MediatR.Tests/Pipeline/RequestExceptionActionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ public Task<Pong> Handle(Ping request, CancellationToken cancellationToken)
}
}

public class GenericExceptionAction<TRequest> : IRequestExceptionAction<TRequest>
{
public int ExecutionCount { get; private set; }

public Task Execute(TRequest request, Exception exception, CancellationToken cancellationToken)
{
ExecutionCount++;
return Task.CompletedTask;
}
}

public class PingPongExceptionAction<TRequest> : IRequestExceptionAction<TRequest, PingPongException>
{
public bool Executed { get; private set; }
Expand Down Expand Up @@ -83,7 +94,7 @@ public Task Execute(Ping request, PongException exception, CancellationToken can
}

[Fact]
public async Task Should_run_all_exception_handlers_that_match_base_type()
public async Task Should_run_all_exception_actions_that_match_base_type()
{
var pingExceptionAction = new PingExceptionAction();
var pongExceptionAction = new PongExceptionAction();
Expand All @@ -108,4 +119,25 @@ public async Task Should_run_all_exception_handlers_that_match_base_type()
pingPongExceptionAction.Executed.ShouldBeTrue();
pongExceptionAction.Executed.ShouldBeFalse();
}

[Fact]
public async Task Should_run_matching_exception_actions_only_once()
{
var genericExceptionAction = new GenericExceptionAction<Ping>();
var container = new Container(cfg =>
{
cfg.For<IRequestHandler<Ping, Pong>>().Use<PingHandler>();
cfg.For<IRequestExceptionAction<Ping>>().Use(_ => genericExceptionAction);
cfg.For(typeof(IPipelineBehavior<,>)).Add(typeof(RequestExceptionActionProcessorBehavior<,>));
cfg.For<ServiceFactory>().Use<ServiceFactory>(ctx => t => ctx.GetInstance(t));
cfg.For<IMediator>().Use<Mediator>();
});

var mediator = container.GetInstance<IMediator>();

var request = new Ping { Message = "Ping!" };
await Assert.ThrowsAsync<PingException>(() => mediator.Send(request));

genericExceptionAction.ExecutionCount.ShouldBe(1);
}
}
35 changes: 35 additions & 0 deletions test/MediatR.Tests/Pipeline/RequestExceptionHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ public Task<Pong> Handle(Ping request, CancellationToken cancellationToken)
}
}

public class GenericPingExceptionHandler : IRequestExceptionHandler<Ping, Pong>
{
public int ExecutionCount { get; private set; }

public Task Handle(Ping request, Exception exception, RequestExceptionHandlerState<Pong> state, CancellationToken cancellationToken)
{
ExecutionCount++;
return Task.CompletedTask;
}
}

public class PingPongExceptionHandlerForType : IRequestExceptionHandler<Ping, Pong, PingException>
{
public Task Handle(Ping request, PingException exception, RequestExceptionHandlerState<Pong> state, CancellationToken cancellationToken)
Expand Down Expand Up @@ -133,4 +144,28 @@ await Should.ThrowAsync<ApplicationException>(async () =>
});
}

[Fact]
public async Task Should_run_matching_exception_handlers_only_once()
{
var genericPingExceptionHandler = new GenericPingExceptionHandler();
var container = new Container(cfg =>
{
cfg.For<IRequestHandler<Ping, Pong>>().Use<PingHandler>();
cfg.For<IRequestExceptionHandler<Ping, Pong>>().Use(genericPingExceptionHandler);
cfg.For(typeof(IPipelineBehavior<,>)).Add(typeof(RequestExceptionProcessorBehavior<,>));
cfg.For<ServiceFactory>().Use<ServiceFactory>(ctx => t => ctx.GetInstance(t));
cfg.For<IMediator>().Use<Mediator>();
});

var mediator = container.GetInstance<IMediator>();

var request = new Ping { Message = "Ping" };
await Should.ThrowAsync<PingException>(async () =>
{
await mediator.Send(request);
});

genericPingExceptionHandler.ExecutionCount.ShouldBe(1);
}

}