Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Runtime.InteropServices;
using System.Threading;
Expand Down Expand Up @@ -45,7 +44,7 @@ internal sealed class MsQuicStream : QuicStreamProvider
private GCHandle _sendHandle;

// Used to check if StartAsync has been called.
private StartState _started;
private bool _started;

private ReadState _readState;
private long _readErrorCode = -1;
Expand Down Expand Up @@ -75,24 +74,24 @@ internal MsQuicStream(MsQuicConnection connection, QUIC_STREAM_OPEN_FLAG flags,

_ptr = nativeObjPtr;


_sendResettableCompletionSource = new ResettableCompletionSource<uint>();
_receiveResettableCompletionSource = new ResettableCompletionSource<uint>();
_shutdownWriteResettableCompletionSource = new ResettableCompletionSource<uint>();
SetCallbackHandler();

if (inbound)
{
_started = StartState.Finished;
_started = true;
_canWrite = !flags.HasFlag(QUIC_STREAM_OPEN_FLAG.UNIDIRECTIONAL);
_canRead = true;
}
else
{
_started = StartState.None;
_canWrite = true;
_canRead = !flags.HasFlag(QUIC_STREAM_OPEN_FLAG.UNIDIRECTIONAL);
StartWrites();
}

_sendResettableCompletionSource = new ResettableCompletionSource<uint>();
_receiveResettableCompletionSource = new ResettableCompletionSource<uint>();
_shutdownWriteResettableCompletionSource = new ResettableCompletionSource<uint>();

SetCallbackHandler();
}

internal override bool CanRead => _canRead;
Expand Down Expand Up @@ -186,6 +185,7 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
throw new OperationCanceledException("Sending has already been aborted on the stream");
}
}

CancellationTokenRegistration registration = cancellationToken.Register(() =>
{
bool shouldComplete = false;
Expand All @@ -204,13 +204,11 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
}
});

// Implicit start on first write.
if (_started == StartState.None)
// Make sure start has completed
if (!_started)
{
_started = StartState.Started;

// TODO can optimize this by not having this method be async.
await StartWritesAsync();
await _sendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false);
_started = true;
}

return registration;
Expand Down Expand Up @@ -636,8 +634,6 @@ private uint HandleStartComplete()
bool shouldComplete = false;
lock (_sync)
{
_started = StartState.Finished;

// Check send state before completing as send cancellation is shared between start and send.
if (_sendState == SendState.None)
{
Expand All @@ -647,7 +643,7 @@ private uint HandleStartComplete()

if (shouldComplete)
{
_sendResettableCompletionSource.Complete(MsQuicStatusCodes.Success);
_sendResettableCompletionSource.Complete(0);
}

if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
Expand Down Expand Up @@ -988,14 +984,14 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync(
return _sendResettableCompletionSource.GetTypelessValueTask();
}

private ValueTask<uint> StartWritesAsync()
private void StartWrites()
{
Debug.Assert(!_started);
uint status = MsQuicApi.Api.StreamStartDelegate(
_ptr,
(uint)QUIC_STREAM_START_FLAG.ASYNC);

QuicExceptionHelpers.ThrowIfFailed(status, "Could not start stream.");
return _sendResettableCompletionSource.GetValueTask();
}

private void ReceiveComplete(int bufferLength)
Expand All @@ -1018,13 +1014,6 @@ private void ThrowIfDisposed()
}
}

private enum StartState
{
None,
Started,
Finished
}

private enum ReadState
{
None,
Expand Down
14 changes: 14 additions & 0 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,20 @@ public async Task CallDifferentWriteMethodsWorks()
Assert.Equal(24, res);
}

[Fact]
public async Task GetStreamIdWithoutStartWorks()
{
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;

using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);
}

private static async Task CreateAndTestBidirectionalStream(QuicConnection c1, QuicConnection c2)
{
using (QuicStream s1 = c1.OpenBidirectionalStream())
Expand Down