Skip to content

Commit 2c36552

Browse files
committed
FIxed some more weird https issues
1 parent 3b2a04f commit 2c36552

1 file changed

Lines changed: 37 additions & 13 deletions

File tree

src/battle.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use nix::sys::socket::{setsockopt, sockopt::TcpNoDelay};
1414
use rustls::{ClientConfig, ClientConnection, ServerName};
1515

1616
use rustls::client::{ServerCertVerified, ServerCertVerifier};
17+
use webpki_roots::TLS_SERVER_ROOTS;
1718

1819
// Helper function to set TCP_NODELAY that works on both Unix and Windows
1920
fn set_tcp_nodelay(stream: &TcpStream) -> io::Result<()> {
@@ -40,7 +41,7 @@ pub struct RequestConfig {
4041
pub port: u16,
4142
pub headers: Vec<(String, String)>,
4243
pub body: Option<String>,
43-
pub is_ht: bool,
44+
pub is_https: bool,
4445
}
4546

4647
impl RequestConfig {
@@ -49,7 +50,7 @@ impl RequestConfig {
4950
let host = uri.host_str().unwrap_or("127.0.0.1").to_string();
5051
let port = uri.port_or_known_default().unwrap_or(80);
5152
let path = if uri.path().is_empty() { "/" } else { uri.path() };
52-
let is_ht = uri.scheme() == "ht";
53+
let is_https = uri.scheme() == "https";
5354

5455
// Parse custom headers
5556
let mut parsed_headers = Vec::new();
@@ -79,7 +80,7 @@ impl RequestConfig {
7980
port,
8081
headers: parsed_headers,
8182
body: body.map(|s| s.to_string()),
82-
is_ht,
83+
is_https,
8384
})
8485
}
8586

@@ -252,7 +253,21 @@ struct Stats {
252253
status_counts: HashMap<i16, u64>,
253254
}
254255

255-
256+
fn create_tls_config() -> ClientConfig {
257+
let mut root_cert_store = rustls::RootCertStore::empty();
258+
root_cert_store.add_trust_anchors(TLS_SERVER_ROOTS.iter().map(|ta| {
259+
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
260+
ta.subject,
261+
ta.spki,
262+
ta.name_constraints,
263+
)
264+
}));
265+
266+
ClientConfig::builder()
267+
.with_safe_defaults()
268+
.with_root_certificates(root_cert_store)
269+
.with_no_client_auth()
270+
}
256271

257272
fn create_insecure_tls_config() -> ClientConfig {
258273
let root_cert_store = rustls::RootCertStore::empty();
@@ -317,20 +332,29 @@ fn resolve_address(url: &str) -> Result<SocketAddr, String> {
317332
}
318333

319334
// Helper function to create connection stream
320-
fn create_connection_stream(addr: SocketAddr, host: &str, is_ht: bool) -> Result<ConnectionStream, String> {
335+
fn create_connection_stream(addr: SocketAddr, host: &str, request_config: &RequestConfig) -> Result<ConnectionStream, String> {
321336
let stream = TcpStream::connect(addr)
322337
.map_err(|e| format!("Failed to connect: {}", e))?;
323338

324339
// Set socket options
325340
set_tcp_nodelay(&stream)
326341
.map_err(|e| format!("Failed to set TCP_NODELAY: {}", e))?;
327342

328-
if is_ht {
329-
let tls_config = create_insecure_tls_config();
343+
if request_config.is_https {
344+
let tls_config = if request_config.is_https {
345+
// Use insecure config for localhost to handle self-signed certificates
346+
if request_config.host == "127.0.0.1" || request_config.host == "localhost" {
347+
Some(Arc::new(create_insecure_tls_config()))
348+
} else {
349+
Some(Arc::new(create_tls_config()))
350+
}
351+
} else {
352+
None
353+
};
330354
let server_name = ServerName::try_from(host)
331355
.map_err(|e| format!("Invalid server name: {}", e))?;
332356

333-
let tls_conn = ClientConnection::new(Arc::new(tls_config), server_name)
357+
let tls_conn = ClientConnection::new(tls_config.unwrap(), server_name)
334358
.map_err(|e| format!("Failed to create TLS connection: {}", e))?;
335359

336360
Ok(ConnectionStream::Tls(tls_conn, stream))
@@ -408,13 +432,13 @@ pub fn health_check(params: &BattleParams) -> Result<(), String> {
408432
}
409433

410434
fn attempt_health_check(addr: &SocketAddr, request_config: &RequestConfig) -> Result<(), String> {
411-
let mut stream = create_connection_stream(*addr, &request_config.host, request_config.is_ht)
435+
let mut stream = create_connection_stream(*addr, &request_config.host, request_config)
412436
.map_err(|e| format!("Failed to create connection: {}", e))?;
413437

414438
// Small delay to let connection stabilize
415439
std::thread::sleep(std::time::Duration::from_millis(50));
416440

417-
if request_config.is_ht {
441+
if request_config.is_https {
418442
complete_tls_handshake(&mut stream)
419443
.map_err(|e| format!("TLS handshake failed: {}", e))?;
420444
}
@@ -543,7 +567,7 @@ fn run_thread(
543567

544568
let request_bytes = request_config.build_request().into_bytes();
545569

546-
let connection_stream = if request_config.is_ht {
570+
let connection_stream = if request_config.is_https {
547571
// Create TLS connection
548572
let server_name = ServerName::try_from(request_config.host.as_str())
549573
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("Invalid DNS name: {}", e)))?;
@@ -582,7 +606,7 @@ fn run_thread(
582606
let conn = connections_map.get_mut(&token).unwrap();
583607

584608
// Handle TLS handshake for HTTPS connections
585-
if request_config.is_ht && conn.stream.is_handshaking() && event.is_writable() {
609+
if request_config.is_https && conn.stream.is_handshaking() && event.is_writable() {
586610
match conn.stream.write(&[]) {
587611
Ok(_) => {
588612
if !conn.stream.is_handshaking() {
@@ -609,7 +633,7 @@ fn run_thread(
609633
}
610634
}
611635

612-
if event.is_writable() && !conn.request_sent && (!conn.stream.is_handshaking() || !request_config.is_ht) {
636+
if event.is_writable() && !conn.request_sent && (!conn.stream.is_handshaking() || !request_config.is_https) {
613637
// Send prebuilt request
614638
match conn.stream.write(&conn.request_bytes) {
615639
Ok(_) => {

0 commit comments

Comments
 (0)