Skip to content
5 changes: 5 additions & 0 deletions src/crypto/ed25519.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ impl PublicKey {
self.0.to_bytes()
}

/// Get the public key as a byte slice.
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}

/// Try to parse a public key from a byte array containing the actual key as produced by
/// `to_bytes`.
pub fn try_from_bytes(k: &[u8]) -> Result<PublicKey, ParseError> {
Expand Down
74 changes: 61 additions & 13 deletions src/crypto/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ impl NoiseContext {
}
}

fn get_handshake_dh_remote_pubkey(&self) -> Result<&[u8], NegotiationError> {
let NoiseState::Handshake(ref noise) = self.noise else {
tracing::error!(target: LOG_TARGET, "invalid state to get remote public key");
return Err(NegotiationError::StateMismatch);
};

let Some(dh_remote_pubkey) = noise.get_remote_static() else {
tracing::error!(target: LOG_TARGET, "expected remote public key at the end of XX session");
return Err(NegotiationError::IoError(std::io::ErrorKind::InvalidData));
};

Ok(dh_remote_pubkey)
}

/// Convert Noise into transport mode.
fn into_transport(self) -> Result<NoiseContext, NegotiationError> {
let transport = match self.noise {
Expand Down Expand Up @@ -656,17 +670,35 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NoiseSocket<S> {
}
}

/// Try to parse `PeerId` from received `NoiseHandshakePayload`
fn parse_peer_id(buf: &[u8]) -> Result<PeerId, NegotiationError> {
match handshake_schema::NoiseHandshakePayload::decode(buf) {
Ok(payload) => {
let identity = payload.identity_key.ok_or(NegotiationError::PeerIdMissing)?;
/// Parse the `PeerId` from received `NoiseHandshakePayload` and verify the payload signature.
fn parse_and_verify_peer_id(
payload: handshake_schema::NoiseHandshakePayload,
dh_remote_pubkey: &[u8],
) -> Result<PeerId, NegotiationError> {
let identity = payload.identity_key.ok_or(NegotiationError::PeerIdMissing)?;
let remote_public_key = PublicKey::from_protobuf_encoding(&identity)?;
let remote_key_signature =
payload.identity_sig.ok_or(NegotiationError::BadSignature).map_err(|err| {
tracing::debug!(target: LOG_TARGET, "payload without signature");
err
})?;

let peer_id = PeerId::from_public_key(&remote_public_key);

if !remote_public_key.verify(
&[STATIC_KEY_DOMAIN.as_bytes(), dh_remote_pubkey].concat(),
&remote_key_signature,
) {
tracing::debug!(
target: LOG_TARGET,
?peer_id,
"failed to verify remote public key signature"
);

let public_key = PublicKey::from_protobuf_encoding(&identity)?;
Ok(PeerId::from_public_key(&public_key))
}
Err(err) => Err(ParseError::from(err).into()),
return Err(NegotiationError::BadSignature);
}

Ok(peer_id)
}

/// Perform Noise handshake.
Expand All @@ -680,7 +712,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
tracing::debug!(target: LOG_TARGET, ?role, "start noise handshake");

let mut noise = NoiseContext::new(keypair, role)?;
let peer = match role {
let payload = match role {
Role::Dialer => {
// write initial message
let first_message = noise.first_message(Role::Dialer)?;
Expand All @@ -689,13 +721,20 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(

// read back response which contains the remote peer id
let message = noise.read_handshake_message(&mut io).await?;
// Decode the remote identity message.
let payload = handshake_schema::NoiseHandshakePayload::decode(message)
.map_err(ParseError::from)
.map_err(|err| {
tracing::error!(target: LOG_TARGET, ?err, "failed to decode remote identity message");
err
})?;

// send the final message which contains local peer id
let second_message = noise.second_message()?;
let _ = io.write(&second_message).await?;
io.flush().await?;

parse_peer_id(&message)?
payload
}
Role::Listener => {
// read remote's first message
Expand All @@ -708,10 +747,14 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(

// read remote's second message which contains their peer id
let message = noise.read_handshake_message(&mut io).await?;
parse_peer_id(&message)?
// Decode the remote identity message.
handshake_schema::NoiseHandshakePayload::decode(message).map_err(ParseError::from)?
}
};

let dh_remote_pubkey = noise.get_handshake_dh_remote_pubkey()?;
let peer = parse_and_verify_peer_id(payload, dh_remote_pubkey)?;

Ok((
NoiseSocket::new(
io,
Expand Down Expand Up @@ -789,7 +832,12 @@ mod tests {

#[test]
fn invalid_peer_id_schema() {
match parse_peer_id(&vec![1, 2, 3, 4]).unwrap_err() {
let payload = handshake_schema::NoiseHandshakePayload {
identity_key: Some(vec![1, 2, 3, 4]),
identity_sig: None,
extensions: None,
};
match parse_and_verify_peer_id(payload, &[0]).unwrap_err() {
NegotiationError::ParseError(_) => {}
_ => panic!("invalid error"),
}
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ pub enum NegotiationError {
/// The peer ID was not provided by the noise handshake.
#[error("`PeerId` missing from Noise handshake")]
PeerIdMissing,
/// The remote peer ID is not the same as the one expected.
#[error("The signature of the remote identity's public key does not verify")]
BadSignature,
/// The negotiation operation timed out.
#[error("Operation timed out")]
Timeout,
Expand Down