diff --git a/examples/ping.rs b/examples/ping.rs index 842619a..a7f78bd 100644 --- a/examples/ping.rs +++ b/examples/ping.rs @@ -5,6 +5,7 @@ use webrtc_sctp::Error; use bytes::Bytes; use clap::{App, AppSettings, Arg}; +use std::net::Shutdown; use std::sync::Arc; use tokio::net::UdpSocket; use tokio::signal; @@ -108,7 +109,7 @@ async fn main() -> Result<(), Error> { signal::ctrl_c().await.expect("failed to listen for event"); println!("Closing stream and association..."); - stream.close().await?; + stream.shutdown(Shutdown::Both).await?; a.close().await?; let _ = done_rx.recv().await; diff --git a/examples/pong.rs b/examples/pong.rs index a1dcad7..0d423fb 100644 --- a/examples/pong.rs +++ b/examples/pong.rs @@ -4,6 +4,7 @@ use webrtc_sctp::Error; use bytes::Bytes; use clap::{App, AppSettings, Arg}; +use std::net::Shutdown; use std::sync::Arc; use std::time::Duration; use tokio::net::UdpSocket; @@ -99,7 +100,7 @@ async fn main() -> Result<(), Error> { signal::ctrl_c().await.expect("failed to listen for event"); println!("Closing stream and association..."); - stream.close().await?; + stream.shutdown(Shutdown::Both).await?; a.close().await?; let _ = done_rx.recv().await; diff --git a/src/association/association_internal.rs b/src/association/association_internal.rs index 48d5bca..2d8e0f0 100644 --- a/src/association/association_internal.rs +++ b/src/association/association_internal.rs @@ -294,8 +294,11 @@ impl AssociationInternal { fn unregister_stream(&mut self, stream_identifier: u16) { let s = self.streams.remove(&stream_identifier); if let Some(s) = s { - s.closed.store(true, Ordering::SeqCst); - s.read_notifier.notify_waiters(); + // NOTE: shutdown is not used here because it resets the stream. + if !s.read_shutdown.swap(true, Ordering::SeqCst) { + s.read_notifier.notify_waiters(); + } + s.write_shutdown.store(true, Ordering::SeqCst); } } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 96c9348..9cb9641 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -12,6 +12,7 @@ use bytes::Bytes; use std::fmt; use std::future::Future; use std::io; +use std::net::Shutdown; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize, Ordering}; use std::sync::Arc; @@ -76,7 +77,8 @@ pub struct Stream { pub(crate) reassembly_queue: Mutex, pub(crate) sequence_number: AtomicU16, pub(crate) read_notifier: Notify, - pub(crate) closed: AtomicBool, + pub(crate) read_shutdown: AtomicBool, + pub(crate) write_shutdown: AtomicBool, pub(crate) unordered: AtomicBool, pub(crate) reliability_type: AtomicU8, //ReliabilityType, pub(crate) reliability_value: AtomicU32, @@ -97,7 +99,8 @@ impl fmt::Debug for Stream { .field("default_payload_type", &self.default_payload_type) .field("reassembly_queue", &self.reassembly_queue) .field("sequence_number", &self.sequence_number) - .field("closed", &self.closed) + .field("read_shutdown", &self.read_shutdown) + .field("write_shutdown", &self.write_shutdown) .field("unordered", &self.unordered) .field("reliability_type", &self.reliability_type) .field("reliability_value", &self.reliability_value) @@ -130,7 +133,8 @@ impl Stream { reassembly_queue: Mutex::new(ReassemblyQueue::new(stream_identifier)), sequence_number: AtomicU16::new(0), read_notifier: Notify::new(), - closed: AtomicBool::new(false), + read_shutdown: AtomicBool::new(false), + write_shutdown: AtomicBool::new(false), unordered: AtomicBool::new(false), reliability_type: AtomicU8::new(0), //ReliabilityType::Reliable, reliability_value: AtomicU32::new(0), @@ -167,37 +171,38 @@ impl Stream { self.reliability_value.store(rel_val, Ordering::SeqCst); } - /// read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier. - /// Returns EOF when the stream is reset or an error if the stream is closed - /// otherwise. + /// Reads a packet of len(p) bytes, dropping the Payload Protocol Identifier. + /// + /// Returns EOF when the stream is reset or an error if `p` is too short. + /// Returns `0` if the reading half of this stream is shutdown. pub async fn read(&self, p: &mut [u8]) -> Result { let (n, _) = self.read_sctp(p).await?; Ok(n) } - /// read_sctp reads a packet of len(p) bytes and returns the associated Payload - /// Protocol Identifier. - /// Returns EOF when the stream is reset or an error if the stream is closed - /// otherwise. + /// Reads a packet of len(p) bytes and returns the associated Payload Protocol Identifier. + /// + /// Returns EOF when the stream is reset or an error if `p` is too short. + /// Returns `(0, PayloadProtocolIdentifier::Unknown)` if the reading half of this stream is shutdown. pub async fn read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> { - while !self.closed.load(Ordering::SeqCst) { + loop { + if self.read_shutdown.load(Ordering::SeqCst) { + return Ok((0, PayloadProtocolIdentifier::Unknown)); + } + let result = { let mut reassembly_queue = self.reassembly_queue.lock().await; reassembly_queue.read(p) }; - if result.is_ok() { - return result; - } else if let Err(err) = result { - if Error::ErrShortBuffer == err { - return Err(err); + match result { + Ok(_) | Err(Error::ErrShortBuffer) => return result, + Err(_) => { + // wait for the next chunk to become available + self.read_notifier.notified().await; } } - - self.read_notifier.notified().await; } - - Err(Error::ErrStreamClosed) } pub(crate) async fn handle_data(&self, pd: ChunkPayloadData) { @@ -257,14 +262,22 @@ impl Stream { } } - /// write writes len(p) bytes from p with the default Payload Protocol Identifier + /// Writes `p` to the DTLS connection with the default Payload Protocol Identifier. + /// + /// Returns an error if the write half of this stream is shutdown or `p` is too large. pub async fn write(&self, p: &Bytes) -> Result { self.write_sctp(p, self.default_payload_type.load(Ordering::SeqCst).into()) .await } - /// write_sctp writes len(p) bytes from p to the DTLS connection + /// Writes `p` to the DTLS connection with the given Payload Protocol Identifier. + /// + /// Returns an error if the write half of this stream is shutdown or `p` is too large. pub async fn write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result { + if self.write_shutdown.load(Ordering::SeqCst) { + return Err(Error::ErrStreamClosed); + } + if p.len() > self.max_message_size.load(Ordering::SeqCst) as usize { return Err(Error::ErrOutboundPacketTooLarge); } @@ -339,16 +352,43 @@ impl Stream { chunks } - /// Close closes the write-direction of the stream. - /// Future calls to write are not permitted after calling Close. + /// Closes both read and write halves of this stream. + /// + /// Use [`Stream::shutdown`] instead. + #[deprecated] pub async fn close(&self) -> Result<()> { - if !self.closed.load(Ordering::SeqCst) { - // Reset the outgoing stream + self.shutdown(Shutdown::Both).await + } + + /// Shuts down the read, write, or both halves of this stream. + /// + /// This function will cause all pending and future I/O on the specified portions to return + /// immediately with an appropriate value (see the documentation of [`Shutdown`]). + /// + /// Resets the stream when both halves of this stream are shutdown. + pub async fn shutdown(&self, how: Shutdown) -> Result<()> { + if self.read_shutdown.load(Ordering::SeqCst) && self.write_shutdown.load(Ordering::SeqCst) { + return Ok(()); + } + + if how == Shutdown::Write || how == Shutdown::Both { + self.write_shutdown.store(true, Ordering::SeqCst); + } + + if how == Shutdown::Read || how == Shutdown::Both { + if !self.read_shutdown.swap(true, Ordering::SeqCst) { + self.read_notifier.notify_waiters(); + } + } + + if how == Shutdown::Both + || (self.read_shutdown.load(Ordering::SeqCst) + && self.write_shutdown.load(Ordering::SeqCst)) + { + // Reset the stream // https://tools.ietf.org/html/rfc6525 self.send_reset_request(self.stream_identifier).await?; } - self.closed.store(true, Ordering::SeqCst); - self.read_notifier.notify_waiters(); // broadcast regardless Ok(()) } @@ -724,8 +764,9 @@ impl AsyncWrite for PollStream { Some(fut) => fut, None => { let stream = self.stream.clone(); - self.shutdown_fut - .get_or_insert(Box::pin(async move { stream.close().await })) + self.shutdown_fut.get_or_insert(Box::pin(async move { + stream.shutdown(Shutdown::Write).await + })) } }; diff --git a/src/stream/stream_test.rs b/src/stream/stream_test.rs index cab45af..8976313 100644 --- a/src/stream/stream_test.rs +++ b/src/stream/stream_test.rs @@ -72,6 +72,80 @@ async fn test_stream_amount_on_buffered_amount_low() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_stream() -> std::result::Result<(), io::Error> { + let s = Stream::new( + "test_poll_stream".to_owned(), + 0, + 4096, + Arc::new(AtomicU32::new(4096)), + Arc::new(AtomicU8::new(AssociationState::Established as u8)), + None, + Arc::new(PendingQueue::new()), + ); + + // getters + assert_eq!(0, s.stream_identifier()); + assert_eq!(0, s.buffered_amount()); + assert_eq!(0, s.buffered_amount_low_threshold()); + assert_eq!(0, s.get_num_bytes_in_reassembly_queue().await); + + // setters + s.set_default_payload_type(PayloadProtocolIdentifier::Binary); + s.set_reliability_params(true, ReliabilityType::Reliable, 0); + + // write + let n = s.write(&Bytes::from("Hello ")).await?; + assert_eq!(6, n); + assert_eq!(6, s.buffered_amount()); + let n = s + .write_sctp(&Bytes::from("world"), PayloadProtocolIdentifier::Binary) + .await?; + assert_eq!(5, n); + assert_eq!(11, s.buffered_amount()); + + // async read + // 1. pretend that we've received a chunk + s.handle_data(ChunkPayloadData { + unordered: true, + beginning_fragment: true, + ending_fragment: true, + user_data: Bytes::from_static(&[0, 1, 2, 3, 4]), + payload_type: PayloadProtocolIdentifier::Binary, + ..Default::default() + }) + .await; + // 2. read it + let mut buf = [0; 5]; + s.read(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + // shutdown write + s.shutdown(Shutdown::Write).await?; + // write must fail + assert!(s.write(&Bytes::from("error")).await.is_err()); + // read should continue working + s.handle_data(ChunkPayloadData { + unordered: true, + beginning_fragment: true, + ending_fragment: true, + user_data: Bytes::from_static(&[5, 6, 7, 8, 9]), + payload_type: PayloadProtocolIdentifier::Binary, + ..Default::default() + }) + .await; + let mut buf = [0; 5]; + s.read(&mut buf).await?; + assert_eq!(buf, [5, 6, 7, 8, 9]); + + // shutdown read + s.shutdown(Shutdown::Read).await?; + // read must return 0 + assert_eq!(Ok(0), s.read(&mut buf).await); + + Ok(()) +} + #[tokio::test] async fn test_poll_stream() -> std::result::Result<(), io::Error> { let s = Arc::new(Stream::new( @@ -114,10 +188,23 @@ async fn test_poll_stream() -> std::result::Result<(), io::Error> { poll_stream.read(&mut buf).await?; assert_eq!(buf, [0, 1, 2, 3, 4]); - // shutdown + // shutdown write poll_stream.shutdown().await?; - assert_eq!(true, sc.closed.load(Ordering::Relaxed)); - assert!(poll_stream.read(&mut buf).await.is_err()); + // write must fail + assert!(poll_stream.write(&[1, 2, 3]).await.is_err()); + // read should continue working + sc.handle_data(ChunkPayloadData { + unordered: true, + beginning_fragment: true, + ending_fragment: true, + user_data: Bytes::from_static(&[5, 6, 7, 8, 9]), + payload_type: PayloadProtocolIdentifier::Binary, + ..Default::default() + }) + .await; + let mut buf = [0; 5]; + poll_stream.read(&mut buf).await?; + assert_eq!(buf, [5, 6, 7, 8, 9]); // misc. let clone = poll_stream.clone();