diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs
index defcf4164..bb38835d9 100644
--- a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs
+++ b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs
@@ -39,8 +39,15 @@ public class MessageFormatterEnumerableTracker
private static readonly MethodInfo OnDisposeAsyncMethodInfo = typeof(MessageFormatterEnumerableTracker).GetMethod(nameof(OnDisposeAsync), BindingFlags.NonPublic | BindingFlags.Instance)!;
///
- /// Dictionary used to map the outbound request id to their progress info so that the progress objects are cleaned after getting the final response.
+ /// Dictionary used to map the outbound request id to the list of tokens that track state machines it owns
+ /// so that the state machines are cleaned after getting the final response.
///
+ ///
+ /// Note that we only track OUTBOUND REQUESTS that carry enumerables here.
+ /// OUTBOUND RESPONSES that carry enumerables are not tracked except in .
+ /// This means that responses that carry enumerables will not be cleaned up if the response is never processed by the client
+ /// until the connection dies.
+ ///
private readonly Dictionary> generatorTokensByRequestId = new Dictionary>();
private readonly Dictionary generatorsByToken = new Dictionary();
@@ -116,12 +123,20 @@ public long GetToken(IAsyncEnumerable enumerable)
long handle = Interlocked.Increment(ref this.nextToken);
lock (this.syncObject)
{
- if (!this.generatorTokensByRequestId.TryGetValue(this.formatterState.SerializingMessageWithId, out ImmutableList? tokens))
+ // We only track the token if we are serializing a request, since per our documentation,
+ // we forcibly terminate the enumerable at the client side when the request has been responded to.
+ // Storing request IDs for outbound *responses* that carry enumerables would lead to them being disposed of
+ // when an INBOUND response with the same ID is received.
+ if (this.formatterState.SerializingRequest)
{
- tokens = ImmutableList.Empty;
+ if (!this.generatorTokensByRequestId.TryGetValue(this.formatterState.SerializingMessageWithId, out ImmutableList? tokens))
+ {
+ tokens = ImmutableList.Empty;
+ }
+
+ this.generatorTokensByRequestId[this.formatterState.SerializingMessageWithId] = tokens.Add(handle);
}
- this.generatorTokensByRequestId[this.formatterState.SerializingMessageWithId] = tokens.Add(handle);
this.generatorsByToken.Add(handle, new GeneratingEnumeratorTracker(this, handle, enumerable, settings: enumerable.GetJsonRpcSettings()));
}
@@ -173,18 +188,18 @@ private ValueTask OnDisposeAsync(long token)
return generator.DisposeAsync();
}
- private void CleanUpResources(RequestId requestId)
+ private void CleanUpResources(RequestId outboundRequestId)
{
lock (this.syncObject)
{
- if (this.generatorTokensByRequestId.TryGetValue(requestId, out ImmutableList? tokens))
+ if (this.generatorTokensByRequestId.TryGetValue(outboundRequestId, out ImmutableList? tokens))
{
foreach (var token in tokens)
{
this.generatorsByToken.Remove(token);
}
- this.generatorTokensByRequestId.Remove(requestId);
+ this.generatorTokensByRequestId.Remove(outboundRequestId);
}
}
}
diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs
index 6cbe8a40c..7d884573e 100644
--- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs
+++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs
@@ -11,17 +11,17 @@
using Microsoft.VisualStudio.Threading;
using Nerdbank.Streams;
using Newtonsoft.Json;
-using StreamJsonRpc;
-using Xunit;
-using Xunit.Abstractions;
public abstract class AsyncEnumerableTests : TestBase, IAsyncLifetime
{
- protected readonly Server server = new Server();
+ protected readonly Server server = new();
+ protected readonly Client client = new();
+
protected JsonRpc serverRpc;
protected IJsonRpcMessageFormatter serverMessageFormatter;
protected Lazy clientProxy;
+ protected Lazy serverProxy;
protected JsonRpc clientRpc;
protected IJsonRpcMessageFormatter clientMessageFormatter;
@@ -73,6 +73,13 @@ protected interface IServer
Task PassInNumbersAndIgnoreAsync(IAsyncEnumerable numbers, CancellationToken cancellationToken);
Task PassInNumbersOnlyStartEnumerationAsync(IAsyncEnumerable numbers, CancellationToken cancellationToken);
+
+ IAsyncEnumerable CallbackClientAndYieldOneValueAsync(CancellationToken cancellationToken);
+ }
+
+ protected interface IClient
+ {
+ Task DoSomethingAsync(CancellationToken cancellationToken);
}
public Task InitializeAsync()
@@ -85,7 +92,7 @@ public Task InitializeAsync()
var clientHandler = new LengthHeaderMessageHandler(streams.Item2.UsePipe(), this.clientMessageFormatter);
this.serverRpc = new JsonRpc(serverHandler, this.server);
- this.clientRpc = new JsonRpc(clientHandler);
+ this.clientRpc = new JsonRpc(clientHandler, this.client);
this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose);
this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose);
@@ -97,6 +104,7 @@ public Task InitializeAsync()
this.clientRpc.StartListening();
this.clientProxy = new Lazy(() => this.clientRpc.Attach());
+ this.serverProxy = new Lazy(() => this.serverRpc.Attach());
return Task.CompletedTask;
}
@@ -530,6 +538,17 @@ public async Task AsyncIteratorThrows(int minBatchSize, int maxReadAhead, int pr
Assert.Equal(Server.FailByDesignExceptionMessage, ex.Message);
}
+ [Fact]
+ public async Task EnumerableIdDisposal()
+ {
+ // This test is specially arranged to create two RPC calls going opposite directions, with the same request ID.
+ // By doing so, we can verify that the server doesn't dispose the enumerable until the full sequence is sent to the client.
+ this.server.Client = this.serverProxy.Value;
+ await foreach (string s in this.clientProxy.Value.CallbackClientAndYieldOneValueAsync(this.TimeoutToken))
+ {
+ }
+ }
+
protected abstract void InitializeFormattersAndHandlers();
private static void AssertCollectedObject(WeakReference weakReference)
@@ -621,6 +640,8 @@ protected class Server : IServer
internal const string FailByDesignExceptionMessage = "Fail by design";
+ public IClient? Client { get; set; }
+
public AsyncManualResetEvent MethodEntered { get; } = new AsyncManualResetEvent();
public AsyncManualResetEvent MethodExited { get; } = new AsyncManualResetEvent();
@@ -745,6 +766,18 @@ public Task GetNumbersAndMetadataAsync(CancellationTok
});
}
+ public async IAsyncEnumerable CallbackClientAndYieldOneValueAsync([EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ if (this.Client is null)
+ {
+ throw new InvalidOperationException("Client must be set before calling this method.");
+ }
+
+ // We deliberately make a callback right away such that the request ID for it collides with the request ID that served THIS request.
+ await this.Client.DoSomethingAsync(cancellationToken);
+ yield return "Hello";
+ }
+
private async IAsyncEnumerable GetNumbersAsync(int totalCount, bool endWithException, [EnumeratorCancellation] CancellationToken cancellationToken)
{
for (int i = 1; i <= totalCount; i++)
@@ -763,6 +796,11 @@ private async IAsyncEnumerable GetNumbersAsync(int totalCount, bool endWith
}
}
+ protected class Client : IClient
+ {
+ public Task DoSomethingAsync(CancellationToken cancellationToken) => Task.CompletedTask;
+ }
+
[DataContract]
protected class CompoundEnumerableResult
{