diff --git a/client/src/callback_based.rs b/client/src/callback_based.rs new file mode 100644 index 000000000..9fea36ed5 --- /dev/null +++ b/client/src/callback_based.rs @@ -0,0 +1,123 @@ +//! This module implements support for callback-based gRPC service that has a callback invoked for +//! every gRPC call instead of directly using the network. + +use anyhow::anyhow; +use bytes::{BufMut, BytesMut}; +use futures_util::future::BoxFuture; +use futures_util::stream; +use http::{HeaderMap, Request, Response}; +use http_body_util::{BodyExt, StreamBody, combinators::BoxBody}; +use hyper::body::{Bytes, Frame}; +use std::{ + sync::Arc, + task::{Context, Poll}, +}; +use tonic::{Status, metadata::GRPC_CONTENT_TYPE}; +use tower::Service; + +/// gRPC request for use by a callback. +pub struct GrpcRequest { + /// Fully qualified gRPC service name. + pub service: String, + /// RPC name. + pub rpc: String, + /// Request headers. + pub headers: HeaderMap, + /// Protobuf bytes of the request. + pub proto: Bytes, +} + +/// Successful gRPC response returned by a callback. +pub struct GrpcSuccessResponse { + /// Response headers. + pub headers: HeaderMap, + + /// Response proto bytes. + pub proto: Vec, +} + +/// gRPC service that invokes the given callback on each call. +#[derive(Clone)] +pub struct CallbackBasedGrpcService { + /// Callback to invoke on each RPC call. + #[allow(clippy::type_complexity)] // Signature is not that complex + pub callback: Arc< + dyn Fn(GrpcRequest) -> BoxFuture<'static, Result> + + Send + + Sync, + >, +} + +impl Service> for CallbackBasedGrpcService { + type Response = http::Response; + type Error = anyhow::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let callback = self.callback.clone(); + + Box::pin(async move { + // Build req + let (parts, body) = req.into_parts(); + let mut path_parts = parts.uri.path().trim_start_matches('/').split('/'); + let req_body = body.collect().await.map_err(|e| anyhow!(e))?.to_bytes(); + // Body is flag saying whether compressed (we do not support that), then 32-bit length, + // then the actual proto. + if req_body.len() < 5 { + return Err(anyhow!("Too few request bytes: {}", req_body.len())); + } else if req_body[0] != 0 { + return Err(anyhow!("Compression not supported")); + } + let req_proto_len = + u32::from_be_bytes([req_body[1], req_body[2], req_body[3], req_body[4]]) as usize; + if req_body.len() < 5 + req_proto_len { + return Err(anyhow!( + "Expected request body length at least {}, got {}", + 5 + req_proto_len, + req_body.len() + )); + } + let req = GrpcRequest { + service: path_parts.next().unwrap_or_default().to_owned(), + rpc: path_parts.next().unwrap_or_default().to_owned(), + headers: parts.headers, + proto: req_body.slice(5..5 + req_proto_len), + }; + + // Invoke and handle response + match (callback)(req).await { + Ok(success) => { + // Create body bytes which requires a flag saying whether compressed, then + // message len, then actual message. So we create a Bytes for those 5 prepend + // parts, then stream it alongside the user-provided Vec. This allows us to + // avoid copying the vec + let mut body_prepend = BytesMut::with_capacity(5); + body_prepend.put_u8(0); // 0 means no compression + body_prepend.put_u32(success.proto.len() as u32); + let stream = stream::iter(vec![ + Ok::<_, Status>(Frame::data(Bytes::from(body_prepend))), + Ok::<_, Status>(Frame::data(Bytes::from(success.proto))), + ]); + let stream_body = StreamBody::new(stream); + let full_body = BoxBody::new(stream_body).boxed(); + + // Build response appending headers + let mut resp_builder = Response::builder() + .status(200) + .header(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE); + for (key, value) in success.headers.iter() { + resp_builder = resp_builder.header(key, value); + } + Ok(resp_builder + .body(tonic::body::Body::new(full_body)) + .map_err(|e| anyhow!(e))?) + } + Err(status) => Ok(status.into_http()), + } + }) + } +} diff --git a/client/src/lib.rs b/client/src/lib.rs index 8b0cf6dea..2b713ffa0 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -7,6 +7,7 @@ #[macro_use] extern crate tracing; +pub mod callback_based; mod metrics; mod proxy; mod raw; @@ -35,7 +36,7 @@ pub use workflow_handle::{ }; use crate::{ - metrics::{GrpcMetricSvc, MetricsContext}, + metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext}, raw::{AttachMetricLabels, sealed::RawClientLike}, sealed::WfHandleClient, workflow_handle::UntypedWorkflowHandle, @@ -434,34 +435,59 @@ impl ClientOptions { metrics_meter: Option, ) -> Result>, ClientInitError> { - let channel = Channel::from_shared(self.target_url.to_string())?; - let channel = self.add_tls_to_channel(channel).await?; - let channel = if let Some(keep_alive) = self.keep_alive.as_ref() { - channel - .keep_alive_while_idle(true) - .http2_keep_alive_interval(keep_alive.interval) - .keep_alive_timeout(keep_alive.timeout) - } else { - channel - }; - let channel = if let Some(origin) = self.override_origin.clone() { - channel.origin(origin) - } else { - channel - }; - // If there is a proxy, we have to connect that way - let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() { - proxy.connect_endpoint(&channel).await? - } else { - channel.connect().await? - }; - let service = ServiceBuilder::new() - .layer_fn(move |channel| GrpcMetricSvc { - inner: channel, + self.connect_no_namespace_with_service_override(metrics_meter, None) + .await + } + + /// Attempt to establish a connection to the Temporal server and return a gRPC client which is + /// intercepted with retry, default headers functionality, and metrics if provided. If a + /// service_override is present, network-specific options are ignored and the callback is + /// invoked for each gRPC call. + /// + /// See [RetryClient] for more + pub async fn connect_no_namespace_with_service_override( + &self, + metrics_meter: Option, + service_override: Option, + ) -> Result>, ClientInitError> + { + let service = if let Some(service_override) = service_override { + GrpcMetricSvc { + inner: ChannelOrGrpcOverride::GrpcOverride(service_override), metrics: metrics_meter.clone().map(MetricsContext::new), disable_errcode_label: self.disable_error_code_metric_tags, - }) - .service(channel); + } + } else { + let channel = Channel::from_shared(self.target_url.to_string())?; + let channel = self.add_tls_to_channel(channel).await?; + let channel = if let Some(keep_alive) = self.keep_alive.as_ref() { + channel + .keep_alive_while_idle(true) + .http2_keep_alive_interval(keep_alive.interval) + .keep_alive_timeout(keep_alive.timeout) + } else { + channel + }; + let channel = if let Some(origin) = self.override_origin.clone() { + channel.origin(origin) + } else { + channel + }; + // If there is a proxy, we have to connect that way + let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() { + proxy.connect_endpoint(&channel).await? + } else { + channel.connect().await? + }; + ServiceBuilder::new() + .layer_fn(move |channel| GrpcMetricSvc { + inner: ChannelOrGrpcOverride::Channel(channel), + metrics: metrics_meter.clone().map(MetricsContext::new), + disable_errcode_label: self.disable_error_code_metric_tags, + }) + .service(channel) + }; + let headers = Arc::new(RwLock::new(ClientHeaders { user_headers: self.headers.clone().unwrap_or_default(), api_key: self.api_key.clone(), diff --git a/client/src/metrics.rs b/client/src/metrics.rs index 9ca802258..fa3179011 100644 --- a/client/src/metrics.rs +++ b/client/src/metrics.rs @@ -1,6 +1,9 @@ -use crate::{AttachMetricLabels, CallType, dbg_panic}; +use crate::{AttachMetricLabels, CallType, callback_based, dbg_panic}; +use futures_util::TryFutureExt; +use futures_util::future::Either; use futures_util::{FutureExt, future::BoxFuture}; use std::{ + fmt, sync::Arc, task::{Context, Poll}, time::{Duration, Instant}, @@ -205,19 +208,37 @@ fn code_as_screaming_snake(code: &Code) -> &'static str { /// Implements metrics functionality for gRPC (really, any http) calls #[derive(Debug, Clone)] pub struct GrpcMetricSvc { - pub(crate) inner: Channel, + pub(crate) inner: ChannelOrGrpcOverride, // If set to none, metrics are a no-op pub(crate) metrics: Option, pub(crate) disable_errcode_label: bool, } +#[derive(Clone)] +pub(crate) enum ChannelOrGrpcOverride { + Channel(Channel), + GrpcOverride(callback_based::CallbackBasedGrpcService), +} + +impl fmt::Debug for ChannelOrGrpcOverride { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ChannelOrGrpcOverride::Channel(inner) => fmt::Debug::fmt(inner, f), + ChannelOrGrpcOverride::GrpcOverride(_) => f.write_str(""), + } + } +} + impl Service> for GrpcMetricSvc { type Response = http::Response; - type Error = tonic::transport::Error; + type Error = Box; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(Into::into) + match &mut self.inner { + ChannelOrGrpcOverride::Channel(inner) => inner.poll_ready(cx).map_err(Into::into), + ChannelOrGrpcOverride::GrpcOverride(inner) => inner.poll_ready(cx).map_err(Into::into), + } } fn call(&mut self, mut req: http::Request) -> Self::Future { @@ -245,7 +266,14 @@ impl Service> for GrpcMetricSvc { metrics }) }); - let callfut = self.inner.call(req); + let callfut = match &mut self.inner { + ChannelOrGrpcOverride::Channel(inner) => { + Either::Left(inner.call(req).map_err(Into::into)) + } + ChannelOrGrpcOverride::GrpcOverride(inner) => { + Either::Right(inner.call(req).map_err(Into::into)) + } + }; let errcode_label_disabled = self.disable_errcode_label; async move { let started = Instant::now(); diff --git a/core-c-bridge/Cargo.toml b/core-c-bridge/Cargo.toml index 6c98f4702..a52c5bf3c 100644 --- a/core-c-bridge/Cargo.toml +++ b/core-c-bridge/Cargo.toml @@ -10,6 +10,8 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0" async-trait = "0.1" +futures-util = { version = "0.3", default-features = false } +http = "1.1" libc = "0.2" prost = { workspace = true } # We rely on Cargo semver rules not updating a 0.x to 0.y. Per the rand diff --git a/core-c-bridge/include/temporal-sdk-core-c-bridge.h b/core-c-bridge/include/temporal-sdk-core-c-bridge.h index 728046a12..19ce209ce 100644 --- a/core-c-bridge/include/temporal-sdk-core-c-bridge.h +++ b/core-c-bridge/include/temporal-sdk-core-c-bridge.h @@ -58,6 +58,15 @@ typedef struct TemporalCoreCancellationToken TemporalCoreCancellationToken; typedef struct TemporalCoreClient TemporalCoreClient; +/** + * Representation of gRPC request for the callback. + * + * Note, temporal_core_client_grpc_override_request_respond is effectively the "free" call for + * each request. Each request _must_ call that and the request can no longer be valid after that + * call. + */ +typedef struct TemporalCoreClientGrpcOverrideRequest TemporalCoreClientGrpcOverrideRequest; + typedef struct TemporalCoreEphemeralServer TemporalCoreEphemeralServer; typedef struct TemporalCoreForwardedLog TemporalCoreForwardedLog; @@ -114,6 +123,20 @@ typedef struct TemporalCoreClientHttpConnectProxyOptions { struct TemporalCoreByteArrayRef password; } TemporalCoreClientHttpConnectProxyOptions; +/** + * Callback that is invoked for every gRPC call if set on the client options. + * + * Note, temporal_core_client_grpc_override_request_respond is effectively the "free" call for + * each request. Each request _must_ call that and the request can no longer be valid after that + * call. However, all of that work and the respond call may be done well after this callback + * returns. No data lifetime is related to the callback invocation itself. + * + * Implementers should return as soon as possible and perform the network request in the + * background. + */ +typedef void (*TemporalCoreClientGrpcOverrideCallback)(struct TemporalCoreClientGrpcOverrideRequest *request, + void *user_data); + typedef struct TemporalCoreClientOptions { struct TemporalCoreByteArrayRef target_url; struct TemporalCoreByteArrayRef client_name; @@ -125,6 +148,21 @@ typedef struct TemporalCoreClientOptions { const struct TemporalCoreClientRetryOptions *retry_options; const struct TemporalCoreClientKeepAliveOptions *keep_alive_options; const struct TemporalCoreClientHttpConnectProxyOptions *http_connect_proxy_options; + /** + * If this is set, all gRPC calls go through it and no connection is made to server. The client + * connection call usually calls this for "GetSystemInfo" before the connect is complete. See + * the callback documentation for more important information about usage and data lifetimes. + * + * When a callback is set, target_url is not used to connect, but it must be set to a valid URL + * anyways in case it is used for logging or other reasons. Similarly, other connect-specific + * fields like tls_options, keep_alive_options, and http_connect_proxy_options will be + * completely ignored if a callback is set. + */ + TemporalCoreClientGrpcOverrideCallback grpc_override_callback; + /** + * Optional user data passed to each callback call. + */ + void *grpc_override_callback_user_data; } TemporalCoreClientOptions; typedef struct TemporalCoreByteArray { @@ -147,6 +185,36 @@ typedef void (*TemporalCoreClientConnectCallback)(void *user_data, struct TemporalCoreClient *success, const struct TemporalCoreByteArray *fail); +/** + * Response provided to temporal_core_client_grpc_override_request_respond. All values referenced + * inside here must live until that call returns. + */ +typedef struct TemporalCoreClientGrpcOverrideResponse { + /** + * Numeric gRPC status code, see https://grpc.io/docs/guides/status-codes/. 0 is success, non-0 + * is failure. + */ + int32_t status_code; + /** + * Headers for the response if any. Note, this is meant for user-defined metadata/headers, and + * not the gRPC system headers (like :status or content-type). + */ + TemporalCoreMetadataRef headers; + /** + * Protobuf bytes for a successful response. Ignored if status_code is non-0. + */ + struct TemporalCoreByteArrayRef success_proto; + /** + * UTF-8 failure message. Ignored if status_code is 0. + */ + struct TemporalCoreByteArrayRef fail_message; + /** + * Optional details for the gRPC failure. If non-empty, this should be a protobuf-serialized + * google.rpc.Status. Ignored if status_code is 0. + */ + struct TemporalCoreByteArrayRef fail_details; +} TemporalCoreClientGrpcOverrideResponse; + typedef struct TemporalCoreRpcCallOptions { enum TemporalCoreRpcService service; struct TemporalCoreByteArrayRef rpc; @@ -656,6 +724,43 @@ void temporal_core_client_update_metadata(struct TemporalCoreClient *client, void temporal_core_client_update_api_key(struct TemporalCoreClient *client, struct TemporalCoreByteArrayRef api_key); +/** + * Get a reference to the service name. + * + * Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. + */ +struct TemporalCoreByteArrayRef temporal_core_client_grpc_override_request_service(const struct TemporalCoreClientGrpcOverrideRequest *req); + +/** + * Get a reference to the RPC name. + * + * Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. + */ +struct TemporalCoreByteArrayRef temporal_core_client_grpc_override_request_rpc(const struct TemporalCoreClientGrpcOverrideRequest *req); + +/** + * Get a reference to the service headers. + * + * Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. + */ +TemporalCoreMetadataRef temporal_core_client_grpc_override_request_headers(const struct TemporalCoreClientGrpcOverrideRequest *req); + +/** + * Get a reference to the request protobuf bytes. + * + * Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. + */ +struct TemporalCoreByteArrayRef temporal_core_client_grpc_override_request_proto(const struct TemporalCoreClientGrpcOverrideRequest *req); + +/** + * Complete the request, freeing all request data. + * + * The data referenced in the response must live until this function returns. Once this call is + * made, none of the request data should be considered valid. + */ +void temporal_core_client_grpc_override_request_respond(struct TemporalCoreClientGrpcOverrideRequest *req, + struct TemporalCoreClientGrpcOverrideResponse resp); + /** * Client, options, and user data must live through callback. */ diff --git a/core-c-bridge/src/client.rs b/core-c-bridge/src/client.rs index 49cbe433b..f7e6a7cef 100644 --- a/core-c-bridge/src/client.rs +++ b/core-c-bridge/src/client.rs @@ -2,13 +2,21 @@ use crate::{ ByteArray, ByteArrayRef, CancellationToken, MetadataRef, UserDataHandle, runtime::Runtime, }; -use std::{str::FromStr, time::Duration}; +use futures_util::FutureExt; +use prost::bytes::Bytes; +use std::cell::OnceCell; +use std::str::FromStr; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::time::Duration; use temporal_client::{ ClientKeepAliveConfig, ClientOptions as CoreClientOptions, ClientOptionsBuilder, ClientTlsConfig, CloudService, ConfiguredClient, HealthService, HttpConnectProxyOptions, OperatorService, RetryClient, RetryConfig, TemporalServiceClientWithMetrics, TestService, - TlsConfig, WorkflowService, + TlsConfig, WorkflowService, callback_based, }; +use tokio::sync::oneshot; use tonic::metadata::MetadataKey; use url::Url; @@ -24,6 +32,17 @@ pub struct ClientOptions { pub retry_options: *const ClientRetryOptions, pub keep_alive_options: *const ClientKeepAliveOptions, pub http_connect_proxy_options: *const ClientHttpConnectProxyOptions, + /// If this is set, all gRPC calls go through it and no connection is made to server. The client + /// connection call usually calls this for "GetSystemInfo" before the connect is complete. See + /// the callback documentation for more important information about usage and data lifetimes. + /// + /// When a callback is set, target_url is not used to connect, but it must be set to a valid URL + /// anyways in case it is used for logging or other reasons. Similarly, other connect-specific + /// fields like tls_options, keep_alive_options, and http_connect_proxy_options will be + /// completely ignored if a callback is set. + pub grpc_override_callback: ClientGrpcOverrideCallback, + /// Optional user data passed to each callback call. + pub grpc_override_callback_user_data: *mut libc::c_void, } #[repr(C)] @@ -102,12 +121,19 @@ pub extern "C" fn temporal_core_client_connect( return; } }; + // Create override if present + let service_override = options.grpc_override_callback.map(|cb| { + create_callback_based_grpc_service(runtime, cb, options.grpc_override_callback_user_data) + }); // Spawn async call let user_data = UserDataHandle(user_data); let core = runtime.core.clone(); runtime.core.tokio_handle().spawn(async move { match core_options - .connect_no_namespace(core.telemetry().get_temporal_metric_meter()) + .connect_no_namespace_with_service_override( + core.telemetry().get_temporal_metric_meter(), + service_override, + ) .await { Ok(core) => { @@ -132,6 +158,75 @@ pub extern "C" fn temporal_core_client_connect( }); } +fn create_callback_based_grpc_service( + runtime: &Runtime, + cb: unsafe extern "C" fn(request: *mut ClientGrpcOverrideRequest, user_data: *mut libc::c_void), + user_data: *mut libc::c_void, +) -> callback_based::CallbackBasedGrpcService { + let runtime = runtime.clone(); + let user_data = Arc::new(UserDataHandle(user_data)); + callback_based::CallbackBasedGrpcService { + callback: Arc::new(move |req| { + let runtime = runtime.clone(); + let user_data = user_data.clone(); + async move { + // Create a oneshot sender/receiver for the result + let (sender, receiver) = oneshot::channel(); + + // Create boxed request that is dropped when the caller sets the response. If the + // caller does not, this will be a memory leak. + // + // We have to cast this to a literal pointer integer because we use spawn_blocking + // and Rust can't validate things in either of two approaches. The first approach, + // just moving the *mut to spawn_blocking closure, will not work because it is not + // send (even if you wrap it in a marked-send struct). The second, approach, moving + // the box to the closure and into_raw'ing it there won't work because Rust thinks + // the "req" param to spawn_blocking may outlive this closure even though we're + // confident in our oneshot use this will never happen. + let req_ptr = Box::into_raw(Box::new(ClientGrpcOverrideRequest { + core: req, + built_headers: OnceCell::new(), + response_sender: sender, + })) as usize; + + // We want to make sure it reached user code. If spawn_blocking fails _and_ it + // didn't reach user code, it is on us to drop the box. + let reached_user_code = Arc::new(AtomicBool::new(false)); + + // Spawn the callback as blocking, failing on join failure. We use spawn_blocking + // just in case the user is doing something blocking in their closure, but we ask + // them not to. + let reached_user_code_clone = reached_user_code.clone(); + let spawn_ret = runtime + .core + .tokio_handle() + .spawn_blocking(move || unsafe { + reached_user_code_clone.store(true, Ordering::Relaxed); + cb( + req_ptr as *mut ClientGrpcOverrideRequest, + user_data.clone().0, + ); + }) + .await; + if let Err(err) = spawn_ret { + // Re-own box so it can be dropped if never reached user code + if !reached_user_code.load(Ordering::Relaxed) { + let _ = unsafe { Box::from_raw(req_ptr as *mut ClientGrpcOverrideRequest) }; + } + return Err(tonic::Status::internal(format!("{err}"))); + } + + // Wait result and return. The receiver failure in theory can never happen. If it + // does, it means somehow the sender was dropped, but our code ensures the sender + // is not dropped until a value is sent. That's why we're panicking here instead + // of turning this into a Tonic error. + receiver.await.expect("Unexpected receiver failure") + } + .boxed() + }), + } +} + #[unsafe(no_mangle)] pub extern "C" fn temporal_core_client_free(client: *mut Client) { unsafe { @@ -160,6 +255,164 @@ pub extern "C" fn temporal_core_client_update_api_key(client: *mut Client, api_k .set_api_key(api_key.to_option_string()); } +/// Callback that is invoked for every gRPC call if set on the client options. +/// +/// Note, temporal_core_client_grpc_override_request_respond is effectively the "free" call for +/// each request. Each request _must_ call that and the request can no longer be valid after that +/// call. However, all of that work and the respond call may be done well after this callback +/// returns. No data lifetime is related to the callback invocation itself. +/// +/// Implementers should return as soon as possible and perform the network request in the +/// background. +pub type ClientGrpcOverrideCallback = Option< + unsafe extern "C" fn(request: *mut ClientGrpcOverrideRequest, user_data: *mut libc::c_void), +>; + +/// Representation of gRPC request for the callback. +/// +/// Note, temporal_core_client_grpc_override_request_respond is effectively the "free" call for +/// each request. Each request _must_ call that and the request can no longer be valid after that +/// call. +pub struct ClientGrpcOverrideRequest { + core: callback_based::GrpcRequest, + built_headers: OnceCell, + response_sender: oneshot::Sender>, +} + +// Expected to be passed to user thread +unsafe impl Send for ClientGrpcOverrideRequest {} +unsafe impl Sync for ClientGrpcOverrideRequest {} + +/// Response provided to temporal_core_client_grpc_override_request_respond. All values referenced +/// inside here must live until that call returns. +#[repr(C)] +pub struct ClientGrpcOverrideResponse { + /// Numeric gRPC status code, see https://grpc.io/docs/guides/status-codes/. 0 is success, non-0 + /// is failure. + pub status_code: i32, + + /// Headers for the response if any. Note, this is meant for user-defined metadata/headers, and + /// not the gRPC system headers (like :status or content-type). + pub headers: MetadataRef, + + /// Protobuf bytes for a successful response. Ignored if status_code is non-0. + pub success_proto: ByteArrayRef, + + /// UTF-8 failure message. Ignored if status_code is 0. + pub fail_message: ByteArrayRef, + + /// Optional details for the gRPC failure. If non-empty, this should be a protobuf-serialized + /// google.rpc.Status. Ignored if status_code is 0. + pub fail_details: ByteArrayRef, +} + +/// Get a reference to the service name. +/// +/// Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. +#[unsafe(no_mangle)] +pub extern "C" fn temporal_core_client_grpc_override_request_service( + req: *const ClientGrpcOverrideRequest, +) -> ByteArrayRef { + let req = unsafe { &*req }; + req.core.service.as_str().into() +} + +/// Get a reference to the RPC name. +/// +/// Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. +#[unsafe(no_mangle)] +pub extern "C" fn temporal_core_client_grpc_override_request_rpc( + req: *const ClientGrpcOverrideRequest, +) -> ByteArrayRef { + let req = unsafe { &*req }; + req.core.rpc.as_str().into() +} + +/// Get a reference to the service headers. +/// +/// Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. +#[unsafe(no_mangle)] +pub extern "C" fn temporal_core_client_grpc_override_request_headers( + req: *const ClientGrpcOverrideRequest, +) -> MetadataRef { + let req = unsafe { &*req }; + // Lazily create the headers on first access + let headers = req.built_headers.get_or_init(|| { + req.core + .headers + .iter() + .filter_map(|(name, value)| value.to_str().ok().map(|val| (name.as_str(), val))) + .flat_map(|(k, v)| [k, v]) + .collect::>() + .join("\n") + }); + headers.as_str().into() +} + +/// Get a reference to the request protobuf bytes. +/// +/// Note, this is only valid until temporal_core_client_grpc_override_request_respond is called. +#[unsafe(no_mangle)] +pub extern "C" fn temporal_core_client_grpc_override_request_proto( + req: *const ClientGrpcOverrideRequest, +) -> ByteArrayRef { + let req = unsafe { &*req }; + (&*req.core.proto).into() +} + +/// Complete the request, freeing all request data. +/// +/// The data referenced in the response must live until this function returns. Once this call is +/// made, none of the request data should be considered valid. +#[unsafe(no_mangle)] +pub extern "C" fn temporal_core_client_grpc_override_request_respond( + req: *mut ClientGrpcOverrideRequest, + resp: ClientGrpcOverrideResponse, +) { + // This will be dropped at the end of this call + let req = unsafe { Box::from_raw(req) }; + // Ignore failure if receiver no longer around (e.g. maybe a cancellation) + let _ = req + .response_sender + .send(resp.build_grpc_override_response()); +} + +impl ClientGrpcOverrideResponse { + #[allow(clippy::result_large_err)] // Tonic status, even though big, is reasonable as an Err + fn build_grpc_override_response( + self, + ) -> Result { + let headers = Self::client_headers_from_metadata_ref(self.headers) + .map_err(tonic::Status::internal)?; + if self.status_code == 0 { + Ok(callback_based::GrpcSuccessResponse { + headers, + proto: self.success_proto.to_vec(), + }) + } else { + Err(tonic::Status::with_details_and_metadata( + tonic::Code::from_i32(self.status_code), + self.fail_message.to_string(), + Bytes::copy_from_slice(self.fail_details.to_slice()), + tonic::metadata::MetadataMap::from_headers(headers), + )) + } + } + + fn client_headers_from_metadata_ref(headers: MetadataRef) -> Result { + let key_values = headers.to_str_map_on_newlines(); + let mut header_map = http::HeaderMap::with_capacity(key_values.len()); + for (k, v) in key_values.into_iter() { + let name = http::HeaderName::try_from(k) + .map_err(|e| format!("Invalid header name '{k}': {e}"))?; + let value = http::HeaderValue::from_str(v) + .map_err(|e| format!("Invalid header value '{v}': {e}"))?; + header_map.insert(name, value); + } + Ok(header_map) + } +} + #[repr(C)] pub struct RpcCallOptions { pub service: RpcService, diff --git a/core-c-bridge/src/tests/context.rs b/core-c-bridge/src/tests/context.rs index c011a3f6e..95b223d58 100644 --- a/core-c-bridge/src/tests/context.rs +++ b/core-c-bridge/src/tests/context.rs @@ -277,6 +277,15 @@ impl Context { } pub fn client_connect(self: &Arc, options: Box) -> anyhow::Result<()> { + Self::client_connect_with_override(self, options, None, std::ptr::null_mut()) + } + + pub fn client_connect_with_override( + self: &Arc, + options: Box, + grpc_override_callback: crate::client::ClientGrpcOverrideCallback, + grpc_override_callback_user_data: *mut libc::c_void, + ) -> anyhow::Result<()> { let metadata = options .headers .as_ref() @@ -339,6 +348,8 @@ impl Context { retry_options: &*retry_options, keep_alive_options: pointer_or_null(keep_alive_options.as_deref()), http_connect_proxy_options: pointer_or_null(proxy_options.as_deref()), + grpc_override_callback, + grpc_override_callback_user_data, }); let client_options_ptr = &*client_options as *const _; diff --git a/core-c-bridge/src/tests/mod.rs b/core-c-bridge/src/tests/mod.rs index 0d73cbda4..8f1c3ceb9 100644 --- a/core-c-bridge/src/tests/mod.rs +++ b/core-c-bridge/src/tests/mod.rs @@ -1,12 +1,22 @@ -use crate::client::RpcService; +use crate::ByteArrayRef; +use crate::client::{ + ClientGrpcOverrideRequest, ClientGrpcOverrideResponse, RpcService, + temporal_core_client_grpc_override_request_headers, + temporal_core_client_grpc_override_request_proto, + temporal_core_client_grpc_override_request_respond, + temporal_core_client_grpc_override_request_rpc, + temporal_core_client_grpc_override_request_service, +}; use crate::tests::utils::{ OwnedRpcCallOptions, RpcCallError, default_client_options, default_server_config, }; use context::Context; use prost::Message; -use std::sync::Arc; +use std::sync::{Arc, LazyLock, Mutex}; +use temporal_sdk_core_protos::temporal::api::failure::v1::Failure; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::{ - GetSystemInfoRequest, GetSystemInfoResponse, + GetSystemInfoRequest, GetSystemInfoResponse, QueryWorkflowRequest, + StartWorkflowExecutionRequest, StartWorkflowExecutionResponse, }; mod context; @@ -176,3 +186,169 @@ fn test_all_rpc_calls_exist() { )); }); } + +static CALLBACK_OVERRIDE_CALLS: LazyLock>> = + LazyLock::new(|| Mutex::new(Vec::new())); + +struct ClientOverrideError { + message: String, + details: Option>, +} + +impl ClientOverrideError { + pub fn new(message: String) -> Self { + Self { + message, + details: None, + } + } +} + +unsafe extern "C" fn callback_override( + req: *mut ClientGrpcOverrideRequest, + user_data: *mut libc::c_void, +) { + let mut calls = CALLBACK_OVERRIDE_CALLS.lock().unwrap(); + + // Simple header check to confirm headers are working + let headers = + temporal_core_client_grpc_override_request_headers(req).to_string_map_on_newlines(); + assert!(headers.get("content-type").unwrap().as_str() == "application/grpc"); + + // Confirm user data is as we expect + let user_data: &String = unsafe { &*(user_data as *const String) }; + assert!(user_data.as_str() == "some-user-data"); + + calls.push(format!( + "service: {}, rpc: {}", + temporal_core_client_grpc_override_request_service(req).to_string(), + temporal_core_client_grpc_override_request_rpc(req).to_string() + )); + let resp_raw = match temporal_core_client_grpc_override_request_rpc(req).to_str() { + "GetSystemInfo" => Ok(GetSystemInfoResponse::default().encode_to_vec()), + "StartWorkflowExecution" => match StartWorkflowExecutionRequest::decode( + temporal_core_client_grpc_override_request_proto(req).to_slice(), + ) { + Ok(req) => Ok(StartWorkflowExecutionResponse { + run_id: format!("run-id for {}", req.workflow_id), + ..Default::default() + } + .encode_to_vec()), + Err(err) => Err(ClientOverrideError::new(format!("Bad bytes: {err}"))), + }, + "QueryWorkflow" => match QueryWorkflowRequest::decode( + temporal_core_client_grpc_override_request_proto(req).to_slice(), + ) { + // Demonstrate fail details + Ok(_) => Err(ClientOverrideError { + message: "query-fail".to_string(), + details: Some( + Failure { + message: "intentional failure".to_string(), + ..Default::default() + } + .encode_to_vec(), + ), + }), + Err(err) => Err(ClientOverrideError::new(format!("Bad bytes: {err}"))), + }, + v => Err(ClientOverrideError::new(format!("Unknown RPC: {v}"))), + }; + // It is very important that we borrow not move resp_raw here. If this were a "match resp_raw" + // without the &, ownership of bytes moves into the match arm and can drop bytes after the + // match arm. This is why we have a "let _" below for resp_raw to ensure a developed doesn't + // accidentally move it. + let resp = match &resp_raw { + Ok(bytes) => ClientGrpcOverrideResponse { + status_code: 0, + headers: ByteArrayRef::empty(), + success_proto: bytes.as_slice().into(), + fail_message: ByteArrayRef::empty(), + fail_details: ByteArrayRef::empty(), + }, + Err(err) => ClientGrpcOverrideResponse { + status_code: tonic::Code::Internal.into(), + headers: ByteArrayRef::empty(), + success_proto: ByteArrayRef::empty(), + fail_message: err.message.as_str().into(), + fail_details: if let Some(details) = &err.details { + details.as_slice().into() + } else { + ByteArrayRef::empty() + }, + }, + }; + temporal_core_client_grpc_override_request_respond(req, resp); + let _ = resp_raw; +} + +#[test] +fn test_simple_callback_override() { + Context::with(|context| { + let mut user_data = "some-user-data".to_owned(); + context.runtime_new().unwrap(); + // Create client which will invoke GetSystemInfo + context + .client_connect_with_override( + Box::new(default_client_options("127.0.0.1:4567")), + Some(callback_override), + &mut user_data as *mut String as *mut libc::c_void, + ) + .unwrap(); + + // Invoke start workflow so we can confirm complex proto in/out + let start_resp_raw = context + .rpc_call(Box::new(OwnedRpcCallOptions { + service: RpcService::Workflow, + rpc: "StartWorkflowExecution".into(), + req: StartWorkflowExecutionRequest { + workflow_id: "my-workflow-id".into(), + ..Default::default() + } + .encode_to_vec(), + retry: false, + metadata: None, + timeout_millis: 0, + cancellation_token: None, + })) + .unwrap(); + let start_resp = StartWorkflowExecutionResponse::decode(&*start_resp_raw).unwrap(); + assert!(start_resp.run_id == "run-id for my-workflow-id"); + + // Try a query where a query failure will actually be delivered as failure details. + // However, we don't currently have temporal_sdk_core_protos::google::rpc::Status in + // the crate, so we'll just use the details directly even though a proper gRPC + // implementation will only provide a google.rpc.Status proto. + let query_err = context + .rpc_call(Box::new(OwnedRpcCallOptions { + service: RpcService::Workflow, + rpc: "QueryWorkflow".into(), + req: QueryWorkflowRequest::default().encode_to_vec(), + retry: false, + metadata: None, + timeout_millis: 0, + cancellation_token: None, + })) + .unwrap_err() + .downcast::() + .unwrap(); + assert!(query_err.status_code == tonic::Code::Internal as u32); + assert!(query_err.message == "query-fail"); + assert!( + Failure::decode(query_err.details.as_ref().unwrap().as_slice()) + .unwrap() + .message + == "intentional failure" + ); + + // Confirm we got the expected calls + assert!( + *CALLBACK_OVERRIDE_CALLS.lock().unwrap() + == vec![ + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: GetSystemInfo", + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: StartWorkflowExecution", + "service: temporal.api.workflowservice.v1.WorkflowService, rpc: QueryWorkflow" + ] + ); + }); +}