diff --git a/src/data_channel/mod.rs b/src/data_channel/mod.rs index 3903154..ca3d936 100644 --- a/src/data_channel/mod.rs +++ b/src/data_channel/mod.rs @@ -139,22 +139,29 @@ impl DataChannel { Ok(data_channel) } - /// Read reads a packet of len(p) bytes as binary data + /// Read reads a packet of len(p) bytes as binary data. + /// + /// See [`sctp::stream::Stream::read_sctp`]. pub async fn read(&self, buf: &mut [u8]) -> Result { self.read_data_channel(buf).await.map(|(n, _)| n) } - /// ReadDataChannel reads a packet of len(p) bytes + /// ReadDataChannel reads a packet of len(p) bytes. It returns the number of bytes read and + /// `true` if the data read is a string. + /// + /// See [`sctp::stream::Stream::read_sctp`]. pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> { loop { //TODO: add handling of cancel read_data_channel let (mut n, ppi) = match self.stream.read_sctp(buf).await { + Ok((0, PayloadProtocolIdentifier::Unknown)) => { + // The incoming stream was reset or the reading half was shutdown + return Ok((0, false)); + } Ok((n, ppi)) => (n, ppi), Err(err) => { - // When the peer sees that an incoming stream was - // reset, it also resets its corresponding outgoing stream. - self.stream.shutdown(Shutdown::Both).await?; - + // Shutdown the stream and send the reset request to the remote. + self.close().await?; return Err(err.into()); } }; @@ -261,18 +268,20 @@ impl DataChannel { (true, _) => PayloadProtocolIdentifier::String, }; - self.messages_sent.fetch_add(1, Ordering::SeqCst); - self.bytes_sent.fetch_add(data_len, Ordering::SeqCst); - - if data_len == 0 { + let n = if data_len == 0 { let _ = self .stream .write_sctp(&Bytes::from_static(&[0]), ppi) .await?; - Ok(0) + 0 } else { - Ok(self.stream.write_sctp(data, ppi).await?) - } + let n = self.stream.write_sctp(data, ppi).await?; + self.bytes_sent.fetch_add(n, Ordering::SeqCst); + n + }; + + self.messages_sent.fetch_add(1, Ordering::SeqCst); + Ok(n) } async fn write_data_channel_ack(&self) -> Result { diff --git a/src/error.rs b/src/error.rs index 7146a2e..a96358a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -60,12 +60,3 @@ impl PartialEq for Error { false } } - -impl PartialEq for Error { - fn eq(&self, other: &sctp::Error) -> bool { - if let Error::Sctp(e) = self { - return e == other; - } - false - } -}