Skip to content

Commit 5424512

Browse files
committed
Guard client stream commands by generation
1 parent b2192e5 commit 5424512

1 file changed

Lines changed: 99 additions & 5 deletions

File tree

crates/slipstream-client/src/streams.rs

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ static INVARIANT_REPORTER: InvariantReporter = InvariantReporter::new(1_000_000)
2727
pub(crate) struct ClientState {
2828
ready: bool,
2929
closing: bool,
30+
connection_generation: usize,
3031
streams: HashMap<u64, ClientStream>,
3132
multi_stream_mode: bool,
3233
command_tx: mpsc::UnboundedSender<Command>,
@@ -438,6 +439,7 @@ impl ClientState {
438439
Self {
439440
ready: false,
440441
closing: false,
442+
connection_generation: 0,
441443
streams: HashMap::new(),
442444
multi_stream_mode: false,
443445
command_tx,
@@ -557,6 +559,7 @@ impl ClientState {
557559
}
558560
self.ready = false;
559561
self.closing = false;
562+
self.connection_generation = self.connection_generation.wrapping_add(1);
560563
self.multi_stream_mode = false;
561564
self.path_events.clear();
562565
self.acceptor.reset();
@@ -620,6 +623,22 @@ fn check_stream_invariants(state: &ClientState, stream_id: u64, context: &str) {
620623
}
621624
}
622625

626+
fn command_generation_matches(
627+
state: &ClientState,
628+
stream_id: u64,
629+
generation: usize,
630+
label: &str,
631+
) -> bool {
632+
if generation == state.connection_generation {
633+
return true;
634+
}
635+
debug!(
636+
"stream {}: ignoring stale {} generation={} current_generation={}",
637+
stream_id, label, generation, state.connection_generation
638+
);
639+
false
640+
}
641+
623642
struct ClientStream {
624643
write_tx: mpsc::UnboundedSender<StreamWrite>,
625644
read_abort_tx: Option<oneshot::Sender<()>>,
@@ -659,13 +678,16 @@ pub(crate) enum Command {
659678
},
660679
StreamReadError {
661680
stream_id: u64,
681+
generation: usize,
662682
},
663683
StreamWriteError {
664684
stream_id: u64,
685+
generation: usize,
665686
},
666687
StreamWriteDrained {
667688
stream_id: u64,
668689
bytes: usize,
690+
generation: usize,
669691
},
670692
}
671693

@@ -1071,6 +1093,7 @@ mod tests {
10711093
Command::StreamWriteDrained {
10721094
stream_id,
10731095
bytes: 0,
1096+
generation: 0,
10741097
},
10751098
);
10761099
assert!(
@@ -1108,6 +1131,7 @@ mod tests {
11081131
Command::StreamWriteDrained {
11091132
stream_id,
11101133
bytes: 0,
1134+
generation: 0,
11111135
},
11121136
);
11131137

@@ -1162,6 +1186,45 @@ mod tests {
11621186
});
11631187
}
11641188

1189+
#[test]
1190+
fn stale_task_command_is_ignored_after_reconnect() {
1191+
let (command_tx, _command_rx) = mpsc::unbounded_channel();
1192+
let data_notify = Arc::new(Notify::new());
1193+
let acceptor = acceptor::ClientAcceptor::new();
1194+
let mut state = ClientState::new(command_tx, data_notify, false, acceptor);
1195+
let stream_id = 4;
1196+
let (write_tx, _write_rx) = mpsc::unbounded_channel();
1197+
let (read_abort_tx, _read_abort_rx) = oneshot::channel();
1198+
1199+
state.streams.insert(
1200+
stream_id,
1201+
ClientStream {
1202+
write_tx,
1203+
read_abort_tx: Some(read_abort_tx),
1204+
data_rx: None,
1205+
tx_bytes: 0,
1206+
recv_state: StreamRecvState::Open,
1207+
send_state: StreamSendState::Open,
1208+
flow: FlowControlState::default(),
1209+
},
1210+
);
1211+
state.connection_generation = 1;
1212+
1213+
handle_command(
1214+
std::ptr::null_mut(),
1215+
&mut state as *mut _,
1216+
Command::StreamReadError {
1217+
stream_id,
1218+
generation: 0,
1219+
},
1220+
);
1221+
1222+
assert!(
1223+
state.streams.contains_key(&stream_id),
1224+
"stale task command from old generation must not mutate current stream state"
1225+
);
1226+
}
1227+
11651228
#[test]
11661229
fn acceptor_backpressure_blocks_new_connections() {
11671230
let _guard = ResetOnDrop::new(|| acceptor::ClientAcceptor::set_test_limit(0));
@@ -1310,6 +1373,7 @@ pub(crate) fn handle_command(
13101373
let (write_tx, write_rx) = mpsc::unbounded_channel();
13111374
let command_tx = state.command_tx.clone();
13121375
let (read_abort_tx, read_abort_rx) = oneshot::channel();
1376+
let generation = state.connection_generation;
13131377
state.streams.insert(
13141378
stream_id,
13151379
ClientStream {
@@ -1326,6 +1390,7 @@ pub(crate) fn handle_command(
13261390
stream_id,
13271391
read_half,
13281392
read_abort_rx,
1393+
generation,
13291394
command_tx.clone(),
13301395
data_tx,
13311396
data_notify,
@@ -1334,6 +1399,7 @@ pub(crate) fn handle_command(
13341399
stream_id,
13351400
write_half,
13361401
write_rx,
1402+
generation,
13371403
command_tx,
13381404
send_buffer_bytes,
13391405
);
@@ -1429,7 +1495,13 @@ pub(crate) fn handle_command(
14291495
}
14301496
check_stream_invariants(state, stream_id, "StreamClosed");
14311497
}
1432-
Command::StreamReadError { stream_id } => {
1498+
Command::StreamReadError {
1499+
stream_id,
1500+
generation,
1501+
} => {
1502+
if !command_generation_matches(state, stream_id, generation, "StreamReadError") {
1503+
return;
1504+
}
14331505
if let Some(stream) = state.streams.remove(&stream_id) {
14341506
warn!(
14351507
"stream {}: tcp read error rx_bytes={} tx_bytes={} queued={} consumed_offset={} fin_offset={:?}",
@@ -1445,7 +1517,13 @@ pub(crate) fn handle_command(
14451517
}
14461518
unsafe { abort_stream_bidi(cnx, stream_id, SLIPSTREAM_INTERNAL_ERROR) };
14471519
}
1448-
Command::StreamWriteError { stream_id } => {
1520+
Command::StreamWriteError {
1521+
stream_id,
1522+
generation,
1523+
} => {
1524+
if !command_generation_matches(state, stream_id, generation, "StreamWriteError") {
1525+
return;
1526+
}
14491527
if let Some(stream) = state.streams.remove(&stream_id) {
14501528
warn!(
14511529
"stream {}: tcp write error rx_bytes={} tx_bytes={} queued={} consumed_offset={} fin_offset={:?}",
@@ -1461,7 +1539,14 @@ pub(crate) fn handle_command(
14611539
}
14621540
unsafe { abort_stream_bidi(cnx, stream_id, SLIPSTREAM_INTERNAL_ERROR) };
14631541
}
1464-
Command::StreamWriteDrained { stream_id, bytes } => {
1542+
Command::StreamWriteDrained {
1543+
stream_id,
1544+
bytes,
1545+
generation,
1546+
} => {
1547+
if !command_generation_matches(state, stream_id, generation, "StreamWriteDrained") {
1548+
return;
1549+
}
14651550
let mut remove_stream = false;
14661551
if let Some(stream) = state.streams.get_mut(&stream_id) {
14671552
if stream.flow.discarding {
@@ -1512,6 +1597,7 @@ fn spawn_client_reader(
15121597
stream_id: u64,
15131598
mut read_half: tokio::net::tcp::OwnedReadHalf,
15141599
mut read_abort_rx: oneshot::Receiver<()>,
1600+
generation: usize,
15151601
command_tx: mpsc::UnboundedSender<Command>,
15161602
data_tx: mpsc::Sender<Vec<u8>>,
15171603
data_notify: Arc<Notify>,
@@ -1539,7 +1625,10 @@ fn spawn_client_reader(
15391625
continue;
15401626
}
15411627
Err(_) => {
1542-
let _ = command_tx.send(Command::StreamReadError { stream_id });
1628+
let _ = command_tx.send(Command::StreamReadError {
1629+
stream_id,
1630+
generation,
1631+
});
15431632
break;
15441633
}
15451634
}
@@ -1555,6 +1644,7 @@ fn spawn_client_writer(
15551644
stream_id: u64,
15561645
mut write_half: tokio::net::tcp::OwnedWriteHalf,
15571646
mut write_rx: mpsc::UnboundedReceiver<StreamWrite>,
1647+
generation: usize,
15581648
command_tx: mpsc::UnboundedSender<Command>,
15591649
coalesce_max_bytes: usize,
15601650
) {
@@ -1586,12 +1676,16 @@ fn spawn_client_writer(
15861676
}
15871677
let len = buffer.len();
15881678
if write_half.write_all(&buffer).await.is_err() {
1589-
let _ = command_tx.send(Command::StreamWriteError { stream_id });
1679+
let _ = command_tx.send(Command::StreamWriteError {
1680+
stream_id,
1681+
generation,
1682+
});
15901683
return;
15911684
}
15921685
let _ = command_tx.send(Command::StreamWriteDrained {
15931686
stream_id,
15941687
bytes: len,
1688+
generation,
15951689
});
15961690
if saw_fin {
15971691
let _ = write_half.shutdown().await;

0 commit comments

Comments
 (0)