Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/MediatR/MediatR.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<Authors>Jimmy Bogard</Authors>
<Description>Simple, unambitious mediator implementation in .NET</Description>
<Copyright>Copyright Jimmy Bogard</Copyright>
<TargetFramework>netstandard2.0</TargetFramework>
<TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
<Nullable>enable</Nullable>
<Features>strict</Features>
<PackageTags>mediator;request;response;queries;commands;notifications</PackageTags>
Expand Down Expand Up @@ -32,7 +32,7 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="MediatR.Contracts" Version="[2.0.1, 3.0.0)" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="6.0.0" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="6.0.0" Condition="'$(TargetFramework)' == 'netstandard2.0'" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="6.0.0" />
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.1.1" PrivateAssets="All" />
<PackageReference Include="MinVer" Version="4.3.0" PrivateAssets="All" />
Expand Down
118 changes: 54 additions & 64 deletions src/MediatR/Mediator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ public Task<TResponse> Send<TResponse>(IRequest<TResponse> request, Cancellation
throw new ArgumentNullException(nameof(request));
}

var requestType = request.GetType();

var handler = (RequestHandlerWrapper<TResponse>)_requestHandlers.GetOrAdd(requestType,
static t => (RequestHandlerBase)(Activator.CreateInstance(typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(t, typeof(TResponse)))
?? throw new InvalidOperationException($"Could not create wrapper type for {t}")));
var handler = (RequestHandlerWrapper<TResponse>)_requestHandlers.GetOrAdd(request.GetType(), static requestType =>
{
var wrapperType = typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestType, typeof(TResponse));
var wrapper = Activator.CreateInstance(wrapperType) ?? throw new InvalidOperationException($"Could not create wrapper type for {requestType}");
return (RequestHandlerBase)wrapper;
});

return handler.Handle(request, _serviceProvider, cancellationToken);
}
Expand All @@ -63,11 +64,12 @@ public Task Send<TRequest>(TRequest request, CancellationToken cancellationToken
throw new ArgumentNullException(nameof(request));
}

var requestType = request.GetType();

var handler = (RequestHandlerWrapper)_requestHandlers.GetOrAdd(requestType,
static t => (RequestHandlerBase)(Activator.CreateInstance(typeof(RequestHandlerWrapperImpl<>).MakeGenericType(t))
?? throw new InvalidOperationException($"Could not create wrapper type for {t}")));
var handler = (RequestHandlerWrapper)_requestHandlers.GetOrAdd(request.GetType(), static requestType =>
{
var wrapperType = typeof(RequestHandlerWrapperImpl<>).MakeGenericType(requestType);
var wrapper = Activator.CreateInstance(wrapperType) ?? throw new InvalidOperationException($"Could not create wrapper type for {requestType}");
return (RequestHandlerBase)wrapper;
});

return handler.Handle(request, _serviceProvider, cancellationToken);
}
Expand All @@ -78,41 +80,31 @@ public Task Send<TRequest>(TRequest request, CancellationToken cancellationToken
{
throw new ArgumentNullException(nameof(request));
}
var requestType = request.GetType();
var handler = _requestHandlers.GetOrAdd(requestType,
static requestTypeKey =>
{
var requestInterfaceType = requestTypeKey
.GetInterfaces()
.FirstOrDefault(static i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IRequest<>));

Type wrapperType;
var handler = _requestHandlers.GetOrAdd(request.GetType(), static requestType =>
{
Type wrapperType;

var requestInterfaceType = requestType.GetInterfaces().FirstOrDefault(static i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IRequest<>));
if (requestInterfaceType is null)
{
requestInterfaceType = requestType.GetInterfaces().FirstOrDefault(static i => i == typeof(IRequest));
if (requestInterfaceType is null)
{
requestInterfaceType = requestTypeKey
.GetInterfaces()
.FirstOrDefault(static i => i == typeof(IRequest));

if (requestInterfaceType is null)
{
throw new ArgumentException($"{requestTypeKey.Name} does not implement {nameof(IRequest)}",
nameof(request));
}

wrapperType =
typeof(RequestHandlerWrapperImpl<>).MakeGenericType(requestTypeKey);
}
else
{
var responseType = requestInterfaceType.GetGenericArguments()[0];
wrapperType =
typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestTypeKey, responseType);
throw new ArgumentException($"{requestType.Name} does not implement {nameof(IRequest)}", nameof(request));
}

return (RequestHandlerBase)(Activator.CreateInstance(wrapperType)
?? throw new InvalidOperationException($"Could not create wrapper for type {wrapperType}"));
});
wrapperType = typeof(RequestHandlerWrapperImpl<>).MakeGenericType(requestType);
}
else
{
var responseType = requestInterfaceType.GetGenericArguments()[0];
wrapperType = typeof(RequestHandlerWrapperImpl<,>).MakeGenericType(requestType, responseType);
}

var wrapper = Activator.CreateInstance(wrapperType) ?? throw new InvalidOperationException($"Could not create wrapper for type {requestType}");
return (RequestHandlerBase)wrapper;
});

// call via dynamic dispatch to avoid calling through reflection for performance reasons
return handler.Handle(request, _serviceProvider, cancellationToken);
Expand Down Expand Up @@ -149,10 +141,12 @@ protected virtual Task PublishCore(IEnumerable<NotificationHandlerExecutor> hand

private Task PublishNotification(INotification notification, CancellationToken cancellationToken = default)
{
var notificationType = notification.GetType();
var handler = _notificationHandlers.GetOrAdd(notificationType,
static t => (NotificationHandlerWrapper) (Activator.CreateInstance(typeof(NotificationHandlerWrapperImpl<>).MakeGenericType(t))
?? throw new InvalidOperationException($"Could not create wrapper for type {t}")));
var handler = _notificationHandlers.GetOrAdd(notification.GetType(), static notificationType =>
{
var wrapperType = typeof(NotificationHandlerWrapperImpl<>).MakeGenericType(notificationType);
var wrapper = Activator.CreateInstance(wrapperType) ?? throw new InvalidOperationException($"Could not create wrapper for type {notificationType}");
return (NotificationHandlerWrapper)wrapper;
});

return handler.Handle(notification, _serviceProvider, PublishCore, cancellationToken);
}
Expand All @@ -165,10 +159,12 @@ public IAsyncEnumerable<TResponse> CreateStream<TResponse>(IStreamRequest<TRespo
throw new ArgumentNullException(nameof(request));
}

var requestType = request.GetType();

var streamHandler = (StreamRequestHandlerWrapper<TResponse>) _streamRequestHandlers.GetOrAdd(requestType,
t => (StreamRequestHandlerBase) Activator.CreateInstance(typeof(StreamRequestHandlerWrapperImpl<,>).MakeGenericType(requestType, typeof(TResponse))));
var streamHandler = (StreamRequestHandlerWrapper<TResponse>)_streamRequestHandlers.GetOrAdd(request.GetType(), static requestType =>
{
var wrapperType = typeof(StreamRequestHandlerWrapperImpl<,>).MakeGenericType(requestType, typeof(TResponse));
var wrapper = Activator.CreateInstance(wrapperType) ?? throw new InvalidOperationException($"Could not create wrapper for type {requestType}");
return (StreamRequestHandlerBase)wrapper;
});

var items = streamHandler.Handle(request, _serviceProvider, cancellationToken);

Expand All @@ -183,26 +179,20 @@ public IAsyncEnumerable<TResponse> CreateStream<TResponse>(IStreamRequest<TRespo
throw new ArgumentNullException(nameof(request));
}

var requestType = request.GetType();

var handler = _streamRequestHandlers.GetOrAdd(requestType,
requestTypeKey =>
var handler = _streamRequestHandlers.GetOrAdd(request.GetType(), static requestType =>
{
var requestInterfaceType = requestType.GetInterfaces().FirstOrDefault(static i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IStreamRequest<>));
if (requestInterfaceType is null)
{
var requestInterfaceType = requestTypeKey
.GetInterfaces()
.FirstOrDefault(static i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IStreamRequest<>));
var isValidRequest = requestInterfaceType != null;
throw new ArgumentException($"{requestType.Name} does not implement IStreamRequest<TResponse>", nameof(request));
}

if (!isValidRequest)
{
throw new ArgumentException($"{requestType.Name} does not implement IStreamRequest<TResponse>", nameof(requestTypeKey));
}
var responseType = requestInterfaceType.GetGenericArguments()[0];
var wrapperType = typeof(StreamRequestHandlerWrapperImpl<,>).MakeGenericType(requestType, responseType);
var wrapper = Activator.CreateInstance(wrapperType) ?? throw new InvalidOperationException($"Could not create wrapper for type {requestType}");
return (StreamRequestHandlerBase)wrapper;
});

var responseType = requestInterfaceType!.GetGenericArguments()[0];
return (StreamRequestHandlerBase) Activator.CreateInstance(typeof(StreamRequestHandlerWrapperImpl<,>).MakeGenericType(requestTypeKey, responseType));
});

// call via dynamic dispatch to avoid calling through reflection for performance reasons
var items = handler.Handle(request, _serviceProvider, cancellationToken);

return items;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
namespace MediatR.Pipeline;

using Internal;
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Collections.Generic;
using System.Linq;
Expand Down Expand Up @@ -71,7 +72,7 @@ private static IEnumerable<Type> GetExceptionTypes(Type? exceptionType)
var exceptionActionInterfaceType = typeof(IRequestExceptionAction<,>).MakeGenericType(typeof(TRequest), exceptionType);
var enumerableExceptionActionInterfaceType = typeof(IEnumerable<>).MakeGenericType(exceptionActionInterfaceType);

var actionsForException = (IEnumerable<object>)_serviceProvider.GetService(enumerableExceptionActionInterfaceType);
var actionsForException = (IEnumerable<object>)_serviceProvider.GetRequiredService(enumerableExceptionActionInterfaceType);

return HandlersOrderer.Prioritize(actionsForException.ToList(), request)
.Select(action => (exceptionType, action));
Expand Down
3 changes: 2 additions & 1 deletion src/MediatR/Pipeline/RequestExceptionProcessorBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
namespace MediatR.Pipeline;

using Internal;
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Collections.Generic;
using System.Linq;
Expand Down Expand Up @@ -88,7 +89,7 @@ private static IEnumerable<Type> GetExceptionTypes(Type? exceptionType)
var exceptionHandlerInterfaceType = typeof(IRequestExceptionHandler<,,>).MakeGenericType(typeof(TRequest), typeof(TResponse), exceptionType);
var enumerableExceptionHandlerInterfaceType = typeof(IEnumerable<>).MakeGenericType(exceptionHandlerInterfaceType);

var exceptionHandlers = (IEnumerable<object>) _serviceProvider.GetService(enumerableExceptionHandlerInterfaceType);
var exceptionHandlers = (IEnumerable<object>) _serviceProvider.GetRequiredService(enumerableExceptionHandlerInterfaceType);

return HandlersOrderer.Prioritize(exceptionHandlers.ToList(), request)
.Select(handler => (exceptionType, action: handler));
Expand Down
10 changes: 5 additions & 5 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
}
}

private static bool IsMatchingWithInterface(Type handlerType, Type handlerInterface)
private static bool IsMatchingWithInterface(Type? handlerType, Type handlerInterface)
{
if (handlerType == null || handlerInterface == null)
{
Expand Down Expand Up @@ -186,15 +186,15 @@ var interfaceType in
yield return interfaceType;
}
}
else if (pluggedType.GetTypeInfo().BaseType.GetTypeInfo().IsGenericType &&
(pluggedType.GetTypeInfo().BaseType.GetGenericTypeDefinition() == templateType))
else if (pluggedType.GetTypeInfo().BaseType!.GetTypeInfo().IsGenericType &&
(pluggedType.GetTypeInfo().BaseType!.GetGenericTypeDefinition() == templateType))
{
yield return pluggedType.GetTypeInfo().BaseType;
yield return pluggedType.GetTypeInfo().BaseType!;
}

if (pluggedType.GetTypeInfo().BaseType == typeof(object)) yield break;

foreach (var interfaceType in FindInterfacesThatClosesCore(pluggedType.GetTypeInfo().BaseType, templateType))
foreach (var interfaceType in FindInterfacesThatClosesCore(pluggedType.GetTypeInfo().BaseType!, templateType))
{
yield return interfaceType;
}
Expand Down