Skip to content

Commit 85ac63f

Browse files
authored
Merge pull request #1004 from AArnott/fix999
Fixes early destruction of `IAsyncEnumerable<T>` sent as return value from RPC methods
2 parents 330c007 + 61901d8 commit 85ac63f

File tree

2 files changed

+65
-12
lines changed

2 files changed

+65
-12
lines changed

src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,15 @@ public class MessageFormatterEnumerableTracker
4040
private static readonly MethodInfo OnDisposeAsyncMethodInfo = typeof(MessageFormatterEnumerableTracker).GetMethod(nameof(OnDisposeAsync), BindingFlags.NonPublic | BindingFlags.Instance)!;
4141

4242
/// <summary>
43-
/// Dictionary used to map the outbound request id to their progress info so that the progress objects are cleaned after getting the final response.
43+
/// Dictionary used to map the outbound request id to the list of tokens that track <see cref="IAsyncEnumerable{T}"/> state machines it owns
44+
/// so that the state machines are cleaned after getting the final response.
4445
/// </summary>
46+
/// <remarks>
47+
/// Note that we only track OUTBOUND REQUESTS that carry enumerables here.
48+
/// OUTBOUND RESPONSES that carry enumerables are not tracked except in <see cref="generatorsByToken"/>.
49+
/// This means that responses that carry enumerables will not be cleaned up if the response is never processed by the client
50+
/// until the connection dies.
51+
/// </remarks>
4552
private readonly Dictionary<RequestId, ImmutableList<long>> generatorTokensByRequestId = new Dictionary<RequestId, ImmutableList<long>>();
4653

4754
private readonly Dictionary<long, IGeneratingEnumeratorTracker> generatorsByToken = new Dictionary<long, IGeneratingEnumeratorTracker>();
@@ -117,12 +124,20 @@ public long GetToken<T>(IAsyncEnumerable<T> enumerable)
117124
long handle = Interlocked.Increment(ref this.nextToken);
118125
lock (this.syncObject)
119126
{
120-
if (!this.generatorTokensByRequestId.TryGetValue(this.formatterState.SerializingMessageWithId, out ImmutableList<long>? tokens))
127+
// We only track the token if we are serializing a request, since per our documentation,
128+
// we forcibly terminate the enumerable at the client side when the request has been responded to.
129+
// Storing request IDs for outbound *responses* that carry enumerables would lead to them being disposed of
130+
// when an INBOUND response with the same ID is received.
131+
if (this.formatterState.SerializingRequest)
121132
{
122-
tokens = ImmutableList<long>.Empty;
133+
if (!this.generatorTokensByRequestId.TryGetValue(this.formatterState.SerializingMessageWithId, out ImmutableList<long>? tokens))
134+
{
135+
tokens = ImmutableList<long>.Empty;
136+
}
137+
138+
this.generatorTokensByRequestId[this.formatterState.SerializingMessageWithId] = tokens.Add(handle);
123139
}
124140

125-
this.generatorTokensByRequestId[this.formatterState.SerializingMessageWithId] = tokens.Add(handle);
126141
this.generatorsByToken.Add(handle, new GeneratingEnumeratorTracker<T>(this, handle, enumerable, settings: enumerable.GetJsonRpcSettings()));
127142
}
128143

@@ -174,18 +189,18 @@ private ValueTask OnDisposeAsync(long token)
174189
return generator.DisposeAsync();
175190
}
176191

177-
private void CleanUpResources(RequestId requestId)
192+
private void CleanUpResources(RequestId outboundRequestId)
178193
{
179194
lock (this.syncObject)
180195
{
181-
if (this.generatorTokensByRequestId.TryGetValue(requestId, out ImmutableList<long>? tokens))
196+
if (this.generatorTokensByRequestId.TryGetValue(outboundRequestId, out ImmutableList<long>? tokens))
182197
{
183198
foreach (var token in tokens)
184199
{
185200
this.generatorsByToken.Remove(token);
186201
}
187202

188-
this.generatorTokensByRequestId.Remove(requestId);
203+
this.generatorTokensByRequestId.Remove(outboundRequestId);
189204
}
190205
}
191206
}

test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
using Microsoft.VisualStudio.Threading;
1212
using Nerdbank.Streams;
1313
using Newtonsoft.Json;
14-
using StreamJsonRpc;
15-
using Xunit;
16-
using Xunit.Abstractions;
1714

1815
public abstract class AsyncEnumerableTests : TestBase, IAsyncLifetime
1916
{
20-
protected readonly Server server = new Server();
17+
protected readonly Server server = new();
18+
protected readonly Client client = new();
19+
2120
protected JsonRpc serverRpc;
2221
protected IJsonRpcMessageFormatter serverMessageFormatter;
2322

2423
protected Lazy<IServer> clientProxy;
24+
protected Lazy<IClient> serverProxy;
2525
protected JsonRpc clientRpc;
2626
protected IJsonRpcMessageFormatter clientMessageFormatter;
2727

@@ -73,6 +73,13 @@ protected interface IServer
7373
Task PassInNumbersAndIgnoreAsync(IAsyncEnumerable<int> numbers, CancellationToken cancellationToken);
7474

7575
Task PassInNumbersOnlyStartEnumerationAsync(IAsyncEnumerable<int> numbers, CancellationToken cancellationToken);
76+
77+
IAsyncEnumerable<string> CallbackClientAndYieldOneValueAsync(CancellationToken cancellationToken);
78+
}
79+
80+
protected interface IClient
81+
{
82+
Task DoSomethingAsync(CancellationToken cancellationToken);
7683
}
7784

7885
public Task InitializeAsync()
@@ -85,7 +92,7 @@ public Task InitializeAsync()
8592
var clientHandler = new LengthHeaderMessageHandler(streams.Item2.UsePipe(), this.clientMessageFormatter);
8693

8794
this.serverRpc = new JsonRpc(serverHandler, this.server);
88-
this.clientRpc = new JsonRpc(clientHandler);
95+
this.clientRpc = new JsonRpc(clientHandler, this.client);
8996

9097
this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose);
9198
this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose);
@@ -97,6 +104,7 @@ public Task InitializeAsync()
97104
this.clientRpc.StartListening();
98105

99106
this.clientProxy = new Lazy<IServer>(() => this.clientRpc.Attach<IServer>());
107+
this.serverProxy = new Lazy<IClient>(() => this.serverRpc.Attach<IClient>());
100108

101109
return Task.CompletedTask;
102110
}
@@ -530,6 +538,17 @@ public async Task AsyncIteratorThrows(int minBatchSize, int maxReadAhead, int pr
530538
Assert.Equal(Server.FailByDesignExceptionMessage, ex.Message);
531539
}
532540

541+
[Fact]
542+
public async Task EnumerableIdDisposal()
543+
{
544+
// This test is specially arranged to create two RPC calls going opposite directions, with the same request ID.
545+
// By doing so, we can verify that the server doesn't dispose the enumerable until the full sequence is sent to the client.
546+
this.server.Client = this.serverProxy.Value;
547+
await foreach (string s in this.clientProxy.Value.CallbackClientAndYieldOneValueAsync(this.TimeoutToken))
548+
{
549+
}
550+
}
551+
533552
protected abstract void InitializeFormattersAndHandlers();
534553

535554
private static void AssertCollectedObject(WeakReference weakReference)
@@ -621,6 +640,8 @@ protected class Server : IServer
621640

622641
internal const string FailByDesignExceptionMessage = "Fail by design";
623642

643+
public IClient? Client { get; set; }
644+
624645
public AsyncManualResetEvent MethodEntered { get; } = new AsyncManualResetEvent();
625646

626647
public AsyncManualResetEvent MethodExited { get; } = new AsyncManualResetEvent();
@@ -745,6 +766,18 @@ public Task<CompoundEnumerableResult> GetNumbersAndMetadataAsync(CancellationTok
745766
});
746767
}
747768

769+
public async IAsyncEnumerable<string> CallbackClientAndYieldOneValueAsync([EnumeratorCancellation] CancellationToken cancellationToken)
770+
{
771+
if (this.Client is null)
772+
{
773+
throw new InvalidOperationException("Client must be set before calling this method.");
774+
}
775+
776+
// We deliberately make a callback right away such that the request ID for it collides with the request ID that served THIS request.
777+
await this.Client.DoSomethingAsync(cancellationToken);
778+
yield return "Hello";
779+
}
780+
748781
private async IAsyncEnumerable<int> GetNumbersAsync(int totalCount, bool endWithException, [EnumeratorCancellation] CancellationToken cancellationToken)
749782
{
750783
for (int i = 1; i <= totalCount; i++)
@@ -763,6 +796,11 @@ private async IAsyncEnumerable<int> GetNumbersAsync(int totalCount, bool endWith
763796
}
764797
}
765798

799+
protected class Client : IClient
800+
{
801+
public Task DoSomethingAsync(CancellationToken cancellationToken) => Task.CompletedTask;
802+
}
803+
766804
[DataContract]
767805
protected class CompoundEnumerableResult
768806
{

0 commit comments

Comments
 (0)