1111using Microsoft . VisualStudio . Threading ;
1212using Nerdbank . Streams ;
1313using Newtonsoft . Json ;
14- using StreamJsonRpc ;
15- using Xunit ;
16- using Xunit . Abstractions ;
1714
1815public abstract class AsyncEnumerableTests : TestBase , IAsyncLifetime
1916{
20- protected readonly Server server = new Server ( ) ;
17+ protected readonly Server server = new ( ) ;
18+ protected readonly Client client = new ( ) ;
19+
2120 protected JsonRpc serverRpc ;
2221 protected IJsonRpcMessageFormatter serverMessageFormatter ;
2322
2423 protected Lazy < IServer > clientProxy ;
24+ protected Lazy < IClient > serverProxy ;
2525 protected JsonRpc clientRpc ;
2626 protected IJsonRpcMessageFormatter clientMessageFormatter ;
2727
@@ -73,6 +73,13 @@ protected interface IServer
7373 Task PassInNumbersAndIgnoreAsync ( IAsyncEnumerable < int > numbers , CancellationToken cancellationToken ) ;
7474
7575 Task PassInNumbersOnlyStartEnumerationAsync ( IAsyncEnumerable < int > numbers , CancellationToken cancellationToken ) ;
76+
77+ IAsyncEnumerable < string > CallbackClientAndYieldOneValueAsync ( CancellationToken cancellationToken ) ;
78+ }
79+
80+ protected interface IClient
81+ {
82+ Task DoSomethingAsync ( CancellationToken cancellationToken ) ;
7683 }
7784
7885 public Task InitializeAsync ( )
@@ -85,7 +92,7 @@ public Task InitializeAsync()
8592 var clientHandler = new LengthHeaderMessageHandler ( streams . Item2 . UsePipe ( ) , this . clientMessageFormatter ) ;
8693
8794 this . serverRpc = new JsonRpc ( serverHandler , this . server ) ;
88- this . clientRpc = new JsonRpc ( clientHandler ) ;
95+ this . clientRpc = new JsonRpc ( clientHandler , this . client ) ;
8996
9097 this . serverRpc . TraceSource = new TraceSource ( "Server" , SourceLevels . Verbose ) ;
9198 this . clientRpc . TraceSource = new TraceSource ( "Client" , SourceLevels . Verbose ) ;
@@ -97,6 +104,7 @@ public Task InitializeAsync()
97104 this . clientRpc . StartListening ( ) ;
98105
99106 this . clientProxy = new Lazy < IServer > ( ( ) => this . clientRpc . Attach < IServer > ( ) ) ;
107+ this . serverProxy = new Lazy < IClient > ( ( ) => this . serverRpc . Attach < IClient > ( ) ) ;
100108
101109 return Task . CompletedTask ;
102110 }
@@ -530,6 +538,17 @@ public async Task AsyncIteratorThrows(int minBatchSize, int maxReadAhead, int pr
530538 Assert . Equal ( Server . FailByDesignExceptionMessage , ex . Message ) ;
531539 }
532540
541+ [ Fact ]
542+ public async Task EnumerableIdDisposal ( )
543+ {
544+ // This test is specially arranged to create two RPC calls going opposite directions, with the same request ID.
545+ // By doing so, we can verify that the server doesn't dispose the enumerable until the full sequence is sent to the client.
546+ this . server . Client = this . serverProxy . Value ;
547+ await foreach ( string s in this . clientProxy . Value . CallbackClientAndYieldOneValueAsync ( this . TimeoutToken ) )
548+ {
549+ }
550+ }
551+
533552 protected abstract void InitializeFormattersAndHandlers ( ) ;
534553
535554 private static void AssertCollectedObject ( WeakReference weakReference )
@@ -621,6 +640,8 @@ protected class Server : IServer
621640
622641 internal const string FailByDesignExceptionMessage = "Fail by design" ;
623642
643+ public IClient ? Client { get ; set ; }
644+
624645 public AsyncManualResetEvent MethodEntered { get ; } = new AsyncManualResetEvent ( ) ;
625646
626647 public AsyncManualResetEvent MethodExited { get ; } = new AsyncManualResetEvent ( ) ;
@@ -745,6 +766,18 @@ public Task<CompoundEnumerableResult> GetNumbersAndMetadataAsync(CancellationTok
745766 } ) ;
746767 }
747768
769+ public async IAsyncEnumerable < string > CallbackClientAndYieldOneValueAsync ( [ EnumeratorCancellation ] CancellationToken cancellationToken )
770+ {
771+ if ( this . Client is null )
772+ {
773+ throw new InvalidOperationException ( "Client must be set before calling this method." ) ;
774+ }
775+
776+ // We deliberately make a callback right away such that the request ID for it collides with the request ID that served THIS request.
777+ await this . Client . DoSomethingAsync ( cancellationToken ) ;
778+ yield return "Hello" ;
779+ }
780+
748781 private async IAsyncEnumerable < int > GetNumbersAsync ( int totalCount , bool endWithException , [ EnumeratorCancellation ] CancellationToken cancellationToken )
749782 {
750783 for ( int i = 1 ; i <= totalCount ; i ++ )
@@ -763,6 +796,11 @@ private async IAsyncEnumerable<int> GetNumbersAsync(int totalCount, bool endWith
763796 }
764797 }
765798
799+ protected class Client : IClient
800+ {
801+ public Task DoSomethingAsync ( CancellationToken cancellationToken ) => Task . CompletedTask ;
802+ }
803+
766804 [ DataContract ]
767805 protected class CompoundEnumerableResult
768806 {
0 commit comments