@@ -14,6 +14,7 @@ use nix::sys::socket::{setsockopt, sockopt::TcpNoDelay};
1414use rustls:: { ClientConfig , ClientConnection , ServerName } ;
1515
1616use rustls:: client:: { ServerCertVerified , ServerCertVerifier } ;
17+ use webpki_roots:: TLS_SERVER_ROOTS ;
1718
1819// Helper function to set TCP_NODELAY that works on both Unix and Windows
1920fn set_tcp_nodelay ( stream : & TcpStream ) -> io:: Result < ( ) > {
@@ -40,7 +41,7 @@ pub struct RequestConfig {
4041 pub port : u16 ,
4142 pub headers : Vec < ( String , String ) > ,
4243 pub body : Option < String > ,
43- pub is_ht : bool ,
44+ pub is_https : bool ,
4445}
4546
4647impl RequestConfig {
@@ -49,7 +50,7 @@ impl RequestConfig {
4950 let host = uri. host_str ( ) . unwrap_or ( "127.0.0.1" ) . to_string ( ) ;
5051 let port = uri. port_or_known_default ( ) . unwrap_or ( 80 ) ;
5152 let path = if uri. path ( ) . is_empty ( ) { "/" } else { uri. path ( ) } ;
52- let is_ht = uri. scheme ( ) == "ht " ;
53+ let is_https = uri. scheme ( ) == "https " ;
5354
5455 // Parse custom headers
5556 let mut parsed_headers = Vec :: new ( ) ;
@@ -79,7 +80,7 @@ impl RequestConfig {
7980 port,
8081 headers : parsed_headers,
8182 body : body. map ( |s| s. to_string ( ) ) ,
82- is_ht ,
83+ is_https ,
8384 } )
8485 }
8586
@@ -252,7 +253,21 @@ struct Stats {
252253 status_counts : HashMap < i16 , u64 > ,
253254}
254255
255-
256+ fn create_tls_config ( ) -> ClientConfig {
257+ let mut root_cert_store = rustls:: RootCertStore :: empty ( ) ;
258+ root_cert_store. add_trust_anchors ( TLS_SERVER_ROOTS . iter ( ) . map ( |ta| {
259+ rustls:: OwnedTrustAnchor :: from_subject_spki_name_constraints (
260+ ta. subject ,
261+ ta. spki ,
262+ ta. name_constraints ,
263+ )
264+ } ) ) ;
265+
266+ ClientConfig :: builder ( )
267+ . with_safe_defaults ( )
268+ . with_root_certificates ( root_cert_store)
269+ . with_no_client_auth ( )
270+ }
256271
257272fn create_insecure_tls_config ( ) -> ClientConfig {
258273 let root_cert_store = rustls:: RootCertStore :: empty ( ) ;
@@ -317,20 +332,29 @@ fn resolve_address(url: &str) -> Result<SocketAddr, String> {
317332}
318333
319334// Helper function to create connection stream
320- fn create_connection_stream ( addr : SocketAddr , host : & str , is_ht : bool ) -> Result < ConnectionStream , String > {
335+ fn create_connection_stream ( addr : SocketAddr , host : & str , request_config : & RequestConfig ) -> Result < ConnectionStream , String > {
321336 let stream = TcpStream :: connect ( addr)
322337 . map_err ( |e| format ! ( "Failed to connect: {}" , e) ) ?;
323338
324339 // Set socket options
325340 set_tcp_nodelay ( & stream)
326341 . map_err ( |e| format ! ( "Failed to set TCP_NODELAY: {}" , e) ) ?;
327342
328- if is_ht {
329- let tls_config = create_insecure_tls_config ( ) ;
343+ if request_config. is_https {
344+ let tls_config = if request_config. is_https {
345+ // Use insecure config for localhost to handle self-signed certificates
346+ if request_config. host == "127.0.0.1" || request_config. host == "localhost" {
347+ Some ( Arc :: new ( create_insecure_tls_config ( ) ) )
348+ } else {
349+ Some ( Arc :: new ( create_tls_config ( ) ) )
350+ }
351+ } else {
352+ None
353+ } ;
330354 let server_name = ServerName :: try_from ( host)
331355 . map_err ( |e| format ! ( "Invalid server name: {}" , e) ) ?;
332356
333- let tls_conn = ClientConnection :: new ( Arc :: new ( tls_config) , server_name)
357+ let tls_conn = ClientConnection :: new ( tls_config. unwrap ( ) , server_name)
334358 . map_err ( |e| format ! ( "Failed to create TLS connection: {}" , e) ) ?;
335359
336360 Ok ( ConnectionStream :: Tls ( tls_conn, stream) )
@@ -408,13 +432,13 @@ pub fn health_check(params: &BattleParams) -> Result<(), String> {
408432}
409433
410434fn attempt_health_check ( addr : & SocketAddr , request_config : & RequestConfig ) -> Result < ( ) , String > {
411- let mut stream = create_connection_stream ( * addr, & request_config. host , request_config. is_ht )
435+ let mut stream = create_connection_stream ( * addr, & request_config. host , request_config)
412436 . map_err ( |e| format ! ( "Failed to create connection: {}" , e) ) ?;
413437
414438 // Small delay to let connection stabilize
415439 std:: thread:: sleep ( std:: time:: Duration :: from_millis ( 50 ) ) ;
416440
417- if request_config. is_ht {
441+ if request_config. is_https {
418442 complete_tls_handshake ( & mut stream)
419443 . map_err ( |e| format ! ( "TLS handshake failed: {}" , e) ) ?;
420444 }
@@ -543,7 +567,7 @@ fn run_thread(
543567
544568 let request_bytes = request_config. build_request ( ) . into_bytes ( ) ;
545569
546- let connection_stream = if request_config. is_ht {
570+ let connection_stream = if request_config. is_https {
547571 // Create TLS connection
548572 let server_name = ServerName :: try_from ( request_config. host . as_str ( ) )
549573 . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidInput , format ! ( "Invalid DNS name: {}" , e) ) ) ?;
@@ -582,7 +606,7 @@ fn run_thread(
582606 let conn = connections_map. get_mut ( & token) . unwrap ( ) ;
583607
584608 // Handle TLS handshake for HTTPS connections
585- if request_config. is_ht && conn. stream . is_handshaking ( ) && event. is_writable ( ) {
609+ if request_config. is_https && conn. stream . is_handshaking ( ) && event. is_writable ( ) {
586610 match conn. stream . write ( & [ ] ) {
587611 Ok ( _) => {
588612 if !conn. stream . is_handshaking ( ) {
@@ -609,7 +633,7 @@ fn run_thread(
609633 }
610634 }
611635
612- if event. is_writable ( ) && !conn. request_sent && ( !conn. stream . is_handshaking ( ) || !request_config. is_ht ) {
636+ if event. is_writable ( ) && !conn. request_sent && ( !conn. stream . is_handshaking ( ) || !request_config. is_https ) {
613637 // Send prebuilt request
614638 match conn. stream . write ( & conn. request_bytes ) {
615639 Ok ( _) => {
0 commit comments