diff --git a/libs/mongodb-client/src/lib.rs b/libs/mongodb-client/src/lib.rs index 51282e7cccd7..71ffe989bcb1 100644 --- a/libs/mongodb-client/src/lib.rs +++ b/libs/mongodb-client/src/lib.rs @@ -137,49 +137,68 @@ impl FromStr for MongoConnectionString { let hosts: Result, Error> = hosts_section .split(',') .map(|address| { - let mut parts = address.split(':'); - - let hostname = match parts.next() { - Some(part) => { - if part.is_empty() { + let (hostname, port) = if address.starts_with('[') { + let end_bracket_idx = match address.rfind(']') { + Some(end_bracket_idx) => end_bracket_idx, + None => { return Err(ErrorKind::invalid_argument(format!( - "invalid server address: \"{address}\"; hostname cannot be empty" + "invalid server address: \"{address}\"; missing closing bracket for IPv6 address" )) .into()); } - part - } - None => { - return Err( - ErrorKind::invalid_argument(format!("invalid server address: \"{address}\"")).into(), - ); - } - }; + }; - let port = match parts.next() { - Some(part) => { - let port = u16::from_str(part).map_err(|_| { + let (host, port_str) = address.split_at(end_bracket_idx + 1); + + let port = if !port_str.is_empty() { + let port_str = port_str.strip_prefix(':').ok_or_else(|| { ErrorKind::invalid_argument(format!( - "port must be valid 16-bit unsigned integer, instead got: {part}" + "invalid server address: \"{address}\"; invalid characters after IPv6 address" )) })?; - if port == 0 { - return Err(ErrorKind::invalid_argument(format!( - "invalid server address: \"{address}\"; port must be non-zero" - )) - .into()); + Some(parse_port(port_str, address)?) + } else { + None + }; + + (host, port) + } else { + let mut parts = address.split(':'); + let hostname = match parts.next() { + Some(part) => { + if part.is_empty() { + return Err(ErrorKind::invalid_argument(format!( + "invalid server address: \"{address}\"; hostname cannot be empty" + )) + .into()); + } + part } - if parts.next().is_some() { - return Err(ErrorKind::invalid_argument(format!( - "address \"{address}\" contains more than one unescaped ':'" - )) - .into()); + None => { + return Err( + ErrorKind::invalid_argument(format!("invalid server address: \"{address}\"")).into(), + ); } + }; - Some(port) - } - None => None, + let port = match parts.next() { + Some(part) => { + let port = parse_port(part, address)?; + + if parts.next().is_some() { + return Err(ErrorKind::invalid_argument(format!( + "address \"{address}\" contains more than one unescaped ':'" + )) + .into()); + } + + Some(port) + } + None => None, + }; + + (hostname, port) }; Ok((hostname.to_lowercase(), port)) @@ -222,6 +241,23 @@ fn percent_decode(s: &str, err_message: &str) -> Result { } } +fn parse_port(port_str: &str, address: &str) -> Result { + let port = u16::from_str(port_str).map_err(|_| { + ErrorKind::invalid_argument(format!( + "port must be valid 16-bit unsigned integer, instead got: {port_str}" + )) + })?; + + if port == 0 { + return Err(ErrorKind::invalid_argument(format!( + "invalid server address: \"{address}\"; port must be non-zero" + )) + .into()); + } + + Ok(port) +} + #[cfg(test)] mod tests { use crate::MongoConnectionString; @@ -301,4 +337,38 @@ mod tests { hosts ); } + + #[test] + fn ipv6_host() { + let s = "mongodb://[::1]/test"; + let MongoConnectionString { hosts, .. } = s.parse().unwrap(); + assert_eq!(vec![(String::from("[::1]"), None)], hosts); + } + + #[test] + fn ipv6_host_and_port() { + let s = "mongodb://[::1]:27017/test"; + let MongoConnectionString { hosts, .. } = s.parse().unwrap(); + assert_eq!(vec![(String::from("[::1]"), Some(27017))], hosts); + } + + #[test] + fn multiple_hosts_including_ipv6() { + let s = "mongodb://[::1]:27017,localhost:27018/test"; + let MongoConnectionString { hosts, .. } = s.parse().unwrap(); + + assert_eq!( + vec![ + (String::from("[::1]"), Some(27017)), + (String::from("localhost"), Some(27018)) + ], + hosts + ); + } + + #[test] + fn ipv6_host_and_zero_port() { + let s = "mongodb://[::1]:0/test"; + assert!(s.parse::().is_err()); + } }