Skip to content

Commit 331ea1e

Browse files
committed
Handle if the last data in async path.
1 parent de886b4 commit 331ea1e

1 file changed

Lines changed: 100 additions & 20 deletions

File tree

src/DotNetty.Handlers/Tls/TlsHandler.cs

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
2727

2828
static readonly Exception ChannelClosedException = new IOException("Channel is closed");
2929
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);
30+
static readonly Action<Task<int>, object> UnwrapCompletedCallback = new Action<Task<int>, object>(UnwrapCompleted);
3031

3132
readonly SslStream sslStream;
3233
readonly MediationStream mediationStream;
@@ -40,6 +41,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
4041
bool firedChannelRead;
4142
volatile FlushMode flushMode = FlushMode.ForceFlush;
4243
IByteBuffer pendingSslStreamReadBuffer;
44+
int pendingSslStreamReadLength;
4345
Task<int> pendingSslStreamReadFuture;
4446

4547
public TlsHandler(TlsSettings settings)
@@ -342,10 +344,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
342344
Contract.Assert(this.pendingSslStreamReadBuffer != null);
343345

344346
outputBuffer = this.pendingSslStreamReadBuffer;
345-
outputBufferLength = outputBuffer.WritableBytes;
347+
outputBufferLength = this.pendingSslStreamReadLength;
346348

347349
this.pendingSslStreamReadFuture = null;
348350
this.pendingSslStreamReadBuffer = null;
351+
this.pendingSslStreamReadLength = 0;
349352
}
350353
else
351354
{
@@ -358,6 +361,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
358361
int currentPacketLength = packetLengths[packetIndex];
359362
this.mediationStream.ExpandSource(currentPacketLength);
360363

364+
while (true)
365+
{
366+
int totalRead = 0;
361367
if (currentReadFuture != null)
362368
{
363369
// there was a read pending already, so we make sure we completed that first
@@ -366,10 +372,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
366372
{
367373
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
368374

369-
continue;
375+
break;
370376
}
371377

372378
int read = currentReadFuture.Result;
379+
totalRead += read;
373380

374381
if (read == 0)
375382
{
@@ -382,19 +389,19 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
382389

383390
currentReadFuture = null;
384391
outputBuffer = null;
385-
if (this.mediationStream.SourceReadableBytes == 0)
392+
if (this.mediationStream.TotalReadableBytes == 0)
386393
{
387394
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
388395

389396
if (read < outputBufferLength)
390397
{
391398
// SslStream returned non-full buffer and there's no more input to go through ->
392399
// typically it means SslStream is done reading current frame so we skip
393-
continue;
400+
break;
394401
}
395402

396403
// we've read out `read` bytes out of current packet to fulfil previously outstanding read
397-
outputBufferLength = currentPacketLength - read;
404+
outputBufferLength = currentPacketLength - totalRead;
398405
if (outputBufferLength <= 0)
399406
{
400407
// after feeding to SslStream current frame it read out more bytes than current packet size
@@ -417,27 +424,15 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
417424
outputBuffer = ctx.Allocator.Buffer(outputBufferLength);
418425
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength);
419426
}
427+
}
420428

421-
// read out the rest of SslStream's output (if any) at risk of going async
422-
// using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading
423-
while (true)
424-
{
425429
if (currentReadFuture != null)
426430
{
427-
if (!currentReadFuture.IsCompleted)
428-
{
429-
break;
430-
}
431-
int read = currentReadFuture.Result;
432-
AddBufferToOutput(outputBuffer, read, output);
433-
}
434-
outputBuffer = ctx.Allocator.Buffer(FallbackReadBufferSize);
435-
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, FallbackReadBufferSize);
436-
}
437-
438431
pending = true;
439432
this.pendingSslStreamReadBuffer = outputBuffer;
440433
this.pendingSslStreamReadFuture = currentReadFuture;
434+
this.pendingSslStreamReadLength = outputBufferLength;
435+
}
441436
}
442437
catch (Exception ex)
443438
{
@@ -458,6 +453,91 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
458453
outputBuffer.SafeRelease();
459454
}
460455
}
456+
457+
if (pending)
458+
{
459+
//Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here.
460+
this.pendingSslStreamReadFuture?.ContinueWith(UnwrapCompletedCallback, this, TaskContinuationOptions.None);
461+
}
462+
}
463+
}
464+
465+
static void UnwrapCompleted(Task<int> task, object state)
466+
{
467+
// Mono(with legacy provider) finish ReadAsync in async,
468+
// so extra check is needed to receive data in async
469+
var self = (TlsHandler)state;
470+
Debug.Assert(self.capturedContext.Executor.InEventLoop);
471+
472+
//Ignore task completed in Unwrap
473+
if (task == self.pendingSslStreamReadFuture)
474+
{
475+
IByteBuffer buf = self.pendingSslStreamReadBuffer;
476+
int outputBufferLength = self.pendingSslStreamReadLength;
477+
478+
self.pendingSslStreamReadFuture = null;
479+
self.pendingSslStreamReadBuffer = null;
480+
self.pendingSslStreamReadLength = 0;
481+
482+
while (true)
483+
{
484+
switch (task.Status)
485+
{
486+
case TaskStatus.RanToCompletion:
487+
{
488+
//The logic is the same as the one in Unwrap()
489+
var read = task.Result;
490+
//Stream Closed
491+
if (read == 0)
492+
return;
493+
self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read));
494+
495+
if (self.mediationStream.TotalReadableBytes == 0)
496+
{
497+
self.capturedContext.FireChannelReadComplete();
498+
self.mediationStream.ResetSource(self.capturedContext.Allocator);
499+
500+
if (read < outputBufferLength)
501+
{
502+
// SslStream returned non-full buffer and there's no more input to go through ->
503+
// typically it means SslStream is done reading current frame so we skip
504+
return;
505+
}
506+
}
507+
508+
outputBufferLength = self.mediationStream.TotalReadableBytes;
509+
if (outputBufferLength <= 0)
510+
outputBufferLength = FallbackReadBufferSize;
511+
512+
buf = self.capturedContext.Allocator.Buffer(outputBufferLength);
513+
task = self.ReadFromSslStreamAsync(buf, outputBufferLength);
514+
if (task.IsCompleted)
515+
{
516+
continue;
517+
}
518+
519+
self.pendingSslStreamReadFuture = task;
520+
self.pendingSslStreamReadBuffer = buf;
521+
self.pendingSslStreamReadLength = outputBufferLength;
522+
task.ContinueWith(UnwrapCompletedCallback, self, TaskContinuationOptions.ExecuteSynchronously);
523+
return;
524+
}
525+
526+
case TaskStatus.Canceled:
527+
case TaskStatus.Faulted:
528+
{
529+
buf.SafeRelease();
530+
self.HandleFailure(task.Exception);
531+
return;
532+
}
533+
534+
default:
535+
{
536+
buf.SafeRelease();
537+
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status);
538+
}
539+
}
540+
}
461541
}
462542
}
463543

0 commit comments

Comments
 (0)