@@ -27,6 +27,7 @@ static INVARIANT_REPORTER: InvariantReporter = InvariantReporter::new(1_000_000)
2727pub ( crate ) struct ClientState {
2828 ready : bool ,
2929 closing : bool ,
30+ connection_generation : usize ,
3031 streams : HashMap < u64 , ClientStream > ,
3132 multi_stream_mode : bool ,
3233 command_tx : mpsc:: UnboundedSender < Command > ,
@@ -438,6 +439,7 @@ impl ClientState {
438439 Self {
439440 ready : false ,
440441 closing : false ,
442+ connection_generation : 0 ,
441443 streams : HashMap :: new ( ) ,
442444 multi_stream_mode : false ,
443445 command_tx,
@@ -557,6 +559,7 @@ impl ClientState {
557559 }
558560 self . ready = false ;
559561 self . closing = false ;
562+ self . connection_generation = self . connection_generation . wrapping_add ( 1 ) ;
560563 self . multi_stream_mode = false ;
561564 self . path_events . clear ( ) ;
562565 self . acceptor . reset ( ) ;
@@ -620,6 +623,22 @@ fn check_stream_invariants(state: &ClientState, stream_id: u64, context: &str) {
620623 }
621624}
622625
626+ fn command_generation_matches (
627+ state : & ClientState ,
628+ stream_id : u64 ,
629+ generation : usize ,
630+ label : & str ,
631+ ) -> bool {
632+ if generation == state. connection_generation {
633+ return true ;
634+ }
635+ debug ! (
636+ "stream {}: ignoring stale {} generation={} current_generation={}" ,
637+ stream_id, label, generation, state. connection_generation
638+ ) ;
639+ false
640+ }
641+
623642struct ClientStream {
624643 write_tx : mpsc:: UnboundedSender < StreamWrite > ,
625644 read_abort_tx : Option < oneshot:: Sender < ( ) > > ,
@@ -659,13 +678,16 @@ pub(crate) enum Command {
659678 } ,
660679 StreamReadError {
661680 stream_id : u64 ,
681+ generation : usize ,
662682 } ,
663683 StreamWriteError {
664684 stream_id : u64 ,
685+ generation : usize ,
665686 } ,
666687 StreamWriteDrained {
667688 stream_id : u64 ,
668689 bytes : usize ,
690+ generation : usize ,
669691 } ,
670692}
671693
@@ -1071,6 +1093,7 @@ mod tests {
10711093 Command :: StreamWriteDrained {
10721094 stream_id,
10731095 bytes : 0 ,
1096+ generation : 0 ,
10741097 } ,
10751098 ) ;
10761099 assert ! (
@@ -1108,6 +1131,7 @@ mod tests {
11081131 Command :: StreamWriteDrained {
11091132 stream_id,
11101133 bytes : 0 ,
1134+ generation : 0 ,
11111135 } ,
11121136 ) ;
11131137
@@ -1162,6 +1186,45 @@ mod tests {
11621186 } ) ;
11631187 }
11641188
1189+ #[ test]
1190+ fn stale_task_command_is_ignored_after_reconnect ( ) {
1191+ let ( command_tx, _command_rx) = mpsc:: unbounded_channel ( ) ;
1192+ let data_notify = Arc :: new ( Notify :: new ( ) ) ;
1193+ let acceptor = acceptor:: ClientAcceptor :: new ( ) ;
1194+ let mut state = ClientState :: new ( command_tx, data_notify, false , acceptor) ;
1195+ let stream_id = 4 ;
1196+ let ( write_tx, _write_rx) = mpsc:: unbounded_channel ( ) ;
1197+ let ( read_abort_tx, _read_abort_rx) = oneshot:: channel ( ) ;
1198+
1199+ state. streams . insert (
1200+ stream_id,
1201+ ClientStream {
1202+ write_tx,
1203+ read_abort_tx : Some ( read_abort_tx) ,
1204+ data_rx : None ,
1205+ tx_bytes : 0 ,
1206+ recv_state : StreamRecvState :: Open ,
1207+ send_state : StreamSendState :: Open ,
1208+ flow : FlowControlState :: default ( ) ,
1209+ } ,
1210+ ) ;
1211+ state. connection_generation = 1 ;
1212+
1213+ handle_command (
1214+ std:: ptr:: null_mut ( ) ,
1215+ & mut state as * mut _ ,
1216+ Command :: StreamReadError {
1217+ stream_id,
1218+ generation : 0 ,
1219+ } ,
1220+ ) ;
1221+
1222+ assert ! (
1223+ state. streams. contains_key( & stream_id) ,
1224+ "stale task command from old generation must not mutate current stream state"
1225+ ) ;
1226+ }
1227+
11651228 #[ test]
11661229 fn acceptor_backpressure_blocks_new_connections ( ) {
11671230 let _guard = ResetOnDrop :: new ( || acceptor:: ClientAcceptor :: set_test_limit ( 0 ) ) ;
@@ -1310,6 +1373,7 @@ pub(crate) fn handle_command(
13101373 let ( write_tx, write_rx) = mpsc:: unbounded_channel ( ) ;
13111374 let command_tx = state. command_tx . clone ( ) ;
13121375 let ( read_abort_tx, read_abort_rx) = oneshot:: channel ( ) ;
1376+ let generation = state. connection_generation ;
13131377 state. streams . insert (
13141378 stream_id,
13151379 ClientStream {
@@ -1326,6 +1390,7 @@ pub(crate) fn handle_command(
13261390 stream_id,
13271391 read_half,
13281392 read_abort_rx,
1393+ generation,
13291394 command_tx. clone ( ) ,
13301395 data_tx,
13311396 data_notify,
@@ -1334,6 +1399,7 @@ pub(crate) fn handle_command(
13341399 stream_id,
13351400 write_half,
13361401 write_rx,
1402+ generation,
13371403 command_tx,
13381404 send_buffer_bytes,
13391405 ) ;
@@ -1429,7 +1495,13 @@ pub(crate) fn handle_command(
14291495 }
14301496 check_stream_invariants ( state, stream_id, "StreamClosed" ) ;
14311497 }
1432- Command :: StreamReadError { stream_id } => {
1498+ Command :: StreamReadError {
1499+ stream_id,
1500+ generation,
1501+ } => {
1502+ if !command_generation_matches ( state, stream_id, generation, "StreamReadError" ) {
1503+ return ;
1504+ }
14331505 if let Some ( stream) = state. streams . remove ( & stream_id) {
14341506 warn ! (
14351507 "stream {}: tcp read error rx_bytes={} tx_bytes={} queued={} consumed_offset={} fin_offset={:?}" ,
@@ -1445,7 +1517,13 @@ pub(crate) fn handle_command(
14451517 }
14461518 unsafe { abort_stream_bidi ( cnx, stream_id, SLIPSTREAM_INTERNAL_ERROR ) } ;
14471519 }
1448- Command :: StreamWriteError { stream_id } => {
1520+ Command :: StreamWriteError {
1521+ stream_id,
1522+ generation,
1523+ } => {
1524+ if !command_generation_matches ( state, stream_id, generation, "StreamWriteError" ) {
1525+ return ;
1526+ }
14491527 if let Some ( stream) = state. streams . remove ( & stream_id) {
14501528 warn ! (
14511529 "stream {}: tcp write error rx_bytes={} tx_bytes={} queued={} consumed_offset={} fin_offset={:?}" ,
@@ -1461,7 +1539,14 @@ pub(crate) fn handle_command(
14611539 }
14621540 unsafe { abort_stream_bidi ( cnx, stream_id, SLIPSTREAM_INTERNAL_ERROR ) } ;
14631541 }
1464- Command :: StreamWriteDrained { stream_id, bytes } => {
1542+ Command :: StreamWriteDrained {
1543+ stream_id,
1544+ bytes,
1545+ generation,
1546+ } => {
1547+ if !command_generation_matches ( state, stream_id, generation, "StreamWriteDrained" ) {
1548+ return ;
1549+ }
14651550 let mut remove_stream = false ;
14661551 if let Some ( stream) = state. streams . get_mut ( & stream_id) {
14671552 if stream. flow . discarding {
@@ -1512,6 +1597,7 @@ fn spawn_client_reader(
15121597 stream_id : u64 ,
15131598 mut read_half : tokio:: net:: tcp:: OwnedReadHalf ,
15141599 mut read_abort_rx : oneshot:: Receiver < ( ) > ,
1600+ generation : usize ,
15151601 command_tx : mpsc:: UnboundedSender < Command > ,
15161602 data_tx : mpsc:: Sender < Vec < u8 > > ,
15171603 data_notify : Arc < Notify > ,
@@ -1539,7 +1625,10 @@ fn spawn_client_reader(
15391625 continue ;
15401626 }
15411627 Err ( _) => {
1542- let _ = command_tx. send( Command :: StreamReadError { stream_id } ) ;
1628+ let _ = command_tx. send( Command :: StreamReadError {
1629+ stream_id,
1630+ generation,
1631+ } ) ;
15431632 break ;
15441633 }
15451634 }
@@ -1555,6 +1644,7 @@ fn spawn_client_writer(
15551644 stream_id : u64 ,
15561645 mut write_half : tokio:: net:: tcp:: OwnedWriteHalf ,
15571646 mut write_rx : mpsc:: UnboundedReceiver < StreamWrite > ,
1647+ generation : usize ,
15581648 command_tx : mpsc:: UnboundedSender < Command > ,
15591649 coalesce_max_bytes : usize ,
15601650) {
@@ -1586,12 +1676,16 @@ fn spawn_client_writer(
15861676 }
15871677 let len = buffer. len ( ) ;
15881678 if write_half. write_all ( & buffer) . await . is_err ( ) {
1589- let _ = command_tx. send ( Command :: StreamWriteError { stream_id } ) ;
1679+ let _ = command_tx. send ( Command :: StreamWriteError {
1680+ stream_id,
1681+ generation,
1682+ } ) ;
15901683 return ;
15911684 }
15921685 let _ = command_tx. send ( Command :: StreamWriteDrained {
15931686 stream_id,
15941687 bytes : len,
1688+ generation,
15951689 } ) ;
15961690 if saw_fin {
15971691 let _ = write_half. shutdown ( ) . await ;
0 commit comments