Skip to content
Merged
123 changes: 123 additions & 0 deletions client/src/callback_based.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//! This module implements support for callback-based gRPC service that has a callback invoked for
//! every gRPC call instead of directly using the network.

use anyhow::anyhow;
use bytes::{BufMut, BytesMut};
use futures_util::future::BoxFuture;
use futures_util::stream;
use http::{HeaderMap, Request, Response};
use http_body_util::{BodyExt, StreamBody, combinators::BoxBody};
use hyper::body::{Bytes, Frame};
use std::{
sync::Arc,
task::{Context, Poll},
};
use tonic::{Status, metadata::GRPC_CONTENT_TYPE};
use tower::Service;

/// gRPC request for use by a callback.
pub struct GrpcRequest {
/// Fully qualified gRPC service name.
pub service: String,
/// RPC name.
pub rpc: String,
/// Request headers.
pub headers: HeaderMap,
/// Protobuf bytes of the request.
pub proto: Bytes,
}

/// Successful gRPC response returned by a callback.
pub struct GrpcSuccessResponse {
/// Response headers.
pub headers: HeaderMap,

/// Response proto bytes.
pub proto: Vec<u8>,
}

/// gRPC service that invokes the given callback on each call.
#[derive(Clone)]
pub struct CallbackBasedGrpcService {
/// Callback to invoke on each RPC call.
#[allow(clippy::type_complexity)] // Signature is not that complex
pub callback: Arc<
dyn Fn(GrpcRequest) -> BoxFuture<'static, Result<GrpcSuccessResponse, Status>>
+ Send
+ Sync,
>,
}

impl Service<Request<tonic::body::Body>> for CallbackBasedGrpcService {
type Response = http::Response<tonic::body::Body>;
type Error = anyhow::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, req: Request<tonic::body::Body>) -> Self::Future {
let callback = self.callback.clone();

Box::pin(async move {
// Build req
let (parts, body) = req.into_parts();
let mut path_parts = parts.uri.path().trim_start_matches('/').split('/');
let req_body = body.collect().await.map_err(|e| anyhow!(e))?.to_bytes();
// Body is flag saying whether compressed (we do not support that), then 32-bit length,
// then the actual proto.
if req_body.len() < 5 {
return Err(anyhow!("Too few request bytes: {}", req_body.len()));
} else if req_body[0] != 0 {
return Err(anyhow!("Compression not supported"));
}
let req_proto_len =
u32::from_be_bytes([req_body[1], req_body[2], req_body[3], req_body[4]]) as usize;
if req_body.len() < 5 + req_proto_len {
return Err(anyhow!(
"Expected request body length at least {}, got {}",
5 + req_proto_len,
req_body.len()
));
}
let req = GrpcRequest {
service: path_parts.next().unwrap_or_default().to_owned(),
rpc: path_parts.next().unwrap_or_default().to_owned(),
headers: parts.headers,
proto: req_body.slice(5..5 + req_proto_len),
};

// Invoke and handle response
match (callback)(req).await {
Ok(success) => {
// Create body bytes which requires a flag saying whether compressed, then
// message len, then actual message. So we create a Bytes for those 5 prepend
// parts, then stream it alongside the user-provided Vec. This allows us to
// avoid copying the vec
let mut body_prepend = BytesMut::with_capacity(5);
body_prepend.put_u8(0); // 0 means no compression
body_prepend.put_u32(success.proto.len() as u32);
let stream = stream::iter(vec![
Ok::<_, Status>(Frame::data(Bytes::from(body_prepend))),
Ok::<_, Status>(Frame::data(Bytes::from(success.proto))),
]);
let stream_body = StreamBody::new(stream);
let full_body = BoxBody::new(stream_body).boxed();

// Build response appending headers
let mut resp_builder = Response::builder()
.status(200)
.header(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
for (key, value) in success.headers.iter() {
resp_builder = resp_builder.header(key, value);
}
Ok(resp_builder
.body(tonic::body::Body::new(full_body))
.map_err(|e| anyhow!(e))?)
}
Err(status) => Ok(status.into_http()),
}
})
}
}
80 changes: 53 additions & 27 deletions client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#[macro_use]
extern crate tracing;

pub mod callback_based;
mod metrics;
mod proxy;
mod raw;
Expand Down Expand Up @@ -35,7 +36,7 @@ pub use workflow_handle::{
};

use crate::{
metrics::{GrpcMetricSvc, MetricsContext},
metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
raw::{AttachMetricLabels, sealed::RawClientLike},
sealed::WfHandleClient,
workflow_handle::UntypedWorkflowHandle,
Expand Down Expand Up @@ -432,34 +433,59 @@ impl ClientOptions {
metrics_meter: Option<TemporalMeter>,
) -> Result<RetryClient<ConfiguredClient<TemporalServiceClientWithMetrics>>, ClientInitError>
{
let channel = Channel::from_shared(self.target_url.to_string())?;
let channel = self.add_tls_to_channel(channel).await?;
let channel = if let Some(keep_alive) = self.keep_alive.as_ref() {
channel
.keep_alive_while_idle(true)
.http2_keep_alive_interval(keep_alive.interval)
.keep_alive_timeout(keep_alive.timeout)
} else {
channel
};
let channel = if let Some(origin) = self.override_origin.clone() {
channel.origin(origin)
} else {
channel
};
// If there is a proxy, we have to connect that way
let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() {
proxy.connect_endpoint(&channel).await?
} else {
channel.connect().await?
};
let service = ServiceBuilder::new()
.layer_fn(move |channel| GrpcMetricSvc {
inner: channel,
self.connect_no_namespace_with_service_override(metrics_meter, None)
.await
}

/// Attempt to establish a connection to the Temporal server and return a gRPC client which is
/// intercepted with retry, default headers functionality, and metrics if provided. If a
/// service_override is present, network-specific options are ignored and the callback is
/// invoked for each gRPC call.
///
/// See [RetryClient] for more
pub async fn connect_no_namespace_with_service_override(
&self,
metrics_meter: Option<TemporalMeter>,
service_override: Option<callback_based::CallbackBasedGrpcService>,
) -> Result<RetryClient<ConfiguredClient<TemporalServiceClientWithMetrics>>, ClientInitError>
{
let service = if let Some(service_override) = service_override {
GrpcMetricSvc {
inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
metrics: metrics_meter.clone().map(MetricsContext::new),
disable_errcode_label: self.disable_error_code_metric_tags,
})
.service(channel);
}
} else {
let channel = Channel::from_shared(self.target_url.to_string())?;
let channel = self.add_tls_to_channel(channel).await?;
let channel = if let Some(keep_alive) = self.keep_alive.as_ref() {
channel
.keep_alive_while_idle(true)
.http2_keep_alive_interval(keep_alive.interval)
.keep_alive_timeout(keep_alive.timeout)
} else {
channel
};
let channel = if let Some(origin) = self.override_origin.clone() {
channel.origin(origin)
} else {
channel
};
// If there is a proxy, we have to connect that way
let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() {
proxy.connect_endpoint(&channel).await?
} else {
channel.connect().await?
};
ServiceBuilder::new()
.layer_fn(move |channel| GrpcMetricSvc {
inner: ChannelOrGrpcOverride::Channel(channel),
metrics: metrics_meter.clone().map(MetricsContext::new),
disable_errcode_label: self.disable_error_code_metric_tags,
})
.service(channel)
};

let headers = Arc::new(RwLock::new(ClientHeaders {
user_headers: self.headers.clone().unwrap_or_default(),
api_key: self.api_key.clone(),
Expand Down
38 changes: 33 additions & 5 deletions client/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{AttachMetricLabels, CallType, dbg_panic};
use crate::{AttachMetricLabels, CallType, callback_based, dbg_panic};
use futures_util::TryFutureExt;
use futures_util::future::Either;
use futures_util::{FutureExt, future::BoxFuture};
use std::{
fmt,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
Expand Down Expand Up @@ -205,19 +208,37 @@ fn code_as_screaming_snake(code: &Code) -> &'static str {
/// Implements metrics functionality for gRPC (really, any http) calls
#[derive(Debug, Clone)]
pub struct GrpcMetricSvc {
pub(crate) inner: Channel,
pub(crate) inner: ChannelOrGrpcOverride,
// If set to none, metrics are a no-op
pub(crate) metrics: Option<MetricsContext>,
pub(crate) disable_errcode_label: bool,
}

#[derive(Clone)]
pub(crate) enum ChannelOrGrpcOverride {
Channel(Channel),
GrpcOverride(callback_based::CallbackBasedGrpcService),
}

impl fmt::Debug for ChannelOrGrpcOverride {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ChannelOrGrpcOverride::Channel(inner) => fmt::Debug::fmt(inner, f),
ChannelOrGrpcOverride::GrpcOverride(_) => f.write_str("<callback-based-grpc-service>"),
}
}
}

impl Service<http::Request<Body>> for GrpcMetricSvc {
type Response = http::Response<Body>;
type Error = tonic::transport::Error;
type Error = Box<dyn std::error::Error + Send + Sync>;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not believe changing this will cause any issues or serious performance concerns, but would like to have it double checked

type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
match &mut self.inner {
ChannelOrGrpcOverride::Channel(inner) => inner.poll_ready(cx).map_err(Into::into),
ChannelOrGrpcOverride::GrpcOverride(inner) => inner.poll_ready(cx).map_err(Into::into),
}
}

fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
Expand Down Expand Up @@ -245,7 +266,14 @@ impl Service<http::Request<Body>> for GrpcMetricSvc {
metrics
})
});
let callfut = self.inner.call(req);
let callfut = match &mut self.inner {
ChannelOrGrpcOverride::Channel(inner) => {
Either::Left(inner.call(req).map_err(Into::into))
}
ChannelOrGrpcOverride::GrpcOverride(inner) => {
Either::Right(inner.call(req).map_err(Into::into))
}
};
let errcode_label_disabled = self.disable_errcode_label;
async move {
let started = Instant::now();
Expand Down
2 changes: 2 additions & 0 deletions core-c-bridge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ crate-type = ["cdylib"]
[dependencies]
anyhow = "1.0"
async-trait = "0.1"
futures-util = { version = "0.3", default-features = false }
http = "1.1"
libc = "0.2"
prost = { workspace = true }
# We rely on Cargo semver rules not updating a 0.x to 0.y. Per the rand
Expand Down
Loading
Loading