diff --git a/doc/rpc_marshalable_objects.md b/doc/rpc_marshalable_objects.md index fe34c8a42..ad78651a2 100644 --- a/doc/rpc_marshalable_objects.md +++ b/doc/rpc_marshalable_objects.md @@ -7,12 +7,45 @@ StreamJsonRpc allows transmitting marshalable objects (i.e., objects implementin Marshalable interfaces must: -1. Extend `IDisposable`. +1. Extend `IDisposable` (unless interface is call-scoped). 1. Not include any properties. 1. Not include any events. The object that implements a marshalable interface may include properties and events as well as other additional members but only the methods defined by the marshalable interface will be available on the proxy, and the data will not be serialized. +The `RpcMarshalableAttribute` must be applied directly to the interface used as the return type, parameter type, or member type within a return type or parameter type's object graph. +The attribute is not inherited. +In fact different interfaces in a type hierarchy can have this attribute applied with distinct settings, and only the settings on the attribute applied directly to the interface used will apply. + +## Call-scoped vs. explicitly scoped + +### Explicit lifetime + +An RPC marshalable interface has an explicit lifetime by default. +This means that the receiver of a marshaled object owns its lifetime, which may extend beyond an individual RPC call. +Memory for the marshaled object and its proxy are not released until the receiver either disposes of its proxy or the JSON-RPC connection is closed. + +### Call-scoped lifetime + +A call-scoped interface produces a proxy that is valid only during the RPC call that delivered it. +It may only be used as part of a method request as or within an argument. +Using it as or within a return value or exception will result in an error. + +This is the preferred model when an interface is expected to only be used within request arguments because it mitigates the risk of a memory leak due to the receiver failing to dispose of the proxy. +This model also allows the sender to retain control over the lifetime of the marshaled object. + +Special allowance is made for `IAsyncEnumerable`-returning RPC methods so that the lifetime of the marshaled object is extended to the lifetime of the enumeration. +An `IAsyncEnumerable` in an exception thrown from the method will *not* have access to the call-scoped marshaled object because exceptions thrown by the server always cause termination of objects marshaled by the request. + +Opt into call-scoped lifetimes by setting the `CallScopedLifetime` property to `true` on the attribute applied to the interface: + +```css +[RpcMarshalable(CallScopedLifetime = true)] +``` + +It is not possible to customize the lifetime of an RPC marshaled object except on its own interface. +For example, applying this attribute to the parameter that uses the interface is not allowed. + ## Use cases In all cases, the special handling of a marshalable object only occurs if the container of that value is typed as the corresponding marshalable interface. @@ -104,6 +137,8 @@ class RpcServer : IRpcServer } ``` +Call-scoped marshalable interfaces may not be used as a return type or member of its object graph. + ### Method argument In this use case the RPC *client* provides the marshalable object to the server: @@ -119,6 +154,8 @@ var counter = new Counter(); await client.ProvideCounterAsync(counter); ``` +Call-scoped marshalable interfaces may only appear as a method parameter or a part of its object graph. + ### Value within a single argument's object graph In this use case the RPC client again provides the marshalable object to the server, @@ -144,6 +181,7 @@ await client.ProvideClassAsync(arg); ``` ⚠️ While this use case is supported, be very wary of this pattern because it becomes less obvious to the receiver that an `IDisposable` value is tucked into the object tree of an argument somewhere that *must* be disposed to avoid a resource leak. +This risk can be mitigated by using call-scoped marshalable interfaces. ### As an argument without a proxy for an RPC interface diff --git a/test/StreamJsonRpc.Tests/DisposableAction.cs b/src/StreamJsonRpc/DisposableAction.cs similarity index 52% rename from test/StreamJsonRpc.Tests/DisposableAction.cs rename to src/StreamJsonRpc/DisposableAction.cs index 436e77b60..78a909605 100644 --- a/test/StreamJsonRpc.Tests/DisposableAction.cs +++ b/src/StreamJsonRpc/DisposableAction.cs @@ -1,22 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using Microsoft; +namespace StreamJsonRpc; internal class DisposableAction : IDisposableObservable { - private readonly Action? disposeAction; + private static readonly Action EmptyAction = () => { }; + private Action? disposeAction; internal DisposableAction(Action? disposeAction) { - this.disposeAction = disposeAction; + this.disposeAction = disposeAction ?? EmptyAction; } - public bool IsDisposed { get; private set; } + public bool IsDisposed => this.disposeAction is null; public void Dispose() { - this.IsDisposed = true; - this.disposeAction?.Invoke(); + Interlocked.Exchange(ref this.disposeAction, null)?.Invoke(); } } diff --git a/src/StreamJsonRpc/FormatterBase.cs b/src/StreamJsonRpc/FormatterBase.cs index 6791ba0b4..7dd4c6479 100644 --- a/src/StreamJsonRpc/FormatterBase.cs +++ b/src/StreamJsonRpc/FormatterBase.cs @@ -89,9 +89,9 @@ JsonRpc IJsonRpcInstanceContainer.Rpc this.rpc = value; this.formatterProgressTracker = new MessageFormatterProgressTracker(value, this); - this.enumerableTracker = new MessageFormatterEnumerableTracker(value, this); - this.duplexPipeTracker = new MessageFormatterDuplexPipeTracker(value, this) { MultiplexingStream = this.MultiplexingStream }; this.rpcMarshaledContextTracker = new MessageFormatterRpcMarshaledContextTracker(value, this); + this.enumerableTracker = new MessageFormatterEnumerableTracker(value, this, this.rpcMarshaledContextTracker); + this.duplexPipeTracker = new MessageFormatterDuplexPipeTracker(value, this) { MultiplexingStream = this.MultiplexingStream }; } } } diff --git a/src/StreamJsonRpc/JsonMessageFormatter.cs b/src/StreamJsonRpc/JsonMessageFormatter.cs index d51b2409a..f7ec01d2a 100644 --- a/src/StreamJsonRpc/JsonMessageFormatter.cs +++ b/src/StreamJsonRpc/JsonMessageFormatter.cs @@ -351,7 +351,7 @@ public JToken Serialize(JsonRpcMessage message) Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this.JsonSerializer); /// - Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this.JsonSerializer); + Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this); /// protected override void Dispose(bool disposing) @@ -570,9 +570,9 @@ private JTokenWriter CreateJTokenWriter() private bool TryGetMarshaledJsonConverter(Type type, [NotNullWhen(true)] out RpcMarshalableConverter? converter) { - if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(type, out JsonRpcProxyOptions? proxyOptions, out JsonRpcTargetOptions? targetOptions)) + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(type, out JsonRpcProxyOptions? proxyOptions, out JsonRpcTargetOptions? targetOptions, out RpcMarshalableAttribute? rpcMarshalableAttribute)) { - converter = new RpcMarshalableConverter(type, this, proxyOptions, targetOptions); + converter = new RpcMarshalableConverter(type, this, proxyOptions, targetOptions, rpcMarshalableAttribute); return true; } @@ -616,7 +616,7 @@ private JsonRpcResult ReadResult(JToken json) RequestId id = this.ExtractRequestId(json); JToken? result = json["result"]; - return new JsonRpcResult(this.JsonSerializer) + return new JsonRpcResult(this) { RequestId = id, Result = result, @@ -837,15 +837,27 @@ public override bool TryGetArgumentByNameOrIndex(string? name, int position, Typ [DataContract] private class JsonRpcResult : JsonRpcResultBase { - private readonly JsonSerializer jsonSerializer; + private readonly JsonMessageFormatter formatter; + private bool resultDeserialized; + private JsonSerializationException? resultDeserializationException; - internal JsonRpcResult(JsonSerializer jsonSerializer) + internal JsonRpcResult(JsonMessageFormatter formatter) { - this.jsonSerializer = jsonSerializer ?? throw new ArgumentNullException(nameof(jsonSerializer)); + this.formatter = formatter; } public override T GetResult() { + if (this.resultDeserializationException is not null) + { + ExceptionDispatchInfo.Capture(this.resultDeserializationException).Throw(); + } + + if (this.resultDeserialized) + { + return (T)this.Result!; + } + Verify.Operation(this.Result is not null, "This instance hasn't been initialized with a result yet."); var result = (JToken)this.Result; if (result.Type == JTokenType.Null) @@ -856,7 +868,10 @@ public override T GetResult() try { - return result.ToObject(this.jsonSerializer)!; + using (this.formatter.TrackDeserialization(this)) + { + return result.ToObject(this.formatter.JsonSerializer)!; + } } catch (Exception exception) { @@ -864,7 +879,34 @@ public override T GetResult() } } - protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.jsonSerializer); + protected internal override void SetExpectedResultType(Type resultType) + { + Verify.Operation(this.Result is not null, "This instance hasn't been initialized with a result yet."); + Verify.Operation(!this.resultDeserialized, "Result is no longer available or has already been deserialized."); + + var result = (JToken)this.Result; + if (result.Type == JTokenType.Null) + { + Verify.Operation(!resultType.GetTypeInfo().IsValueType || Nullable.GetUnderlyingType(resultType) is not null, "null result is not assignable to a value type."); + return; + } + + try + { + using (this.formatter.TrackDeserialization(this)) + { + this.Result = result.ToObject(resultType, this.formatter.JsonSerializer)!; + this.resultDeserialized = true; + } + } + catch (Exception exception) + { + // This was a best effort anyway. We'll throw again later at a more convenient time for JsonRpc. + this.resultDeserializationException = new JsonSerializationException(string.Format(CultureInfo.CurrentCulture, Resources.FailureDeserializingRpcResult, resultType.Name, exception.GetType().Name, exception.Message), exception); + } + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter.JsonSerializer); } private class JsonRpcError : JsonRpcErrorBase @@ -1207,29 +1249,16 @@ public override void WriteJson(JsonWriter writer, Stream? value, JsonSerializer } [DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")] - private class RpcMarshalableConverter : JsonConverter + private class RpcMarshalableConverter(Type interfaceType, JsonMessageFormatter jsonMessageFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions, RpcMarshalableAttribute rpcMarshalableAttribute) : JsonConverter { - private readonly Type interfaceType; - private readonly JsonMessageFormatter jsonMessageFormatter; - private readonly JsonRpcProxyOptions proxyOptions; - private readonly JsonRpcTargetOptions targetOptions; - - public RpcMarshalableConverter(Type interfaceType, JsonMessageFormatter jsonMessageFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions) - { - this.interfaceType = interfaceType; - this.jsonMessageFormatter = jsonMessageFormatter; - this.proxyOptions = proxyOptions; - this.targetOptions = targetOptions; - } - - private string DebuggerDisplay => $"Converter for marshalable objects of type {this.interfaceType.FullName}"; + private string DebuggerDisplay => $"Converter for marshalable objects of type {interfaceType.FullName}"; - public override bool CanConvert(Type objectType) => objectType == this.interfaceType; + public override bool CanConvert(Type objectType) => objectType == interfaceType; public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { var token = (MessageFormatterRpcMarshaledContextTracker.MarshalToken?)JToken.Load(reader).ToObject(typeof(MessageFormatterRpcMarshaledContextTracker.MarshalToken), serializer); - return this.jsonMessageFormatter.RpcMarshaledContextTracker.GetObject(objectType, token, this.proxyOptions); + return jsonMessageFormatter.RpcMarshaledContextTracker.GetObject(objectType, token, proxyOptions); } public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) @@ -1238,13 +1267,13 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer { writer.WriteNull(); } - else if (!this.interfaceType.IsAssignableFrom(value.GetType())) + else if (!interfaceType.IsAssignableFrom(value.GetType())) { - throw new InvalidOperationException($"Type {value.GetType().FullName} doesn't implement {this.interfaceType.FullName}"); + throw new InvalidOperationException($"Type {value.GetType().FullName} doesn't implement {interfaceType.FullName}"); } else { - MessageFormatterRpcMarshaledContextTracker.MarshalToken token = this.jsonMessageFormatter.RpcMarshaledContextTracker.GetToken(value, this.targetOptions, this.interfaceType); + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = jsonMessageFormatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, interfaceType, rpcMarshalableAttribute); serializer.Serialize(writer, token); } } diff --git a/src/StreamJsonRpc/JsonRpc.cs b/src/StreamJsonRpc/JsonRpc.cs index 15512bbb3..fee321b53 100644 --- a/src/StreamJsonRpc/JsonRpc.cs +++ b/src/StreamJsonRpc/JsonRpc.cs @@ -2569,56 +2569,66 @@ private async Task HandleRpcAsync(JsonRpcMessage rpc) } else if (rpc is IJsonRpcMessageWithId resultOrError) { - this.OnResponseReceived(rpc); - JsonRpcResult? result = resultOrError as JsonRpcResult; - JsonRpcError? error = resultOrError as JsonRpcError; - - lock (this.dispatcherMapLock) + try { - if (this.resultDispatcherMap.TryGetValue(resultOrError.RequestId, out data)) - { - this.resultDispatcherMap.Remove(resultOrError.RequestId); - } - } + JsonRpcResult? result = resultOrError as JsonRpcResult; + JsonRpcError? error = resultOrError as JsonRpcError; - if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information)) - { - if (result is not null) - { - this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEvents.ReceivedResult, "Received result for request \"{0}\".", result.RequestId); - } - else if (error?.Error is object) + lock (this.dispatcherMapLock) { - this.TraceSource.TraceEvent(TraceEventType.Warning, (int)TraceEvents.ReceivedError, "Received error response for request {0}: {1} \"{2}\": ", error.RequestId, error.Error.Code, error.Error.Message); + if (this.resultDispatcherMap.TryGetValue(resultOrError.RequestId, out data)) + { + this.resultDispatcherMap.Remove(resultOrError.RequestId); + } } - } - if (data is object) - { - if (data.ExpectedResultType is not null && rpc is JsonRpcResult resultMessage) + if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information)) { - resultMessage.SetExpectedResultType(data.ExpectedResultType); + if (result is not null) + { + this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEvents.ReceivedResult, "Received result for request \"{0}\".", result.RequestId); + } + else if (error?.Error is object) + { + this.TraceSource.TraceEvent(TraceEventType.Warning, (int)TraceEvents.ReceivedError, "Received error response for request {0}: {1} \"{2}\": ", error.RequestId, error.Error.Code, error.Error.Message); + } } - else if (rpc is JsonRpcError errorMessage && errorMessage.Error is not null) + + if (data is object) { - Type? errorType = this.GetErrorDetailsDataType(errorMessage); - if (errorType is not null) + if (data.ExpectedResultType is not null && rpc is JsonRpcResult resultMessage) { - errorMessage.Error.SetExpectedDataType(errorType); + resultMessage.SetExpectedResultType(data.ExpectedResultType); + } + else if (rpc is JsonRpcError errorMessage && errorMessage.Error is not null) + { + Type? errorType = this.GetErrorDetailsDataType(errorMessage); + if (errorType is not null) + { + errorMessage.Error.SetExpectedDataType(errorType); + } } + + this.OnResponseReceived(rpc); + + // Complete the caller's request with the response asynchronously so it doesn't delay handling of other JsonRpc messages. + await TaskScheduler.Default.SwitchTo(alwaysYield: true); + data.CompletionHandler(rpc); + data = null; // avoid invoking again if we throw later } + else + { + this.OnResponseReceived(rpc); - // Complete the caller's request with the response asynchronously so it doesn't delay handling of other JsonRpc messages. - await TaskScheduler.Default.SwitchTo(alwaysYield: true); - data.CompletionHandler(rpc); - data = null; // avoid invoking again if we throw later + // Unexpected "response" to no request we have a record of. Raise disconnected event. + this.OnJsonRpcDisconnected(new JsonRpcDisconnectedEventArgs( + Resources.UnexpectedResponseWithNoMatchingRequest, + DisconnectedReason.RemoteProtocolViolation)); + } } - else + catch { - // Unexpected "response" to no request we have a record of. Raise disconnected event. - this.OnJsonRpcDisconnected(new JsonRpcDisconnectedEventArgs( - Resources.UnexpectedResponseWithNoMatchingRequest, - DisconnectedReason.RemoteProtocolViolation)); + this.OnResponseReceived(rpc); } } else diff --git a/src/StreamJsonRpc/MessagePackFormatter.cs b/src/StreamJsonRpc/MessagePackFormatter.cs index 8b43c0ebd..fbf5f93bf 100644 --- a/src/StreamJsonRpc/MessagePackFormatter.cs +++ b/src/StreamJsonRpc/MessagePackFormatter.cs @@ -233,7 +233,7 @@ public void Serialize(IBufferWriter contentBuffer, JsonRpcMessage message) Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this.userDataSerializationOptions); /// - Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this.messageSerializationOptions); + Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this, this.messageSerializationOptions); void IJsonRpcFormatterTracingCallbacks.OnSerializationComplete(JsonRpcMessage message, ReadOnlySequence encodedMessage) { @@ -1271,13 +1271,14 @@ internal RpcMarshalableResolver(MessagePackFormatter formatter) } } - if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(typeof(T), out JsonRpcProxyOptions? proxyOptions, out JsonRpcTargetOptions? targetOptions)) + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(typeof(T), out JsonRpcProxyOptions? proxyOptions, out JsonRpcTargetOptions? targetOptions, out RpcMarshalableAttribute? attribute)) { object formatter = Activator.CreateInstance( typeof(RpcMarshalableFormatter<>).MakeGenericType(typeof(T)), this.formatter, proxyOptions, - targetOptions)!; + targetOptions, + attribute)!; lock (this.formatters) { @@ -1295,25 +1296,14 @@ internal RpcMarshalableResolver(MessagePackFormatter formatter) } #pragma warning disable CA1812 - private class RpcMarshalableFormatter : IMessagePackFormatter + private class RpcMarshalableFormatter(MessagePackFormatter messagePackFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions, RpcMarshalableAttribute rpcMarshalableAttribute) : IMessagePackFormatter where T : class #pragma warning restore CA1812 { - private MessagePackFormatter messagePackFormatter; - private JsonRpcProxyOptions proxyOptions; - private JsonRpcTargetOptions targetOptions; - - public RpcMarshalableFormatter(MessagePackFormatter messagePackFormatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions) - { - this.messagePackFormatter = messagePackFormatter; - this.proxyOptions = proxyOptions; - this.targetOptions = targetOptions; - } - public T? Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) { MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = MessagePackSerializer.Deserialize(ref reader, options); - return token.HasValue ? (T?)this.messagePackFormatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, this.proxyOptions) : null; + return token.HasValue ? (T?)messagePackFormatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) : null; } public void Serialize(ref MessagePackWriter writer, T? value, MessagePackSerializerOptions options) @@ -1324,7 +1314,7 @@ public void Serialize(ref MessagePackWriter writer, T? value, MessagePackSeriali } else { - MessageFormatterRpcMarshaledContextTracker.MarshalToken token = this.messagePackFormatter.RpcMarshaledContextTracker.GetToken(value, this.targetOptions, typeof(T)); + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = messagePackFormatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); MessagePackSerializer.Serialize(ref writer, token, options); } } @@ -1802,7 +1792,7 @@ internal JsonRpcResultFormatter(MessagePackFormatter formatter) public Protocol.JsonRpcResult Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) { - var result = new JsonRpcResult(this.formatter.userDataSerializationOptions) + var result = new JsonRpcResult(this.formatter, this.formatter.userDataSerializationOptions) { OriginalMessagePack = reader.Sequence, }; @@ -2276,11 +2266,13 @@ protected override void ReleaseBuffers() private class JsonRpcResult : JsonRpcResultBase, IJsonRpcMessagePackRetention { private readonly MessagePackSerializerOptions serializerOptions; + private readonly MessagePackFormatter formatter; private Exception? resultDeserializationException; - internal JsonRpcResult(MessagePackSerializerOptions serializerOptions) + internal JsonRpcResult(MessagePackFormatter formatter, MessagePackSerializerOptions serializerOptions) { + this.formatter = formatter; this.serializerOptions = serializerOptions; } @@ -2307,7 +2299,11 @@ protected internal override void SetExpectedResultType(Type resultType) var reader = new MessagePackReader(this.MsgPackResult); try { - this.Result = MessagePackSerializer.Deserialize(resultType, ref reader, this.serializerOptions); + using (this.formatter.TrackDeserialization(this)) + { + this.Result = MessagePackSerializer.Deserialize(resultType, ref reader, this.serializerOptions); + } + this.MsgPackResult = default; } catch (MessagePackSerializationException ex) diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs index aa62ec81d..ad5ba89c6 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs @@ -55,23 +55,30 @@ public class MessageFormatterEnumerableTracker private readonly JsonRpc jsonRpc; private readonly IJsonRpcFormatterState formatterState; - + private readonly MessageFormatterRpcMarshaledContextTracker? rpcTracker; private readonly object syncObject = new object(); private long nextToken; + /// + public MessageFormatterEnumerableTracker(JsonRpc jsonRpc, IJsonRpcFormatterState formatterState) + : this(jsonRpc, formatterState, null) + { + } + /// /// Initializes a new instance of the class. /// /// The instance that may be used to send or receive RPC messages related to . /// The formatter that owns this tracker. - public MessageFormatterEnumerableTracker(JsonRpc jsonRpc, IJsonRpcFormatterState formatterState) + /// The RPC marshalable object support used by the formatter, if applicable. + internal MessageFormatterEnumerableTracker(JsonRpc jsonRpc, IJsonRpcFormatterState formatterState, MessageFormatterRpcMarshaledContextTracker? rpcTracker) { Requires.NotNull(jsonRpc, nameof(jsonRpc)); Requires.NotNull(formatterState, nameof(formatterState)); this.jsonRpc = jsonRpc; this.formatterState = formatterState; - + this.rpcTracker = rpcTracker; jsonRpc.AddLocalRpcMethod(NextMethodName, OnNextAsyncMethodInfo, this); jsonRpc.AddLocalRpcMethod(DisposeMethodName, OnDisposeAsyncMethodInfo, this); this.formatterState = formatterState; @@ -156,7 +163,13 @@ public long GetToken(IAsyncEnumerable enumerable) public IAsyncEnumerable CreateEnumerableProxy(object? handle, IReadOnlyList? prefetchedItems) #pragma warning restore VSTHRD200 // Use "Async" suffix in names of methods that return an awaitable type. { - return new AsyncEnumerableProxy(this.jsonRpc, handle, prefetchedItems); + IDisposable? requestResourcesDeferral = null; + if (handle is not null && this.rpcTracker is not null && !this.formatterState.DeserializingMessageWithId.IsEmpty) + { + requestResourcesDeferral = this.rpcTracker.OutboundCleanupDeferral(this.formatterState.DeserializingMessageWithId); + } + + return new AsyncEnumerableProxy(this.jsonRpc, handle, prefetchedItems, requestResourcesDeferral); } private ValueTask OnNextAsync(long token, CancellationToken cancellationToken) @@ -344,15 +357,17 @@ private class AsyncEnumerableProxy : IAsyncEnumerable { private readonly JsonRpc jsonRpc; private readonly bool finished; + private readonly IDisposable? requestResourcesDeferral; private object? handle; private bool enumeratorAcquired; private IReadOnlyList? prefetchedItems; - internal AsyncEnumerableProxy(JsonRpc jsonRpc, object? handle, IReadOnlyList? prefetchedItems) + internal AsyncEnumerableProxy(JsonRpc jsonRpc, object? handle, IReadOnlyList? prefetchedItems, IDisposable? requestResourcesDeferral) { this.jsonRpc = jsonRpc; this.handle = handle; this.prefetchedItems = prefetchedItems; + this.requestResourcesDeferral = requestResourcesDeferral; this.finished = handle is null; } @@ -360,7 +375,7 @@ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToke { Verify.Operation(!this.enumeratorAcquired, Resources.CannotBeCalledAfterGetAsyncEnumerator); this.enumeratorAcquired = true; - var result = new AsyncEnumeratorProxy(this, this.handle, this.prefetchedItems, this.finished, cancellationToken); + var result = new AsyncEnumeratorProxy(this, this.handle, this.prefetchedItems, this.finished, this.requestResourcesDeferral, cancellationToken); this.prefetchedItems = null; return result; } @@ -373,6 +388,7 @@ private class AsyncEnumeratorProxy : IAsyncEnumerator private readonly AsyncEnumerableProxy owner; private readonly CancellationToken cancellationToken; private readonly object[]? nextOrDisposeArguments; + private readonly IDisposable? requestResourcesDeferral; /// /// A sequence of values that have already been received from the generator but not yet consumed. @@ -388,7 +404,7 @@ private class AsyncEnumeratorProxy : IAsyncEnumerator private bool disposed; - internal AsyncEnumeratorProxy(AsyncEnumerableProxy owner, object? handle, IReadOnlyList? prefetchedItems, bool finished, CancellationToken cancellationToken) + internal AsyncEnumeratorProxy(AsyncEnumerableProxy owner, object? handle, IReadOnlyList? prefetchedItems, bool finished, IDisposable? requestResourcesDeferral, CancellationToken cancellationToken) { this.owner = owner; this.nextOrDisposeArguments = handle is not null ? new object[] { handle } : null; @@ -400,6 +416,7 @@ internal AsyncEnumeratorProxy(AsyncEnumerableProxy owner, object? handle, IRe } this.generatorReportsFinished = finished; + this.requestResourcesDeferral = requestResourcesDeferral; } public T Current @@ -430,6 +447,9 @@ public async ValueTask DisposeAsync() { await this.owner.jsonRpc.NotifyAsync(DisposeMethodName, this.nextOrDisposeArguments).ConfigureAwait(false); } + + // Clean up any local resources that were held open for the remote source of the enumeration. + this.requestResourcesDeferral?.Dispose(); } } diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index d087a4909..0688c2647 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -18,9 +18,9 @@ namespace StreamJsonRpc.Reflection; /// internal class MessageFormatterRpcMarshaledContextTracker { - private static readonly IReadOnlyCollection<(Type ImplicitlyMarshaledType, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions)> ImplicitlyMarshaledTypes = new (Type ImplicitlyMarshaledType, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions)[] + private static readonly IReadOnlyCollection<(Type ImplicitlyMarshaledType, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions, RpcMarshalableAttribute Attribute)> ImplicitlyMarshaledTypes = new (Type, JsonRpcProxyOptions, JsonRpcTargetOptions, RpcMarshalableAttribute)[] { - (typeof(IDisposable), new JsonRpcProxyOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase }, new JsonRpcTargetOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase }), + (typeof(IDisposable), new JsonRpcProxyOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase }, new JsonRpcTargetOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase }, new RpcMarshalableAttribute()), // IObserver support requires special recognition of OnCompleted and OnError be considered terminating calls. ( @@ -45,10 +45,11 @@ internal class MessageFormatterRpcMarshaledContextTracker }; }, }, - new JsonRpcTargetOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase }), + new JsonRpcTargetOptions { MethodNameTransform = CommonMethodNameTransforms.CamelCase }, + new RpcMarshalableAttribute()), }; - private static readonly ConcurrentDictionary MarshaledTypes = new ConcurrentDictionary(); + private static readonly ConcurrentDictionary MarshaledTypes = new(); private static readonly (JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions) RpcMarshalableInterfaceDefaultOptions = (new JsonRpcProxyOptions(), new JsonRpcTargetOptions { NotifyClientOfEvents = false, DisposeOnDisconnect = true }); private static readonly MethodInfo ReleaseMarshaledObjectMethodInfo = typeof(MessageFormatterRpcMarshaledContextTracker).GetMethod(nameof(ReleaseMarshaledObject), BindingFlags.NonPublic | BindingFlags.Instance)!; private static readonly ConcurrentDictionary MarshalableOptionalInterfaces = new ConcurrentDictionary(); @@ -66,7 +67,7 @@ internal class MessageFormatterRpcMarshaledContextTracker /// and the request ends up not being transmitted for any reason. /// It will only contain the data until the request is either aborted or a response is received. /// - private ImmutableDictionary> outboundRequestIdMarshalMap = ImmutableDictionary>.Empty; + private ImmutableDictionary> outboundRequestIdMarshalMap = ImmutableDictionary>.Empty; internal MessageFormatterRpcMarshaledContextTracker(JsonRpc jsonRpc, IJsonRpcFormatterState formatterState) { @@ -90,23 +91,25 @@ private enum MarshalMode MarshallingRealObject = 1, } - internal static bool TryGetMarshalOptionsForType(Type type, [NotNullWhen(true)] out JsonRpcProxyOptions? proxyOptions, [NotNullWhen(true)] out JsonRpcTargetOptions? targetOptions) + internal static bool TryGetMarshalOptionsForType(Type type, [NotNullWhen(true)] out JsonRpcProxyOptions? proxyOptions, [NotNullWhen(true)] out JsonRpcTargetOptions? targetOptions, [NotNullWhen(true)] out RpcMarshalableAttribute? rpcMarshalableAttribute) { proxyOptions = null; targetOptions = null; + rpcMarshalableAttribute = null; if (type.IsInterface is false) { return false; } - if (MarshaledTypes.TryGetValue(type, out (JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions) options)) + if (MarshaledTypes.TryGetValue(type, out (JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions, RpcMarshalableAttribute Attribute) options)) { proxyOptions = options.ProxyOptions; targetOptions = options.TargetOptions; + rpcMarshalableAttribute = options.Attribute; return true; } - foreach ((Type implicitlyMarshaledType, JsonRpcProxyOptions typeProxyOptions, JsonRpcTargetOptions typeTargetOptions) in ImplicitlyMarshaledTypes) + foreach ((Type implicitlyMarshaledType, JsonRpcProxyOptions typeProxyOptions, JsonRpcTargetOptions typeTargetOptions, RpcMarshalableAttribute attribute) in ImplicitlyMarshaledTypes) { if (implicitlyMarshaledType == type || (implicitlyMarshaledType.IsGenericTypeDefinition && @@ -115,18 +118,20 @@ internal static bool TryGetMarshalOptionsForType(Type type, [NotNullWhen(true)] { proxyOptions = typeProxyOptions; targetOptions = typeTargetOptions; - MarshaledTypes.TryAdd(type, (proxyOptions, targetOptions)); + rpcMarshalableAttribute = attribute; + MarshaledTypes.TryAdd(type, (proxyOptions, targetOptions, rpcMarshalableAttribute)); return true; } } - if (type.GetCustomAttribute() is not null) + if (type.GetCustomAttribute() is RpcMarshalableAttribute marshalableAttribute) { - ValidateMarshalableInterface(type); + ValidateMarshalableInterface(type, marshalableAttribute); proxyOptions = RpcMarshalableInterfaceDefaultOptions.ProxyOptions; targetOptions = RpcMarshalableInterfaceDefaultOptions.TargetOptions; - MarshaledTypes.TryAdd(type, (proxyOptions, targetOptions)); + rpcMarshalableAttribute = marshalableAttribute; + MarshaledTypes.TryAdd(type, (proxyOptions, targetOptions, rpcMarshalableAttribute)); return true; } @@ -138,6 +143,7 @@ internal static bool TryGetMarshalOptionsForType(Type type, [NotNullWhen(true)] /// . /// /// The type to get attributes from. + /// The attribute that appears on the declared type. /// The list of applied to /// . /// If an invalid set of @@ -147,7 +153,7 @@ internal static bool TryGetMarshalOptionsForType(Type type, [NotNullWhen(true)] /// values are duplicated, or if an /// optional interface is not marked with or it is not a valid marshalable /// interface. - internal static RpcMarshalableOptionalInterfaceAttribute[] GetMarshalableOptionalInterfaces(Type declaredType) + internal static RpcMarshalableOptionalInterfaceAttribute[] GetMarshalableOptionalInterfaces(Type declaredType, RpcMarshalableAttribute rpcMarshalableAttribute) { return MarshalableOptionalInterfaces.GetOrAdd(declaredType, declaredType => { @@ -169,7 +175,9 @@ internal static RpcMarshalableOptionalInterfaceAttribute[] GetMarshalableOptiona throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.RpcMarshalableOptionalInterfaceMustBeMarshalable, attribute.OptionalInterface.FullName)); } - ValidateMarshalableInterface(attribute.OptionalInterface); + // We pass in the declared interface's own attribute rather than the attribute that appears on the optional interface. + // Only one attribute can control the policy for this marshaled object. + ValidateMarshalableInterface(attribute.OptionalInterface, rpcMarshalableAttribute); } return attributes; @@ -182,14 +190,20 @@ internal static RpcMarshalableOptionalInterfaceAttribute[] GetMarshalableOptiona /// The object to be exposed over RPC. /// /// The marshalable interface type of as declared in the RPC contract. + /// The attribute that defines certain options that control which marshaling rules will be followed. /// A token to be serialized so the remote party can invoke methods on the marshaled object. - internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions options, Type declaredType) + internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions options, Type declaredType, RpcMarshalableAttribute rpcMarshalableAttribute) { if (this.formatterState.SerializingMessageWithId.IsEmpty) { throw new NotSupportedException(Resources.MarshaledObjectInNotificationError); } + if (rpcMarshalableAttribute.CallScopedLifetime && !this.formatterState.SerializingRequest) + { + throw new NotSupportedException(Resources.CallScopedMarshaledObjectInReturnValueNotAllowed); + } + long handle = this.nextUniqueHandle++; IRpcMarshaledContext context = JsonRpc.MarshalWithControlledLifetime(declaredType, marshaledObject, options); @@ -201,13 +215,14 @@ internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions opti { NotifyClientOfEvents = false, // We don't support this yet. MethodNameTransform = mn => Invariant($"$/invokeProxy/{handle}/{context.JsonRpcTargetOptions.MethodNameTransform?.Invoke(mn) ?? mn}"), + DisposeOnDisconnect = !rpcMarshalableAttribute.CallScopedLifetime, }, requestRevertOption: true); Assumes.NotNull(revert); Type objectType = marshaledObject.GetType(); List? optionalInterfacesCodes = null; - foreach (RpcMarshalableOptionalInterfaceAttribute attribute in GetMarshalableOptionalInterfaces(declaredType)) + foreach (RpcMarshalableOptionalInterfaceAttribute attribute in GetMarshalableOptionalInterfaces(declaredType, rpcMarshalableAttribute)) { if (attribute.OptionalInterface.IsAssignableFrom(objectType)) { @@ -236,11 +251,12 @@ internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions opti ImmutableInterlocked.AddOrUpdate( ref this.outboundRequestIdMarshalMap, this.formatterState.SerializingMessageWithId, - ImmutableList.Create(handle), - (key, value) => value.Add(handle)); + ImmutableList.Create((handle, rpcMarshalableAttribute.CallScopedLifetime)), + (key, value) => value.Add((handle, rpcMarshalableAttribute.CallScopedLifetime))); } - return new MarshalToken((int)MarshalMode.MarshallingRealObject, handle, lifetime: null, optionalInterfacesCodes?.ToArray()); + string? lifetime = rpcMarshalableAttribute.CallScopedLifetime ? MarshalLifetime.Call : null; + return new MarshalToken((int)MarshalMode.MarshallingRealObject, handle, lifetime, optionalInterfacesCodes?.ToArray()); } /// @@ -263,18 +279,17 @@ internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions opti throw new NotSupportedException("Receiving marshaled objects back to the owner is not yet supported."); } - if (token.Value.Lifetime == MarshalLifetime.Call) + RpcMarshalableAttribute synthesizedAttribute = new() { - throw new NotSupportedException("Receiving marshaled objects scoped to the lifetime of a single RPC request is not yet supported."); - } - + CallScopedLifetime = token.Value.Lifetime == MarshalLifetime.Call, + }; List<(TypeInfo Type, int Code)>? optionalInterfaces = null; if (token.Value.OptionalInterfacesCodes?.Length > 0) { // We ignore unknown optional interface codes foreach (int optionalInterfacesCode in token.Value.OptionalInterfacesCodes.Distinct()) { - foreach (RpcMarshalableOptionalInterfaceAttribute attribute in GetMarshalableOptionalInterfaces(interfaceType)) + foreach (RpcMarshalableOptionalInterfaceAttribute attribute in GetMarshalableOptionalInterfaces(interfaceType, synthesizedAttribute)) { if (attribute.OptionalInterfaceCode == optionalInterfacesCode) { @@ -293,7 +308,7 @@ internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions opti new JsonRpcProxyOptions(options) { MethodNameTransform = mn => Invariant($"$/invokeProxy/{token.Value.Handle}/{options.MethodNameTransform(mn)}"), - OnDispose = delegate + OnDispose = token.Value.Lifetime == MarshalLifetime.Call ? null : delegate { // Only forward the Dispose call if the marshaled interface derives from IDisposable. if (typeof(IDisposable).IsAssignableFrom(interfaceType)) @@ -312,17 +327,49 @@ internal MarshalToken GetToken(object marshaledObject, JsonRpcTargetOptions opti return result; } + /// + /// Called near the conclusion of a successful outbound request (i.e. when processing the received response) + /// to extend the lifetime of call-scoped marshaled objects. + /// + /// The ID of the request to extend. + /// A value that may be disposed of to finally release the resources bound up with the request. + /// + /// This is useful to keep call-scoped arguments alive while the server's result + /// is still active, suggesting the server may still need access to the arguments passed to it. + /// + internal IDisposable? OutboundCleanupDeferral(RequestId requestId) + { + // Remove the handles from the map so that they don't get cleaned up when the request is completed. + if (ImmutableInterlocked.TryRemove(ref this.outboundRequestIdMarshalMap, requestId, out ImmutableList<(long Handle, bool CallScoped)>? handles)) + { + return new DisposableAction(delegate + { + // Add the handles back to the map so that they get cleaned up in the normal way, which we then invoke immediately, + // since the time to clean them up normally has presumably already passed. + Assumes.True(ImmutableInterlocked.TryAdd(ref this.outboundRequestIdMarshalMap, requestId, handles)); + this.CleanUpOutboundResources(requestId, successful: true); + }); + } + else + { + // Nothing to defer. + return null; + } + } + /// /// Throws if is not a valid marshalable interface. /// This method doesn't validate that has the /// attribute. /// /// The interface to validate. + /// The attribute that appears on the interface. /// When is not a valid marshalable interface: this /// can happen if has properties, events or it is not disposable. - private static void ValidateMarshalableInterface(Type type) + private static void ValidateMarshalableInterface(Type type, RpcMarshalableAttribute attribute) { - if (typeof(IDisposable).IsAssignableFrom(type) is false) + // We only require marshalable interfaces to derive from IDisposable when they are not call-scoped. + if (!attribute.CallScopedLifetime && !typeof(IDisposable).IsAssignableFrom(type)) { throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, Resources.MarshalableInterfaceNotDisposable, type.FullName)); } @@ -341,7 +388,7 @@ private static void ValidateMarshalableInterface(Type type) /// /// Releases memory associated with marshaled objects. /// - /// The handle to the object as created by the method. + /// The handle to the object as created by the method. /// if the was created by (and thus the original object owned by) the remote party; if the token and object was created locally. private void ReleaseMarshaledObject(long handle, bool ownedBySender) { @@ -360,13 +407,14 @@ private void ReleaseMarshaledObject(long handle, bool ownedBySender) private void CleanUpOutboundResources(RequestId requestId, bool successful) { - if (ImmutableInterlocked.TryRemove(ref this.outboundRequestIdMarshalMap, requestId, out ImmutableList? handles)) + if (ImmutableInterlocked.TryRemove(ref this.outboundRequestIdMarshalMap, requestId, out ImmutableList<(long Handle, bool CallScoped)>? handles)) { - // Only kill the marshaled objects if the server threw an error. - // Successful responses make it the responsibility of the client/server to terminate the marshaled connection. - if (!successful) + foreach ((long handle, bool callScoped) in handles) { - foreach (long handle in handles) + // For explicit lifetime objects, we only kill the marshaled objects if the server threw an error. + // Successful responses make it the responsibility of the client/server to terminate the marshaled connection. + // But for call-scoped objects, we always release them when the outbound request is complete, by error or result. + if (callScoped || !successful) { // We use "ownedBySender: false" because the method we're calling is accustomed to the perspective being the "other" party. this.ReleaseMarshaledObject(handle, ownedBySender: false); diff --git a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs index 68b1e7879..07e5f29d6 100644 --- a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs +++ b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs @@ -12,4 +12,16 @@ namespace StreamJsonRpc; [AttributeUsage(AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] public class RpcMarshalableAttribute : Attribute { + /// + /// Gets a value indicating whether the marshaled object is only allowed in requests + /// and may only be invoked by the receiver until the response is sent. + /// + /// + /// Objects marshaled via an interface attributed with this property set to true may only be used as RPC method parameters. + /// They will not be allowed as return values from RPC methods. + /// While the receiver may dispose of the proxy they receive, this disposal will not propagate to the sender, + /// and their originating object will not be disposed of. + /// The original object owner retains ownership of the lifetime of the object after the RPC call. + /// + public bool CallScopedLifetime { get; init; } } diff --git a/src/StreamJsonRpc/Resources.resx b/src/StreamJsonRpc/Resources.resx index 4eaf998dc..f61ffd63e 100644 --- a/src/StreamJsonRpc/Resources.resx +++ b/src/StreamJsonRpc/Resources.resx @@ -123,6 +123,9 @@ Both readable and writable are null. + + A call-scoped marshaled object was included in a return value, which is not allowed. + A CancellationToken is only allowed as the last parameter. diff --git a/src/StreamJsonRpc/SystemTextJsonFormatter.cs b/src/StreamJsonRpc/SystemTextJsonFormatter.cs index 6dacc292d..5e41a5fc9 100644 --- a/src/StreamJsonRpc/SystemTextJsonFormatter.cs +++ b/src/StreamJsonRpc/SystemTextJsonFormatter.cs @@ -625,7 +625,11 @@ protected internal override void SetExpectedResultType(Type resultType) try { - this.Result = this.JsonResult.Value.Deserialize(resultType, this.formatter.massagedUserDataSerializerOptions); + using (this.formatter.TrackDeserialization(this)) + { + this.Result = this.JsonResult.Value.Deserialize(resultType, this.formatter.massagedUserDataSerializerOptions); + } + this.JsonResult = default; } catch (Exception ex) @@ -885,42 +889,32 @@ public RpcMarshalableConverterFactory(SystemTextJsonFormatter formatter) public override bool CanConvert(Type typeToConvert) { - return MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(typeToConvert, out _, out _); + return MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(typeToConvert, out _, out _, out _); } public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - Assumes.True(MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(typeToConvert, out JsonRpcProxyOptions? proxyOptions, out JsonRpcTargetOptions? targetOptions)); + Assumes.True(MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(typeToConvert, out JsonRpcProxyOptions? proxyOptions, out JsonRpcTargetOptions? targetOptions, out RpcMarshalableAttribute? attribute)); return (JsonConverter)Activator.CreateInstance( typeof(Converter<>).MakeGenericType(typeToConvert), this.formatter, proxyOptions, - targetOptions)!; + targetOptions, + attribute)!; } - private class Converter : JsonConverter + private class Converter(SystemTextJsonFormatter formatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions, RpcMarshalableAttribute rpcMarshalableAttribute) : JsonConverter where T : class { - private readonly SystemTextJsonFormatter formatter; - private readonly JsonRpcProxyOptions proxyOptions; - private readonly JsonRpcTargetOptions targetOptions; - - public Converter(SystemTextJsonFormatter formatter, JsonRpcProxyOptions proxyOptions, JsonRpcTargetOptions targetOptions) - { - this.formatter = formatter; - this.proxyOptions = proxyOptions; - this.targetOptions = targetOptions; - } - public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { MessageFormatterRpcMarshaledContextTracker.MarshalToken token = JsonSerializer.Deserialize(ref reader, options); - return (T)this.formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, this.proxyOptions); + return (T)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions); } public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) { - MessageFormatterRpcMarshaledContextTracker.MarshalToken token = this.formatter.RpcMarshaledContextTracker.GetToken(value, this.targetOptions, typeof(T)); + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); JsonSerializer.Serialize(writer, token, options); } } diff --git a/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt b/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt index fe3ee13c6..3e65b072f 100644 --- a/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt @@ -2,3 +2,5 @@ StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker.JoinableTaskTokenTracker() -> void StreamJsonRpc.JsonRpc.JoinableTaskTracker.get -> StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker! StreamJsonRpc.JsonRpc.JoinableTaskTracker.set -> void +StreamJsonRpc.RpcMarshalableAttribute.CallScopedLifetime.get -> bool +StreamJsonRpc.RpcMarshalableAttribute.CallScopedLifetime.init -> void diff --git a/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt b/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt index fe3ee13c6..3e65b072f 100644 --- a/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt +++ b/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt @@ -2,3 +2,5 @@ StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker.JoinableTaskTokenTracker() -> void StreamJsonRpc.JsonRpc.JoinableTaskTracker.get -> StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker! StreamJsonRpc.JsonRpc.JoinableTaskTracker.set -> void +StreamJsonRpc.RpcMarshalableAttribute.CallScopedLifetime.get -> bool +StreamJsonRpc.RpcMarshalableAttribute.CallScopedLifetime.init -> void diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index ffb574c21..e201d8a33 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -456,6 +456,7 @@ public async Task NotifyAsync_ThrowsIfAsyncEnumerableSent() public async Task ArgumentEnumerable_ReleasedOnErrorResponse() { WeakReference enumerable = await this.ArgumentEnumerable_ReleasedOnErrorResponse_Helper(); + await Task.Yield(); // get off the helper's inline continuation stack. AssertCollectedObject(enumerable); } @@ -465,6 +466,7 @@ public async Task ArgumentEnumerable_ReleasedOnErrorResponse() public async Task ArgumentEnumerable_ReleasedOnErrorInSubsequentArgumentSerialization() { WeakReference enumerable = await this.ArgumentEnumerable_ReleasedOnErrorInSubsequentArgumentSerialization_Helper(); + await Task.Yield(); // get off the helper's inline continuation stack. AssertCollectedObject(enumerable); } @@ -474,6 +476,7 @@ public async Task ArgumentEnumerable_ReleasedOnErrorInSubsequentArgumentSerializ public async Task ArgumentEnumerable_ReleasedWhenIgnoredBySuccessfulRpcCall() { WeakReference enumerable = await this.ArgumentEnumerable_ReleasedWhenIgnoredBySuccessfulRpcCall_Helper(); + await Task.Yield(); // get off the helper's inline continuation stack. AssertCollectedObject(enumerable); } @@ -483,6 +486,7 @@ public async Task ArgumentEnumerable_ReleasedWhenIgnoredBySuccessfulRpcCall() public async Task ArgumentEnumerable_ForciblyDisposedAndReleasedWhenNotDisposedWithinRpcCall() { WeakReference enumerable = await this.ArgumentEnumerable_ForciblyDisposedAndReleasedWhenNotDisposedWithinRpcCall_Helper(); + await Task.Yield(); // get off the helper's inline continuation stack. AssertCollectedObject(enumerable); // Assert that if the RPC server tries to enumerate more values after it returns that it gets the right exception. @@ -495,6 +499,7 @@ public async Task ArgumentEnumerable_ForciblyDisposedAndReleasedWhenNotDisposedW public async Task ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod() { WeakReference enumerable = await this.ReturnEnumerable_AutomaticallyReleasedOnErrorFromIteratorMethod_Helper(); + await Task.Yield(); // get off the helper's inline continuation stack. AssertCollectedObject(enumerable); } @@ -551,19 +556,6 @@ public async Task EnumerableIdDisposal() protected abstract void InitializeFormattersAndHandlers(); - private static void AssertCollectedObject(WeakReference weakReference) - { - GC.Collect(); - - // For some reason the assertion tends to be sketchy when running on Azure Pipelines. - if (IsTestRunOnAzurePipelines) - { - Skip.If(weakReference.IsAlive); - } - - Assert.False(weakReference.IsAlive); - } - [MethodImpl(MethodImplOptions.NoInlining)] private async Task ArgumentEnumerable_ReleasedOnErrorResponse_Helper() { diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs index ea0f1757e..a12fb17f3 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs @@ -93,6 +93,11 @@ public interface IMarshalable : INonMarshalable { } + [RpcMarshalable(CallScopedLifetime = true)] + public interface IMarshalableWithCallScopedLifetime : IMarshalable + { + } + [RpcMarshalable] public interface IGenericMarshalable : IMarshalable { @@ -264,6 +269,14 @@ public interface IServer Task AcceptNonMarshalableAsync(INonMarshalable nonMarshalable); Task AcceptNonMarshalableDerivedFromMarshalablesAsync(INonMarshalableDerivedFromMarshalable nonMarshalable); + + Task CallScopedMarshalableAsync(IMarshalableWithCallScopedLifetime marshalable); + + Task ReturnCallScopedObjectAsync(); + + IAsyncEnumerable CallScopedMarshalableReturnsAsyncEnumerable(IMarshalableWithCallScopedLifetime marshalable); + + Task CallScopedMarshalableThrowsWithAsyncEnumerable(IMarshalableWithCallScopedLifetime marshalable); } protected abstract Type FormatterExceptionType { get; } @@ -898,6 +911,90 @@ public async Task RpcMarshalableOptionalInterface_MultipleImplementationsCombine Assert.Equal(6, await ((IMarshalableSubTypesCombined)proxy).GetPlusFiveAsync(1)); } + [Fact] + public async Task RpcMarshalable_CallScopedLifetime() + { + MarshalableAndSerializable marshaled = new(); + await this.client.CallScopedMarshalableAsync(marshaled); + Assert.True(marshaled.DoSomethingCalled); + Assert.False(marshaled.IsDisposed); + this.clientRpc.Dispose(); + Assert.False(marshaled.IsDisposed); + } + + [Fact] + public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableReturned() + { + MarshalableAndSerializable marshaled = new(); + await foreach (int number in this.client.CallScopedMarshalableReturnsAsyncEnumerable(marshaled)) + { + Assert.Equal(42, number); + Assert.True(marshaled.DoSomethingCalled); + } + + // Verify that after enumeration conclusion, the proxy is zombied. + this.server.AllowContinuation.Set(); + Assert.NotNull(this.server.ContinuationResult); + await Assert.ThrowsAsync(() => this.server.ContinuationResult).WithCancellation(this.TimeoutToken); + } + + [Fact] + public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableThrown() + { + this.clientRpc.AllowModificationWhileListening = true; + this.serverRpc.AllowModificationWhileListening = true; + this.clientRpc.ExceptionStrategy = ExceptionProcessing.ISerializable; + this.serverRpc.ExceptionStrategy = ExceptionProcessing.ISerializable; + + MarshalableAndSerializable marshaled = new(); + var outerException = await Assert.ThrowsAsync(() => this.client.CallScopedMarshalableThrowsWithAsyncEnumerable(marshaled)); + var ex = Assert.IsType(outerException.InnerException); + + // Verify that the proxy is zombied immediately because the request failed. + // Successfully enumerating is sufficient for this because the enumerator method has the assertion of zombie inside it. + Assert.NotNull(ex.Enumerable); + await foreach (int number in ex.Enumerable) + { + Assert.Equal(42, number); + } + } + + [Fact] + public async Task RpcMarshalable_CallScopedLifetime_InvokedAfterReturn() + { + MarshalableAndSerializable marshaled = new(); + await this.client.CallScopedMarshalableAsync(marshaled); + this.server.AllowContinuation.Set(); + Assert.NotNull(this.server.ContinuationResult); + await Assert.ThrowsAsync(() => this.server.ContinuationResult).WithCancellation(this.TimeoutToken); + } + + [SkippableFact] + [Trait("GC", "")] + public async Task RpcMarshalable_CallScopedLifetime_ObjectCollected() + { + WeakReference weakRef = await HelperAsync(this.client); + await Task.Yield(); // get off the helper's inline continuation stack. + AssertCollectedObject(weakRef); + + [MethodImpl(MethodImplOptions.NoInlining)] + static async Task HelperAsync(IServer client) + { + MarshalableAndSerializable? marshaled = new(); + await client.CallScopedMarshalableAsync(marshaled); + WeakReference result = new(marshaled); + marshaled = null; + return result; + } + } + + [Fact] + public async Task RpcMarshalable_CallScopedLifetime_ObjectReturned() + { + var ex = await Assert.ThrowsAsync(this.client.ReturnCallScopedObjectAsync); + this.Logger.WriteLine(ex.ToString()); + } + protected abstract IJsonRpcMessageFormatter CreateFormatter(); private static void AssertIsNot(object obj, Type type) @@ -907,6 +1004,8 @@ private static void AssertIsNot(object obj, Type type) public class Server : IServer { + internal AsyncAutoResetEvent AllowContinuation { get; } = new(); + internal AsyncManualResetEvent ReturnedMarshalableDisposed { get; } = new AsyncManualResetEvent(); internal IMarshalable? ReturnedMarshalable { get; set; } @@ -915,6 +1014,8 @@ public class Server : IServer internal IMarshalableWithOptionalInterfaces? ReturnedMarshalableWithOptionalInterfaces { get; set; } + internal Task? ContinuationResult { get; private set; } + public Task GetMarshalableAsync(bool returnNull) { // The OneObjectMarshalledTwiceHasIndependentLifetimes test depends on us returning the same instance each time. @@ -1040,6 +1141,57 @@ public Task AcceptGenericProxyContainerAsync(ProxyContainer Task.CompletedTask; public Task AcceptNonMarshalableDerivedFromMarshalablesAsync(INonMarshalableDerivedFromMarshalable nonMarshalable) => Task.CompletedTask; + + public async Task CallScopedMarshalableAsync(IMarshalableWithCallScopedLifetime marshalable) + { + await marshalable.DoSomethingAsync(); + + this.ContinuationResult = Task.Run(async delegate + { + await this.AllowContinuation.WaitAsync(); + await marshalable.DoSomethingAsync(); + }); + } + + public Task ReturnCallScopedObjectAsync() + { + // Returning a call-scoped object as a return type is illegal. + // This method is used in a test that verifies the failure mode for this case. + return Task.FromResult(new MarshalableAndSerializable()); + } + + public async IAsyncEnumerable CallScopedMarshalableReturnsAsyncEnumerable(IMarshalableWithCallScopedLifetime marshalable) + { + // Yield before using the marshalable since we want to test that the call-scoped argument + // is available for the whole async enumerable state machine. + await Task.Yield(); + + await marshalable.DoSomethingAsync(); + yield return 42; + + // This allows a test to optionally verify that the call-scoped object + // quits working after the enumeration has completed. + this.ContinuationResult = Task.Run(async delegate + { + await this.AllowContinuation.WaitAsync(); + await marshalable.DoSomethingAsync(); + }); + } + + public Task CallScopedMarshalableThrowsWithAsyncEnumerable(IMarshalableWithCallScopedLifetime marshalable) + { + throw new ExceptionWithAsyncEnumerable(Helper()); + + async IAsyncEnumerable Helper() + { + await Task.Yield(); + + // By the time this runs, the original request has failed and using the call-scoped argument is expected to fail too. + await Assert.ThrowsAsync(() => marshalable.DoSomethingAsync()); + + yield return 42; + } + } } [DataContract] @@ -1112,12 +1264,15 @@ public Task DoSomethingAsync() } } - public class MarshalableAndSerializable : IMarshalableAndSerializable + public class MarshalableAndSerializable : IMarshalableAndSerializable, IMarshalableWithCallScopedLifetime { public bool DoSomethingCalled { get; private set; } + public bool IsDisposed { get; private set; } + public void Dispose() { + this.IsDisposed = true; } public Task DoSomethingAsync() @@ -1336,4 +1491,27 @@ public void Dispose() { } } + + [Serializable] + public class ExceptionWithAsyncEnumerable : Exception + { + public ExceptionWithAsyncEnumerable(IAsyncEnumerable enumeration) + { + this.Enumerable = enumeration; + } + + protected ExceptionWithAsyncEnumerable(SerializationInfo info, StreamingContext context) + : base(info, context) + { + this.Enumerable = (IAsyncEnumerable?)info.GetValue(nameof(this.Enumerable), typeof(IAsyncEnumerable)); + } + + internal IAsyncEnumerable? Enumerable { get; set; } + + public override void GetObjectData(SerializationInfo info, StreamingContext context) + { + base.GetObjectData(info, context); + info.AddValue(nameof(this.Enumerable), this.Enumerable); + } + } } diff --git a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj index 74042bc01..bee5edab3 100644 --- a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj +++ b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj @@ -7,6 +7,7 @@ + diff --git a/test/StreamJsonRpc.Tests/TestBase.cs b/test/StreamJsonRpc.Tests/TestBase.cs index f1cd06741..e805214f8 100644 --- a/test/StreamJsonRpc.Tests/TestBase.cs +++ b/test/StreamJsonRpc.Tests/TestBase.cs @@ -46,6 +46,19 @@ public void Dispose() GC.SuppressFinalize(this); } + protected static void AssertCollectedObject(WeakReference weakReference) + { + GC.Collect(); + + // For some reason the assertion tends to be sketchy when running on Azure Pipelines. + if (IsTestRunOnAzurePipelines) + { + Skip.If(weakReference.IsAlive); + } + + Assert.False(weakReference.IsAlive); + } + /// /// Checks whether a given exception or any transitive inner exception has a given type. ///