Skip to content

Commit 2e50292

Browse files
Add checking for header sanity
Co-authored-by: Daniel Abramov <[email protected]>
1 parent f916b33 commit 2e50292

1 file changed

Lines changed: 69 additions & 16 deletions

File tree

src/handshake/machine.rs

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub struct HandshakeMachine<Stream> {
2020
impl<Stream> HandshakeMachine<Stream> {
2121
/// Start reading data from the peer.
2222
pub fn start_read(stream: Stream) -> Self {
23-
HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) }
23+
Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
2424
}
2525
/// Start writing data to the peer.
2626
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
@@ -41,25 +41,31 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
4141
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
4242
trace!("Doing handshake round.");
4343
match self.state {
44-
HandshakeState::Reading(mut buf) => {
44+
HandshakeState::Reading(mut buf, mut attack_check) => {
4545
let read = buf.read_from(&mut self.stream).no_block()?;
4646
match read {
4747
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
48-
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
49-
buf.advance(size);
50-
RoundResult::StageFinished(StageResult::DoneReading {
51-
result: obj,
52-
stream: self.stream,
53-
tail: buf.into_vec(),
48+
Some(count) => {
49+
attack_check.check_incoming_packet_size(count)?;
50+
// TODO: this is slow for big headers with too many small packets.
51+
// The parser has to be reworked in order to work on streams instead
52+
// of buffers.
53+
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
54+
buf.advance(size);
55+
RoundResult::StageFinished(StageResult::DoneReading {
56+
result: obj,
57+
stream: self.stream,
58+
tail: buf.into_vec(),
59+
})
60+
} else {
61+
RoundResult::Incomplete(HandshakeMachine {
62+
state: HandshakeState::Reading(buf, attack_check),
63+
..self
64+
})
5465
})
55-
} else {
56-
RoundResult::Incomplete(HandshakeMachine {
57-
state: HandshakeState::Reading(buf),
58-
..self
59-
})
60-
}),
66+
}
6167
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
62-
state: HandshakeState::Reading(buf),
68+
state: HandshakeState::Reading(buf, attack_check),
6369
..self
6470
})),
6571
}
@@ -119,7 +125,54 @@ pub trait TryParse: Sized {
119125
#[derive(Debug)]
120126
enum HandshakeState {
121127
/// Reading data from the peer.
122-
Reading(ReadBuffer),
128+
Reading(ReadBuffer, AttackCheck),
123129
/// Sending data to the peer.
124130
Writing(Cursor<Vec<u8>>),
125131
}
132+
133+
/// Attack mitigation. Contains counters needed to prevent DoS attacks
134+
/// and reject valid but useless headers.
135+
#[derive(Debug)]
136+
pub(crate) struct AttackCheck {
137+
/// Number of HTTP header successful reads (TCP packets).
138+
number_of_packets: usize,
139+
/// Total number of bytes in HTTP header.
140+
number_of_bytes: usize,
141+
}
142+
143+
impl AttackCheck {
144+
/// Initialize attack checking for incoming buffer.
145+
fn new() -> Self {
146+
Self { number_of_packets: 0, number_of_bytes: 0 }
147+
}
148+
149+
/// Check the size of an incoming packet. To be called immediately after `read()`
150+
/// passing its returned bytes count as `size`.
151+
fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
152+
self.number_of_packets += 1;
153+
self.number_of_bytes += size;
154+
155+
// TODO: these values are hardcoded. Instead of making them configurable,
156+
// rework the way HTTP header is parsed to remove this check at all.
157+
const MAX_BYTES: usize = 65536;
158+
const MAX_PACKETS: usize = 512;
159+
const MIN_PACKET_SIZE: usize = 128;
160+
const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
161+
162+
if self.number_of_bytes > MAX_BYTES {
163+
return Err(Error::AttackAttempt);
164+
}
165+
166+
if self.number_of_packets > MAX_PACKETS {
167+
return Err(Error::AttackAttempt);
168+
}
169+
170+
if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD {
171+
if self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes {
172+
return Err(Error::AttackAttempt);
173+
}
174+
}
175+
176+
Ok(())
177+
}
178+
}

0 commit comments

Comments
 (0)