@@ -27,6 +27,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
2727
2828 static readonly Exception ChannelClosedException = new IOException ( "Channel is closed" ) ;
2929 static readonly Action < Task , object > HandshakeCompletionCallback = new Action < Task , object > ( HandleHandshakeCompleted ) ;
30+ static readonly Action < Task < int > , object > UnwrapCompletedCallback = new Action < Task < int > , object > ( UnwrapCompleted ) ;
3031
3132 readonly SslStream sslStream ;
3233 readonly MediationStream mediationStream ;
@@ -40,6 +41,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
4041 bool firedChannelRead ;
4142 volatile FlushMode flushMode = FlushMode . ForceFlush ;
4243 IByteBuffer pendingSslStreamReadBuffer ;
44+ int pendingSslStreamReadLength ;
4345 Task < int > pendingSslStreamReadFuture ;
4446
4547 public TlsHandler ( TlsSettings settings )
@@ -342,10 +344,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
342344 Contract . Assert ( this . pendingSslStreamReadBuffer != null ) ;
343345
344346 outputBuffer = this . pendingSslStreamReadBuffer ;
345- outputBufferLength = outputBuffer . WritableBytes ;
347+ outputBufferLength = this . pendingSslStreamReadLength ;
346348
347349 this . pendingSslStreamReadFuture = null ;
348350 this . pendingSslStreamReadBuffer = null ;
351+ this . pendingSslStreamReadLength = 0 ;
349352 }
350353 else
351354 {
@@ -358,6 +361,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
358361 int currentPacketLength = packetLengths [ packetIndex ] ;
359362 this . mediationStream . ExpandSource ( currentPacketLength ) ;
360363
364+ while ( true )
365+ {
366+ int totalRead = 0 ;
361367 if ( currentReadFuture != null )
362368 {
363369 // there was a read pending already, so we make sure we completed that first
@@ -366,10 +372,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
366372 {
367373 // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
368374
369- continue ;
375+ break ;
370376 }
371377
372378 int read = currentReadFuture . Result ;
379+ totalRead += read ;
373380
374381 if ( read == 0 )
375382 {
@@ -382,19 +389,19 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
382389
383390 currentReadFuture = null ;
384391 outputBuffer = null ;
385- if ( this . mediationStream . SourceReadableBytes == 0 )
392+ if ( this . mediationStream . TotalReadableBytes == 0 )
386393 {
387394 // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
388395
389396 if ( read < outputBufferLength )
390397 {
391398 // SslStream returned non-full buffer and there's no more input to go through ->
392399 // typically it means SslStream is done reading current frame so we skip
393- continue ;
400+ break ;
394401 }
395402
396403 // we've read out `read` bytes out of current packet to fulfil previously outstanding read
397- outputBufferLength = currentPacketLength - read ;
404+ outputBufferLength = currentPacketLength - totalRead ;
398405 if ( outputBufferLength <= 0 )
399406 {
400407 // after feeding to SslStream current frame it read out more bytes than current packet size
@@ -417,27 +424,15 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
417424 outputBuffer = ctx . Allocator . Buffer ( outputBufferLength ) ;
418425 currentReadFuture = this . ReadFromSslStreamAsync ( outputBuffer , outputBufferLength ) ;
419426 }
427+ }
420428
421- // read out the rest of SslStream's output (if any) at risk of going async
422- // using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading
423- while ( true )
424- {
425429 if ( currentReadFuture != null )
426430 {
427- if ( ! currentReadFuture . IsCompleted )
428- {
429- break ;
430- }
431- int read = currentReadFuture . Result ;
432- AddBufferToOutput ( outputBuffer , read , output ) ;
433- }
434- outputBuffer = ctx . Allocator . Buffer ( FallbackReadBufferSize ) ;
435- currentReadFuture = this . ReadFromSslStreamAsync ( outputBuffer , FallbackReadBufferSize ) ;
436- }
437-
438431 pending = true ;
439432 this . pendingSslStreamReadBuffer = outputBuffer ;
440433 this . pendingSslStreamReadFuture = currentReadFuture ;
434+ this . pendingSslStreamReadLength = outputBufferLength ;
435+ }
441436 }
442437 catch ( Exception ex )
443438 {
@@ -458,6 +453,91 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
458453 outputBuffer . SafeRelease ( ) ;
459454 }
460455 }
456+
457+ if ( pending )
458+ {
459+ //Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here.
460+ this . pendingSslStreamReadFuture ? . ContinueWith ( UnwrapCompletedCallback , this , TaskContinuationOptions . None ) ;
461+ }
462+ }
463+ }
464+
465+ static void UnwrapCompleted ( Task < int > task , object state )
466+ {
467+ // Mono(with legacy provider) finish ReadAsync in async,
468+ // so extra check is needed to receive data in async
469+ var self = ( TlsHandler ) state ;
470+ Debug . Assert ( self . capturedContext . Executor . InEventLoop ) ;
471+
472+ //Ignore task completed in Unwrap
473+ if ( task == self . pendingSslStreamReadFuture )
474+ {
475+ IByteBuffer buf = self . pendingSslStreamReadBuffer ;
476+ int outputBufferLength = self . pendingSslStreamReadLength ;
477+
478+ self . pendingSslStreamReadFuture = null ;
479+ self . pendingSslStreamReadBuffer = null ;
480+ self . pendingSslStreamReadLength = 0 ;
481+
482+ while ( true )
483+ {
484+ switch ( task . Status )
485+ {
486+ case TaskStatus . RanToCompletion :
487+ {
488+ //The logic is the same as the one in Unwrap()
489+ var read = task . Result ;
490+ //Stream Closed
491+ if ( read == 0 )
492+ return ;
493+ self . capturedContext . FireChannelRead ( buf . SetWriterIndex ( buf . WriterIndex + read ) ) ;
494+
495+ if ( self . mediationStream . TotalReadableBytes == 0 )
496+ {
497+ self . capturedContext . FireChannelReadComplete ( ) ;
498+ self . mediationStream . ResetSource ( self . capturedContext . Allocator ) ;
499+
500+ if ( read < outputBufferLength )
501+ {
502+ // SslStream returned non-full buffer and there's no more input to go through ->
503+ // typically it means SslStream is done reading current frame so we skip
504+ return ;
505+ }
506+ }
507+
508+ outputBufferLength = self . mediationStream . TotalReadableBytes ;
509+ if ( outputBufferLength <= 0 )
510+ outputBufferLength = FallbackReadBufferSize ;
511+
512+ buf = self . capturedContext . Allocator . Buffer ( outputBufferLength ) ;
513+ task = self . ReadFromSslStreamAsync ( buf , outputBufferLength ) ;
514+ if ( task . IsCompleted )
515+ {
516+ continue ;
517+ }
518+
519+ self . pendingSslStreamReadFuture = task ;
520+ self . pendingSslStreamReadBuffer = buf ;
521+ self . pendingSslStreamReadLength = outputBufferLength ;
522+ task . ContinueWith ( UnwrapCompletedCallback , self , TaskContinuationOptions . ExecuteSynchronously ) ;
523+ return ;
524+ }
525+
526+ case TaskStatus . Canceled :
527+ case TaskStatus . Faulted :
528+ {
529+ buf . SafeRelease ( ) ;
530+ self . HandleFailure ( task . Exception ) ;
531+ return ;
532+ }
533+
534+ default :
535+ {
536+ buf . SafeRelease ( ) ;
537+ throw new ArgumentOutOfRangeException ( nameof ( task ) , "Unexpected task status: " + task . Status ) ;
538+ }
539+ }
540+ }
461541 }
462542 }
463543
0 commit comments