Skip to content

Commit 1097e6b

Browse files
authored
Merge pull request #2 from Watfaq/tests
add tests
2 parents 1df695b + 597edac commit 1097e6b

File tree

10 files changed

+123
-89
lines changed

10 files changed

+123
-89
lines changed

.cargo/config.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[target.x86_64-unknown-linux-gnu]
2+
runner = 'sudo -E'

.devcontainer/devcontainer.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
2+
// README at: https://github.com/devcontainers/templates/tree/main/src/rust
3+
{
4+
"name": "Rust",
5+
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
6+
"image": "mcr.microsoft.com/devcontainers/rust:1-1-bullseye",
7+
"features": {
8+
"ghcr.io/devcontainers/features/rust:1": {
9+
"version": "latest",
10+
"profile": "minimal"
11+
}
12+
},
13+
"runArgs": [
14+
// TODO: figure out the exact cap-add ?
15+
"--privileged"
16+
],
17+
"remoteUser": "root"
18+
}

.github/dependabot.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# To get started with Dependabot version updates, you'll need to specify which
2+
# package ecosystems to update and where the package manifests are located.
3+
# Please see the documentation for more information:
4+
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5+
# https://containers.dev/guide/dependabot
6+
7+
version: 2
8+
updates:
9+
- package-ecosystem: "devcontainers"
10+
directory: "/"
11+
schedule:
12+
interval: weekly

.github/workflows/rust.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ env:
1111

1212
jobs:
1313
build:
14+
strategy:
15+
matrix:
16+
os: [ubuntu-latest, macos-latest]
1417

15-
runs-on: ubuntu-latest
18+
runs-on: ${{ matrix.os }}
1619

1720
steps:
1821
- uses: actions/checkout@v4

benches/bench_find_process_by_socket.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use criterion::{criterion_group, criterion_main, Criterion};
2-
use sock2proc::{FindProc, FindProcImpl};
2+
use sock2proc::find_process_name;
33

44
fn run_find_process_by_socket() {
55
let dst = std::net::SocketAddr::new(
66
std::net::IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)),
77
80,
88
);
9-
let _process_name = FindProcImpl::resolve(None, Some(dst), libc::IPPROTO_TCP);
9+
let _process_name = find_process_name(None, Some(dst), sock2proc::NetworkProtocol::TCP);
1010
}
1111

1212
fn criterion_benchmark(c: &mut Criterion) {

src/lib.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
mod platform;
22
mod utils;
33

4-
pub use libc::{IPPROTO_TCP, IPPROTO_UDP};
5-
pub use platform::{FindProc, FindProcImpl};
4+
#[derive(PartialEq, Clone, Copy, Debug)]
5+
#[repr(u8)]
6+
pub enum NetworkProtocol {
7+
TCP = 6,
8+
UDP = 17,
9+
}
10+
pub use platform::find_process_name;

src/platform/linux.rs

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,14 @@ use netlink_packet_sock_diag::{
1616
};
1717
use netlink_sys::{protocols::NETLINK_SOCK_DIAG, Socket, SocketAddr};
1818

19-
use super::FindProc;
19+
use crate::{utils::pre_condition, NetworkProtocol};
2020

21-
pub struct FindProcImpl;
22-
23-
impl FindProc for FindProcImpl {
24-
fn resolve(
25-
src: Option<std::net::SocketAddr>,
26-
dst: Option<std::net::SocketAddr>,
27-
proto: i32,
28-
) -> Option<String> {
29-
resolve(src, dst, proto)
30-
}
31-
}
32-
33-
fn resolve(
21+
pub fn find_process_name(
3422
src: Option<std::net::SocketAddr>,
3523
dst: Option<std::net::SocketAddr>,
36-
proto: i32,
24+
proto: NetworkProtocol,
3725
) -> Option<String> {
38-
if !crate::utils::check(src, dst) {
39-
return None;
40-
}
41-
if proto != libc::IPPROTO_TCP || proto != libc::IPPROTO_UDP {
26+
if !pre_condition(src, dst) {
4227
return None;
4328
}
4429

@@ -49,7 +34,7 @@ fn resolve(
4934
fn resolve_uid_inode(
5035
src: Option<std::net::SocketAddr>,
5136
dst: Option<std::net::SocketAddr>,
52-
proto: i32,
37+
proto: NetworkProtocol,
5338
) -> Option<(u32, u32)> {
5439
let mut socket = Socket::new(NETLINK_SOCK_DIAG).unwrap();
5540
let _port_number = socket.bind_auto().unwrap().port_number();
@@ -90,11 +75,12 @@ fn resolve_uid_inode(
9075
// Before calling serialize, it is important to check that the buffer in
9176
// which we're emitting is big enough for the packet, other
9277
// `serialize()` panics.
93-
assert_eq!(buf.len(), packet.buffer_len());
78+
assert_eq!(buf.len(), packet.buffer_len(), "Buffer is too small");
9479

9580
packet.serialize(&mut buf[..]);
9681

97-
if let Err(_) = socket.send(&buf[..], 0) {
82+
if let Err(e) = socket.send(&buf[..], 0) {
83+
eprintln!("Failed to send packet: {:?}", e);
9884
return None;
9985
}
10086

src/platform/macos.rs

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
use libc::{IPPROTO_TCP, IPPROTO_UDP};
21
use sysctl::Sysctl;
32

4-
use crate::utils::{check, is_ipv6};
5-
use crate::FindProc;
3+
use crate::utils::{is_ipv6, pre_condition};
4+
use crate::NetworkProtocol;
65

76
use std::io;
87
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
@@ -15,28 +14,24 @@ const PROCPIDPATHINFOSIZE: usize = 1024;
1514
const PROCCALLNUMPIDINFO: i32 = 0x2;
1615

1716
static STRUCT_SIZE: AtomicUsize = AtomicUsize::new(0);
18-
const STRUCT_SIZE_SETTER: Once = Once::new();
19-
20-
pub struct FindProcImpl;
21-
22-
impl FindProc for FindProcImpl {
23-
fn resolve(
24-
src: Option<std::net::SocketAddr>,
25-
dst: Option<std::net::SocketAddr>,
26-
proto: i32,
27-
) -> Option<String> {
28-
if !check(src, dst) {
29-
return None;
30-
}
31-
find_process_name(src, dst, proto).ok()
32-
}
17+
static STRUCT_SIZE_SETTER: Once = Once::new();
18+
19+
pub fn find_process_name(
20+
src: Option<std::net::SocketAddr>,
21+
dst: Option<std::net::SocketAddr>,
22+
proto: NetworkProtocol,
23+
) -> Option<String> {
24+
find_process_name_inner(src, dst, proto).ok()
3325
}
3426

35-
fn find_process_name(
27+
fn find_process_name_inner(
3628
src: Option<std::net::SocketAddr>,
3729
dst: Option<std::net::SocketAddr>,
38-
proto: i32,
30+
proto: NetworkProtocol,
3931
) -> Result<String, io::Error> {
32+
if !pre_condition(src, dst) {
33+
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid input"));
34+
}
4035
STRUCT_SIZE_SETTER.call_once(|| {
4136
let default = "".to_string();
4237
let ctl = sysctl::Ctl::new("kern.osrelease").unwrap();
@@ -53,14 +48,8 @@ fn find_process_name(
5348

5449
// see: https://github.com/apple-oss-distributions/xnu/blob/94d3b452840153a99b38a3a9659680b2a006908e/bsd/netinet/in_pcblist.c#L292
5550
let spath = match proto {
56-
IPPROTO_TCP => "net.inet.tcp.pcblist_n",
57-
IPPROTO_UDP => "net.inet.udp.pcblist_n",
58-
_ => {
59-
return Err(io::Error::new(
60-
io::ErrorKind::InvalidInput,
61-
"Invalid network",
62-
))
63-
}
51+
NetworkProtocol::TCP => "net.inet.tcp.pcblist_n",
52+
NetworkProtocol::UDP => "net.inet.udp.pcblist_n",
6453
};
6554

6655
let is_ipv4 = !is_ipv6(src, dst);
@@ -69,7 +58,12 @@ fn find_process_name(
6958
let value = ctl.value().unwrap();
7059
let buf = value.as_struct().unwrap();
7160
let struct_size = STRUCT_SIZE.load(std::sync::atomic::Ordering::Relaxed);
72-
let item_size = struct_size + if proto == IPPROTO_TCP { 208 } else { 0 };
61+
let item_size = struct_size
62+
+ if proto == NetworkProtocol::TCP {
63+
208
64+
} else {
65+
0
66+
};
7367

7468
// see https://github.com/apple-oss-distributions/xnu/blob/94d3b452840153a99b38a3a9659680b2a006908e/bsd/netinet/in_pcb.h#L451
7569
// offset of flag is 44
@@ -144,7 +138,7 @@ fn find_process_name(
144138
fn get_pid(bytes: &[u8]) -> u32 {
145139
assert_eq!(bytes.len(), 4);
146140
let mut pid_bytes = [0; 4];
147-
pid_bytes.copy_from_slice(&bytes);
141+
pid_bytes.copy_from_slice(bytes);
148142
if cfg!(target_endian = "big") {
149143
u32::from_be_bytes(pid_bytes)
150144
} else {

src/platform/mod.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,39 @@
1-
pub trait FindProc {
2-
fn resolve(
3-
src: Option<std::net::SocketAddr>,
4-
dst: Option<std::net::SocketAddr>,
5-
proto: i32,
6-
) -> Option<String>;
7-
}
8-
91
#[cfg(target_os = "linux")]
102
mod linux;
113
#[cfg(target_os = "linux")]
12-
pub use linux::FindProcImpl;
4+
pub use linux::find_process_name;
135

146
#[cfg(target_os = "macos")]
157
mod macos;
168
#[cfg(target_os = "macos")]
17-
pub use macos::FindProcImpl;
9+
pub use macos::find_process_name;
1810

1911
#[cfg(test)]
2012
mod tests {
2113

22-
use super::*;
14+
use std::net::TcpListener;
2315

2416
#[test]
25-
fn test_compile() {
26-
let dst = std::net::SocketAddr::new(
27-
std::net::IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)),
28-
80,
29-
);
30-
let _process_name = FindProcImpl::resolve(None, Some(dst), libc::IPPROTO_TCP);
17+
fn test_get_find_tcp_socket() {
18+
let socket = TcpListener::bind("127.0.0.1:0").unwrap();
19+
let addr = socket.local_addr().unwrap();
20+
let path = super::find_process_name(Some(addr), None, crate::NetworkProtocol::TCP);
21+
22+
assert!(path.is_some());
23+
24+
let current_exe = std::env::current_exe().unwrap();
25+
assert_eq!(path.unwrap(), current_exe.to_str().unwrap());
26+
}
27+
28+
#[test]
29+
fn test_get_find_udp_socket() {
30+
let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
31+
let addr = socket.local_addr().unwrap();
32+
let path = super::find_process_name(Some(addr), None, crate::NetworkProtocol::UDP);
33+
34+
assert!(path.is_some());
35+
36+
let current_exe = std::env::current_exe().unwrap();
37+
assert_eq!(path.unwrap(), current_exe.to_str().unwrap());
3138
}
3239
}

src/utils.rs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
1-
pub(crate) fn check(src: Option<std::net::SocketAddr>, dst: Option<std::net::SocketAddr>) -> bool {
2-
if src.is_none() && dst.is_none() {
3-
false
4-
} else if src.is_some() && dst.is_some() {
5-
let inner1 = src.unwrap();
6-
let inner2 = dst.unwrap();
7-
(inner1.is_ipv4() && inner2.is_ipv6()) || (inner2.is_ipv4() && inner1.is_ipv6())
8-
} else {
9-
true
1+
pub(crate) fn pre_condition(
2+
src: Option<std::net::SocketAddr>,
3+
dst: Option<std::net::SocketAddr>,
4+
) -> bool {
5+
match (src, dst) {
6+
(None, None) => false,
7+
(Some(_), None) => true,
8+
(None, Some(_)) => true,
9+
(Some(left), Some(right)) => {
10+
// it was (inner1.is_ipv4() && inner2.is_ipv6()) || (inner2.is_ipv4() && inner1.is_ipv6())
11+
(left.is_ipv4() && right.is_ipv4()) || (left.is_ipv6() && right.is_ipv6())
12+
}
1013
}
1114
}
1215

13-
pub(crate) fn is_ipv6(src: Option<std::net::SocketAddr>, dst: Option<std::net::SocketAddr>) -> bool {
14-
if src.is_some() {
15-
src.unwrap().is_ipv6()
16-
} else {
17-
dst.unwrap().is_ipv6()
16+
pub(crate) fn is_ipv6(
17+
src: Option<std::net::SocketAddr>,
18+
dst: Option<std::net::SocketAddr>,
19+
) -> bool {
20+
match (src, dst) {
21+
(Some(addr), None) => addr.is_ipv6(),
22+
(None, Some(addr)) => addr.is_ipv6(),
23+
(Some(left), Some(right)) => left.is_ipv6() || right.is_ipv6(),
24+
_ => false,
1825
}
19-
}
26+
}

0 commit comments

Comments
 (0)