diff --git a/src/test/lib/HandshakeTest.cpp b/src/test/lib/HandshakeTest.cpp index cdb55b65b4..5ae60c1967 100644 --- a/src/test/lib/HandshakeTest.cpp +++ b/src/test/lib/HandshakeTest.cpp @@ -4100,8 +4100,51 @@ QuicTestHandshakeSpecificLossPatterns( } } +QUIC_STATUS +ConnectionPoolServerConnectionCallback( + _In_ MsQuicConnection* /* Connection */, + _In_opt_ void* /* Context */, + _Inout_ QUIC_CONNECTION_EVENT* Event + ) +{ + if (Event->Type == QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED) { + auto Stream = + new(std::nothrow) MsQuicStream( + Event->PEER_STREAM_STARTED.Stream, + CleanUpAutoDelete, + [](MsQuicStream* Stream, void*, QUIC_STREAM_EVENT* Event){ + if (Event->Type == QUIC_STREAM_EVENT_RECEIVE) { + auto SendBuffer = new (std::nothrow) QUIC_BUFFER[Event->RECEIVE.BufferCount + 1]; + // The first QUIC_BUFFER is a dummy to store the total buffer length + SendBuffer[0].Buffer = nullptr; + SendBuffer[0].Length = (uint32_t)Event->RECEIVE.TotalBufferLength; + for (uint32_t i = 0; i < Event->RECEIVE.BufferCount; i++) { + SendBuffer[i + 1] = Event->RECEIVE.Buffers[i]; + } + auto Status= Stream->Send(&SendBuffer[1], Event->RECEIVE.BufferCount, QUIC_SEND_FLAG_FIN, SendBuffer); + if (QUIC_FAILED(Status)) { + TEST_FAILURE("Return Send Failed with 0x%x", Status); + Stream->Shutdown(Status); + return QUIC_STATUS_SUCCESS; + } + Event->RECEIVE.TotalBufferLength = 0; + return QUIC_STATUS_PENDING; + } else if (Event->Type == QUIC_STREAM_EVENT_SEND_COMPLETE) { + auto SendBuffer = (QUIC_BUFFER*) Event->SEND_COMPLETE.ClientContext; + uint64_t TotalBufferLength = SendBuffer[0].Length; + Stream->ReceiveComplete(TotalBufferLength); + delete[] SendBuffer; + } + return QUIC_STATUS_SUCCESS; + }); + UNREFERENCED_PARAMETER(Stream); // This stream will clean itself up so it's not leaked here. + } + return QUIC_STATUS_SUCCESS; +} + struct ConnectionPoolConnectionContext { CxPlatEvent ConnectedEvent{}; + CxPlatEvent DataReceivedEvent{}; uint16_t IdealProcessor{}; uint16_t PartitionIndex{}; bool Connected{false}; @@ -4170,6 +4213,10 @@ QuicTestConnectionPoolCreate( QuicAddrSetToDuoNic(&ServerAddr.SockAddr); } + const uint32_t SendBytes = 200; + UniquePtrArray SendData(new(std::nothrow) uint8_t[SendBytes]); + QUIC_BUFFER SendBuffer{SendBytes, SendData.get()}; + // // Make sure to create the connection contexts before the connections, // to ensure they are not freed before the connection is closed. @@ -4183,6 +4230,9 @@ QuicTestConnectionPoolCreate( UniquePtrArray Connections(new(std::nothrow) ConnectionScope[NumberOfConnections]); TEST_NOT_EQUAL(nullptr, Connections); + UniquePtrArray Streams(new(std::nothrow) StreamScope[NumberOfConnections]); + TEST_NOT_EQUAL(nullptr, Streams); + for (uint32_t i = 0; i < NumberOfConnections; ++i) { ContextPtrs[i] = &Contexts[i]; Contexts[i].Connected = false; @@ -4190,7 +4240,7 @@ QuicTestConnectionPoolCreate( Contexts[i].PartitionIndex = 0; } - MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, MsQuicConnection::NoOpCallback); + MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, ConnectionPoolServerConnectionCallback); TEST_QUIC_SUCCEEDED(Listener.GetInitStatus()); if (TestCibirSupport) { @@ -4239,6 +4289,34 @@ QuicTestConnectionPoolCreate( QUIC_STATUS Status = MsQuic->ConnectionPoolCreate(&PoolConfig, &(Connections.get()->Handle)); if (XdpSupported) { TEST_QUIC_SUCCEEDED(Status); + for (uint32_t i = 0; i < NumberOfConnections; i++) { + // Send data on each connection while the handshake is progressing + TEST_QUIC_SUCCEEDED( + MsQuic->StreamOpen( + Connections[i], + QUIC_STREAM_OPEN_FLAG_NONE, + [](HQUIC, void* Context, QUIC_STREAM_EVENT* Event){ + if (Event->Type == QUIC_STREAM_EVENT_START_COMPLETE) { + if (QUIC_FAILED(Event->START_COMPLETE.Status)) { + TEST_FAILURE("Stream start failed 0x%x", Event->START_COMPLETE.Status); + } + } else if (Event->Type == QUIC_STREAM_EVENT_RECEIVE) { + auto* Ctxt = (ConnectionPoolConnectionContext*)Context; + Ctxt->DataReceivedEvent.Set(); + } + return QUIC_STATUS_SUCCESS; + }, + &Contexts[i], + &Streams[i].Handle)); + + TEST_QUIC_SUCCEEDED( + MsQuic->StreamSend( + Streams[i], + &SendBuffer, + 1, + QUIC_SEND_FLAG_START | QUIC_SEND_FLAG_FIN, + nullptr)); + } for (uint32_t i = 0; i < NumberOfConnections; i++) { // // Verify the client connection is connected. @@ -4250,6 +4328,7 @@ QuicTestConnectionPoolCreate( if (Contexts[i].SpuriousNotification) { TEST_FAILURE("Context %u received notification for a failed connection", i); } + Contexts[i].DataReceivedEvent.WaitTimeout(TestWaitTimeout); } TEST_EQUAL(NumberOfConnections, Listener.AcceptedConnectionCount); } else {