Skip to content

Commit 25f9bc2

Browse files
committed
refactor(wasi-sockets): simplify UDP implementation
This introduces quite a few changes compared to TCP, which should most probably be integrated there as well Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net>
1 parent 48c8f01 commit 25f9bc2

2 files changed

Lines changed: 85 additions & 122 deletions

File tree

crates/wasi/src/preview2/host/udp.rs

Lines changed: 75 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::preview2::{
99
Table,
1010
};
1111
use crate::preview2::{Pollable, SocketResult, WasiView};
12-
use cap_net_ext::PoolExt;
12+
use cap_net_ext::{AddressFamily, PoolExt};
1313
use io_lifetimes::AsSocketlike;
1414
use rustix::io::Errno;
1515
use rustix::net::sockopt;
@@ -29,7 +29,10 @@ fn start_bind(
2929
let socket = table.get_resource(&this)?;
3030
match socket.udp_state {
3131
UdpState::Default => {}
32-
_ => return Err(ErrorCode::NotInProgress.into()),
32+
UdpState::BindStarted | UdpState::Connecting | UdpState::ConnectReady => {
33+
return Err(ErrorCode::ConcurrencyConflict.into())
34+
}
35+
UdpState::Bound | UdpState::Connected => return Err(ErrorCode::AlreadyBound.into()),
3336
}
3437

3538
let network = table.get_resource(&network)?;
@@ -51,56 +54,11 @@ fn start_bind(
5154
fn finish_bind(table: &mut Table, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
5255
let socket = table.get_resource_mut(&this)?;
5356
match socket.udp_state {
54-
UdpState::BindStarted => {}
55-
_ => return Err(ErrorCode::NotInProgress.into()),
56-
}
57-
58-
socket.udp_state = UdpState::Bound;
59-
60-
Ok(())
61-
}
62-
63-
fn address_family(table: &Table, this: Resource<udp::UdpSocket>) -> SocketResult<IpAddressFamily> {
64-
let socket = table.get_resource(&this)?;
65-
66-
// If `SO_DOMAIN` is available, use it.
67-
//
68-
// TODO: OpenBSD also supports this; upstream PRs are posted.
69-
#[cfg(not(any(
70-
windows,
71-
target_os = "ios",
72-
target_os = "macos",
73-
target_os = "netbsd",
74-
target_os = "openbsd"
75-
)))]
76-
{
77-
use rustix::net::AddressFamily;
78-
79-
let family = sockopt::get_socket_domain(socket.udp_socket())?;
80-
let family = match family {
81-
AddressFamily::INET => IpAddressFamily::Ipv4,
82-
AddressFamily::INET6 => IpAddressFamily::Ipv6,
83-
_ => return Err(ErrorCode::NotSupported.into()),
84-
};
85-
Ok(family)
86-
}
87-
88-
// When `SO_DOMAIN` is not available, emulate it.
89-
#[cfg(any(
90-
windows,
91-
target_os = "ios",
92-
target_os = "macos",
93-
target_os = "netbsd",
94-
target_os = "openbsd"
95-
))]
96-
{
97-
if let Ok(_) = sockopt::get_ipv6_unicast_hops(socket.udp_socket()) {
98-
return Ok(IpAddressFamily::Ipv6);
99-
}
100-
if let Ok(_) = sockopt::get_ip_ttl(socket.udp_socket()) {
101-
return Ok(IpAddressFamily::Ipv4);
57+
UdpState::BindStarted => {
58+
socket.udp_state = UdpState::Bound;
59+
Ok(())
10260
}
103-
Err(ErrorCode::NotSupported.into())
61+
_ => Err(ErrorCode::NotInProgress.into()),
10462
}
10563
}
10664

@@ -127,70 +85,64 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
12785
remote_address: IpSocketAddress,
12886
) -> SocketResult<()> {
12987
let table = self.table_mut();
130-
let r = {
131-
let socket = table.get_resource(&this)?;
132-
match socket.udp_state {
133-
UdpState::Default => {
134-
let family = address_family(table, Resource::new_borrow(this.rep()))?;
135-
let addr = match family {
136-
IpAddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(),
137-
IpAddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(),
138-
};
139-
start_bind(
140-
table,
141-
Resource::new_borrow(this.rep()),
142-
Resource::new_borrow(network.rep()),
143-
SocketAddr::new(addr, 0).into(),
144-
)?;
145-
finish_bind(table, Resource::new_borrow(this.rep()))?;
146-
}
147-
UdpState::BindStarted => {
148-
finish_bind(table, Resource::new_borrow(this.rep()))?;
149-
}
150-
UdpState::Bound => {}
151-
UdpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
152-
_ => return Err(ErrorCode::NotInProgress.into()),
88+
let socket = table.get_resource(&this)?;
89+
match socket.udp_state {
90+
UdpState::Default => {
91+
let addr = match socket.family {
92+
AddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(),
93+
AddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(),
94+
};
95+
start_bind(
96+
table,
97+
Resource::new_borrow(this.rep()),
98+
Resource::new_borrow(network.rep()),
99+
SocketAddr::new(addr, 0).into(),
100+
)?;
101+
finish_bind(table, Resource::new_borrow(this.rep()))?;
153102
}
154-
155-
let socket = table.get_resource(&this)?;
156-
let network = table.get_resource(&network)?;
157-
let connecter = network.pool.udp_connecter(remote_address)?;
158-
159-
// Do an OS `connect`. Our socket is non-blocking, so it'll either...
160-
{
161-
let view = &*socket
162-
.udp_socket()
163-
.as_socketlike_view::<cap_std::net::UdpSocket>();
164-
let r = connecter.connect_existing_udp_socket(view);
165-
r
103+
UdpState::Bound => {}
104+
UdpState::BindStarted | UdpState::Connecting | UdpState::ConnectReady => {
105+
return Err(ErrorCode::ConcurrencyConflict.into())
166106
}
167-
};
107+
UdpState::Connected => return Err(ErrorCode::AlreadyConnected.into()),
108+
}
168109

169-
match r {
110+
let socket = table.get_resource(&this)?;
111+
let network = table.get_resource(&network)?;
112+
let connecter = network.pool.udp_connecter(remote_address)?;
113+
114+
// Do an OS `connect`. Our socket is non-blocking, so it'll either...
115+
let res = connecter.connect_existing_udp_socket(
116+
&*socket
117+
.udp_socket()
118+
.as_socketlike_view::<cap_std::net::UdpSocket>(),
119+
);
120+
match res {
170121
// succeed immediately,
171122
Ok(()) => {
172123
let socket = table.get_resource_mut(&this)?;
173124
socket.udp_state = UdpState::ConnectReady;
174-
return Ok(());
125+
Ok(())
175126
}
176127
// continue in progress,
177-
Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => {}
128+
Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => {
129+
let socket = table.get_resource_mut(&this)?;
130+
socket.udp_state = UdpState::Connecting;
131+
Ok(())
132+
}
178133
// or fail immediately.
179-
Err(err) => return Err(err.into()),
134+
Err(err) => Err(err.into()),
180135
}
181-
182-
let socket = table.get_resource_mut(&this)?;
183-
socket.udp_state = UdpState::Connecting;
184-
185-
Ok(())
186136
}
187137

188138
fn finish_connect(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
189139
let table = self.table_mut();
190140
let socket = table.get_resource_mut(&this)?;
191-
192141
match socket.udp_state {
193-
UdpState::ConnectReady => {}
142+
UdpState::ConnectReady => {
143+
socket.udp_state = UdpState::Connected;
144+
Ok(())
145+
}
194146
UdpState::Connecting => {
195147
// Do a `poll` to test for completion, using a timeout of zero
196148
// to avoid blocking.
@@ -202,21 +154,21 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
202154
0,
203155
) {
204156
Ok(0) => return Err(ErrorCode::WouldBlock.into()),
205-
Ok(_) => (),
206-
Err(err) => Err(err).unwrap(),
157+
Ok(_) => {}
158+
Err(err) => return Err(err.into()),
207159
}
208160

209161
// Check whether the connect succeeded.
210162
match sockopt::get_socket_error(socket.udp_socket()) {
211-
Ok(Ok(())) => {}
212-
Err(err) | Ok(Err(err)) => return Err(err.into()),
163+
Ok(Ok(())) => {
164+
socket.udp_state = UdpState::Connected;
165+
Ok(())
166+
}
167+
Err(err) | Ok(Err(err)) => Err(err.into()),
213168
}
214169
}
215-
_ => return Err(ErrorCode::NotInProgress.into()),
216-
};
217-
218-
socket.udp_state = UdpState::Connected;
219-
Ok(())
170+
_ => Err(ErrorCode::NotInProgress.into()),
171+
}
220172
}
221173

222174
fn receive(
@@ -232,7 +184,7 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
232184
let socket = table.get_resource(&this)?;
233185

234186
let udp_socket = socket.udp_socket();
235-
let mut datagrams = Vec::with_capacity(max_results.try_into().unwrap_or(usize::MAX));
187+
let mut datagrams = vec![];
236188
let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
237189
match socket.udp_state {
238190
UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::NotBound.into()),
@@ -352,8 +304,12 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
352304
&mut self,
353305
this: Resource<udp::UdpSocket>,
354306
) -> Result<IpAddressFamily, anyhow::Error> {
355-
let family = address_family(self.table(), this)?;
356-
Ok(family)
307+
let table = self.table();
308+
let socket = table.get_resource(&this)?;
309+
match socket.family {
310+
AddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
311+
AddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
312+
}
357313
}
358314

359315
fn ipv6_only(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<bool> {
@@ -477,12 +433,12 @@ impl<T: WasiView> crate::preview2::host::udp::udp::HostUdpSocket for T {
477433
}
478434
}
479435

480-
// On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`.
481-
// <https://pubs.opengroup.org/onlinepubs/9699919799/functions/connect.html>
482-
#[cfg(not(windows))]
483-
const INPROGRESS: Errno = Errno::INPROGRESS;
484-
485-
// On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`.
486-
// <https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect>
487-
#[cfg(windows)]
488-
const INPROGRESS: Errno = Errno::WOULDBLOCK;
436+
const INPROGRESS: Errno = if cfg!(windows) {
437+
// On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`.
438+
// <https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-connect>
439+
Errno::WOULDBLOCK
440+
} else {
441+
// On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`.
442+
// <https://pubs.opengroup.org/onlinepubs/9699919799/functions/connect.html>
443+
Errno::INPROGRESS
444+
};

crates/wasi/src/preview2/udp.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,17 @@ pub struct UdpSocket {
4343

4444
/// The current state in the bind/connect progression.
4545
pub(crate) udp_state: UdpState,
46+
47+
/// Socket address family.
48+
pub(crate) family: AddressFamily,
4649
}
4750

4851
#[async_trait]
4952
impl Subscribe for UdpSocket {
5053
async fn ready(&mut self) {
5154
// Some states are ready immediately.
5255
match self.udp_state {
53-
UdpState::BindStarted | UdpState::ConnectReady => return,
56+
UdpState::BindStarted => return,
5457
_ => {}
5558
}
5659

@@ -68,16 +71,20 @@ impl UdpSocket {
6871
// Create a new host socket and set it to non-blocking, which is needed
6972
// by our async implementation.
7073
let udp_socket = cap_std::net::UdpSocket::new(family, Blocking::No)?;
71-
Self::from_udp_socket(udp_socket)
74+
Self::from_udp_socket(udp_socket, family)
7275
}
7376

74-
pub fn from_udp_socket(udp_socket: cap_std::net::UdpSocket) -> io::Result<Self> {
77+
pub fn from_udp_socket(
78+
udp_socket: cap_std::net::UdpSocket,
79+
family: AddressFamily,
80+
) -> io::Result<Self> {
7581
let fd = udp_socket.into_raw_socketlike();
7682
let std_socket = unsafe { std::net::UdpSocket::from_raw_socketlike(fd) };
7783
let socket = with_ambient_tokio_runtime(|| tokio::net::UdpSocket::try_from(std_socket))?;
7884
Ok(Self {
7985
inner: Arc::new(socket),
8086
udp_state: UdpState::Default,
87+
family,
8188
})
8289
}
8390

0 commit comments

Comments
 (0)