Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions lib/src/bolt/request/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl Serialize for Routing {
Routing::No => serializer.serialize_none(),
Routing::Yes(routing) => {
let mut map = serializer.serialize_map(Some(routing.len()))?;
for (k, v) in routing {
for (k, v) in routing.iter() {
map.serialize_entry(k.value.as_str(), v.value.as_str())?;
}
map.end()
Expand Down Expand Up @@ -76,7 +76,7 @@ mod tests {
#[test]
fn serialize() {
let route = RouteBuilder::new(
Routing::Yes(vec![("address".into(), "localhost:7687".into())]),
Routing::Yes([("address".into(), "localhost:7687".into())].into()),
vec!["bookmark".into()],
)
.with_db(Database::from("neo4j"))
Expand All @@ -99,7 +99,7 @@ mod tests {
#[test]
fn serialize_no_db() {
let builder = RouteBuilder::new(
Routing::Yes(vec![("address".into(), "localhost:7687".into())]),
Routing::Yes([("address".into(), "localhost:7687".into())].into()),
vec!["bookmark".into()],
);
let route = builder.build(Version::V4_3);
Expand All @@ -121,7 +121,7 @@ mod tests {
#[test]
fn serialize_no_db_v4_4() {
let builder = RouteBuilder::new(
Routing::Yes(vec![("address".into(), "localhost:7687".into())]),
Routing::Yes([("address".into(), "localhost:7687".into())].into()),
vec!["bookmark".into()],
);
let route = builder.build(Version::V4_4);
Expand All @@ -147,7 +147,7 @@ mod tests {
#[test]
fn serialize_with_db_v4_4() {
let builder = RouteBuilder::new(
Routing::Yes(vec![("address".into(), "localhost:7687".into())]),
Routing::Yes([("address".into(), "localhost:7687".into())].into()),
vec!["bookmark".into()],
);
let route = builder
Expand Down
135 changes: 87 additions & 48 deletions lib/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ pub struct Connection {

impl Connection {
pub(crate) async fn new(info: &ConnectionInfo) -> Result<Self> {
let mut connection = Self::prepare(info).await?;
let hello = info.to_hello(connection.version);
let mut connection = Self::prepare(&info.prepare).await?;
let hello = info.init.to_hello(connection.version);
connection.hello(hello).await?;
Ok(connection)
}
Expand All @@ -63,14 +63,14 @@ impl Connection {
self.version
}

pub(crate) async fn prepare(info: &ConnectionInfo) -> Result<Self> {
let mut stream = match &info.host {
Host::Domain(domain) => TcpStream::connect((&**domain, info.port)).await?,
Host::Ipv4(ip) => TcpStream::connect((*ip, info.port)).await?,
Host::Ipv6(ip) => TcpStream::connect((*ip, info.port)).await?,
pub(crate) async fn prepare(opts: &PrepareOpts) -> Result<Self> {
let mut stream = match &opts.host {
Host::Domain(domain) => TcpStream::connect((&**domain, opts.port)).await?,
Host::Ipv4(ip) => TcpStream::connect((*ip, opts.port)).await?,
Host::Ipv6(ip) => TcpStream::connect((*ip, opts.port)).await?,
};

Ok(match &info.encryption {
Ok(match &opts.encryption {
Some((connector, domain)) => {
let mut stream = connector.connect(domain.clone(), stream).await?;
let version = Self::init(&mut stream).await?;
Expand Down Expand Up @@ -269,7 +269,7 @@ impl Connection {
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Routing {
No,
Yes(Vec<(BoltString, BoltString)>),
Yes(Arc<[(BoltString, BoltString)]>),
}

impl From<Routing> for Option<BoltMap> {
Expand All @@ -278,8 +278,8 @@ impl From<Routing> for Option<BoltMap> {
Routing::No => None,
Routing::Yes(routing) => Some(
routing
.into_iter()
.map(|(k, v)| (k, BoltType::String(v)))
.iter()
.map(|(k, v)| (k.clone(), BoltType::String(v.clone())))
.collect(),
),
}
Expand All @@ -302,24 +302,78 @@ impl Display for Routing {
}
}

#[derive(Clone)]
pub(crate) struct PrepareOpts {
pub(crate) host: Host<Arc<str>>,
pub(crate) port: u16,
pub(crate) encryption: Option<(TlsConnector, ServerName<'static>)>,
}

impl Debug for PrepareOpts {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PrepareOpts")
.field("host", &self.host)
.field("port", &self.port)
.field("encryption", &self.encryption.is_some())
.finish()
}
}

#[derive(Clone)]
pub(crate) struct InitOpts {
pub(crate) user: Arc<str>,
pub(crate) password: Arc<str>,
pub(crate) routing: Routing,
}

impl Debug for InitOpts {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InitOpts")
.field("user", &self.user)
.field("password", &"***")
.field("routing", &self.routing)
.finish()
}
}

impl InitOpts {
#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))]
pub(crate) fn to_hello(&self, version: Version) -> BoltRequest {
HelloBuilder::new(&*self.user, &*self.password)
.with_routing(self.routing.clone())
.build(version)
}

#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
pub(crate) fn to_hello(&self, version: Version) -> Hello {
match self.routing {
Routing::No => HelloBuilder::new(&self.user, &self.password).build(version),
Routing::Yes(ref routing) => HelloBuilder::new(&self.user, &self.password)
.with_routing(
routing
.iter()
.map(|(k, v)| (k.value.as_str(), v.value.as_str())),
)
.build(version),
}
}
}

#[derive(Clone)]
pub(crate) struct ConnectionInfo {
pub user: Arc<str>,
pub password: Arc<str>,
pub host: Host<Arc<str>>,
pub port: u16,
pub routing: Routing,
pub encryption: Option<(TlsConnector, ServerName<'static>)>,
pub(crate) prepare: PrepareOpts,
pub(crate) init: InitOpts,
}

impl Debug for ConnectionInfo {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionInfo")
.field("user", &self.user)
.field("user", &self.init.user)
.field("password", &"***")
.field("host", &self.host)
.field("port", &self.port)
.field("routing", &self.routing)
.field("encryption", &self.encryption.is_some())
.field("host", &self.prepare.host)
.field("port", &self.prepare.port)
.field("routing", &self.init.routing)
.field("encryption", &self.prepare.encryption.is_some())
.finish_non_exhaustive()
}
}
Expand Down Expand Up @@ -360,7 +414,8 @@ impl ConnectionInfo {
"Client-side routing is in experimental mode.",
"It is possible that operations against a cluster (such as Aura) will fail."
));
Routing::Yes(url.routing_context())
let context = url.routing_context();
Routing::Yes(context.into())
} else {
Routing::No
};
Expand All @@ -373,14 +428,19 @@ impl ConnectionInfo {
Host::Ipv6(d) => Host::Ipv6(d),
};

Ok(Self {
user: user.into(),
password: password.into(),
let prepare = PrepareOpts {
host,
port: url.port(),
encryption,
};

let init = InitOpts {
user: user.into(),
password: password.into(),
routing,
})
};

Ok(Self { prepare, init })
}

fn tls_connector(
Expand Down Expand Up @@ -432,27 +492,6 @@ impl ConnectionInfo {

Ok((connector, domain))
}

#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))]
pub(crate) fn to_hello(&self, version: Version) -> BoltRequest {
HelloBuilder::new(&*self.user, &*self.password)
.with_routing(self.routing.clone())
.build(version)
}

#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
pub(crate) fn to_hello(&self, version: Version) -> Hello {
match self.routing {
Routing::No => HelloBuilder::new(&self.user, &self.password).build(version),
Routing::Yes(ref routing) => HelloBuilder::new(&self.user, &self.password)
.with_routing(
routing
.iter()
.map(|(k, v)| (k.value.as_str(), v.value.as_str())),
)
.build(version),
}
}
}

#[derive(Clone, Debug)]
Expand Down
5 changes: 3 additions & 2 deletions lib/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use {
crate::routing::{ClusterRoutingTableProvider, RoutedConnectionManager},
crate::summary::ResultSummary,
log::debug,
std::sync::Arc,
};

use crate::graph::ConnectionPoolManager::Direct;
Expand Down Expand Up @@ -75,11 +76,11 @@ impl Graph {
&config.password,
&config.tls_config,
)?;
if matches!(info.routing, Routing::Yes(_)) {
if matches!(info.init.routing, Routing::Yes(_)) {
debug!("Routing enabled, creating a routed connection manager");
let pool = Routed(RoutedConnectionManager::new(
&config,
Box::new(ClusterRoutingTableProvider),
Arc::new(ClusterRoutingTableProvider),
)?);
Ok(Graph {
config: config.into_live_config(),
Expand Down
8 changes: 3 additions & 5 deletions lib/src/routing/connection_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Default for ConnectionRegistry {
async fn refresh_routing_table(
config: Config,
registry: Arc<ConnectionRegistry>,
provider: Arc<Box<dyn RoutingTableProvider>>,
provider: Arc<dyn RoutingTableProvider>,
bookmarks: &[String],
) -> Result<u64, Error> {
debug!("Routing table expired or empty, refreshing...");
Expand Down Expand Up @@ -109,7 +109,7 @@ async fn refresh_routing_table(
pub(crate) fn start_background_updater(
config: &Config,
registry: Arc<ConnectionRegistry>,
provider: Arc<Box<dyn RoutingTableProvider>>,
provider: Arc<dyn RoutingTableProvider>,
) -> Sender<RegistryCommand> {
let config_clone = config.clone();
let (tx, mut rx) = mpsc::channel(1);
Expand Down Expand Up @@ -266,9 +266,7 @@ mod tests {
let ttl = refresh_routing_table(
config.clone(),
registry.clone(),
Arc::new(Box::new(TestRoutingTableProvider::new(
cluster_routing_table,
))),
Arc::new(TestRoutingTableProvider::new(cluster_routing_table)),
&[],
)
.await
Expand Down
8 changes: 3 additions & 5 deletions lib/src/routing/routed_connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use crate::{Config, Error, Operation};
use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
use futures::lock::Mutex;
use log::{debug, error};
use std::sync::Arc;
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use tokio::sync::mpsc::Sender;

#[derive(Clone)]
Expand All @@ -24,7 +23,7 @@ pub struct RoutedConnectionManager {
}

impl RoutedConnectionManager {
pub fn new(config: &Config, provider: Box<dyn RoutingTableProvider>) -> Result<Self, Error> {
pub fn new(config: &Config, provider: Arc<dyn RoutingTableProvider>) -> Result<Self, Error> {
let backoff = Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(1))
Expand All @@ -35,8 +34,7 @@ impl RoutedConnectionManager {
);

let connection_registry = Arc::new(ConnectionRegistry::default());
let channel =
start_background_updater(config, connection_registry.clone(), provider.into());
let channel = start_background_updater(config, connection_registry.clone(), provider);
Ok(RoutedConnectionManager {
load_balancing_strategy: Arc::new(RoundRobinStrategy::default()),
bookmarks: Arc::new(Mutex::new(vec![])),
Expand Down
2 changes: 1 addition & 1 deletion lib/src/routing/routing_table_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl RoutingTableProvider for ClusterRoutingTableProvider {
&config.tls_config,
)?;
let mut connection = Connection::new(&info).await?;
let mut builder = RouteBuilder::new(info.routing, bookmarks);
let mut builder = RouteBuilder::new(info.init.routing, bookmarks);
if let Some(db) = config.db.clone() {
builder = builder.with_db(db);
}
Expand Down
8 changes: 4 additions & 4 deletions lib/src/types/serde/typ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ mod tests {
#[test]
fn tuple_struct_from_map_fails() {
// We do not support this since maps are unordered and
// we cannot gurantee that the values are in the same
// we cannot guarantee that the values are in the same
// order as the tuple struct fields.
#[derive(Clone, Debug, PartialEq, Eq, Deserialize)]
struct Person(String, u8);
Expand Down Expand Up @@ -2063,9 +2063,9 @@ mod tests {
let bolt = BoltLocalTime::from(time);
let bolt = BoltType::LocalTime(bolt);

let acutal = bolt.to::<(NaiveTime, Option<Offset>)>().unwrap();
assert_eq!(acutal.0, time);
assert_eq!(acutal.1, None);
let actual = bolt.to::<(NaiveTime, Option<Offset>)>().unwrap();
assert_eq!(actual.0, time);
assert_eq!(actual.1, None);
}

#[test]
Expand Down