@@ -18,12 +18,122 @@ package com.squareup.wire.internal
1818import com.squareup.wire.GrpcMethod
1919import com.squareup.wire.GrpcServerStreamingCall
2020import com.squareup.wire.GrpcStreamingCall
21+ import com.squareup.wire.MessageSink
2122import com.squareup.wire.MessageSource
23+ import com.squareup.wire.WireGrpcClient
24+ import java.util.concurrent.TimeUnit
2225import kotlinx.coroutines.CoroutineScope
26+ import kotlinx.coroutines.GlobalScope
27+ import kotlinx.coroutines.channels.Channel
2328import kotlinx.coroutines.channels.ReceiveChannel
29+ import kotlinx.coroutines.channels.SendChannel
30+ import kotlinx.coroutines.launch
31+ import kotlinx.coroutines.runBlocking
32+ import okio.ForwardingTimeout
33+ import okio.IOException
2434import okio.Timeout
2535
36+ /* *
37+ * A [GrpcServerStreamingCall] that sends a single non-duplex request and reads a streaming
38+ * response. Using a non-duplex request body ensures the complete request (including END_STREAM) is
39+ * sent to the server before responses are read, avoiding delays on servers that wait for the
40+ * client's half-close before starting to stream responses.
41+ */
2642internal class RealGrpcServerStreamingCall <S : Any , R : Any >(
43+ private val grpcClient : WireGrpcClient ,
44+ override val method : GrpcMethod <S , R >,
45+ ) : GrpcServerStreamingCall<S, R> {
46+
47+ private var call: okhttp3.Call ? = null
48+ private var canceled = false
49+
50+ override val timeout: Timeout = ForwardingTimeout (Timeout ())
51+
52+ init {
53+ timeout.clearTimeout()
54+ timeout.clearDeadline()
55+ }
56+
57+ override var requestMetadata: Map <String , String > = mapOf ()
58+
59+ override var responseMetadata: Map <String , String >? = null
60+ internal set
61+
62+ override fun cancel () {
63+ canceled = true
64+ call?.cancel()
65+ }
66+
67+ override fun isCanceled (): Boolean = canceled || call?.isCanceled() == true
68+
69+ override fun isExecuted (): Boolean = call?.isExecuted() ? : false
70+
71+ override fun clone (): GrpcServerStreamingCall <S , R > {
72+ val result = RealGrpcServerStreamingCall (grpcClient, method)
73+ val oldTimeout = this .timeout
74+ result.timeout.also { newTimeout ->
75+ newTimeout.timeout(oldTimeout.timeoutNanos(), TimeUnit .NANOSECONDS )
76+ if (oldTimeout.hasDeadline()) {
77+ newTimeout.deadlineNanoTime(oldTimeout.deadlineNanoTime())
78+ } else {
79+ newTimeout.clearDeadline()
80+ }
81+ }
82+ result.requestMetadata + = this .requestMetadata
83+ return result
84+ }
85+
86+ override suspend fun executeIn (scope : CoroutineScope , request : S ): ReceiveChannel <R > {
87+ val responseChannel = Channel <R >(1 )
88+ val call = initCall(request)
89+
90+ responseChannel.invokeOnClose { cause ->
91+ if (cause != null ) {
92+ call.cancel()
93+ }
94+ }
95+
96+ call.enqueue(
97+ responseChannel.readFromResponseBodyCallback(
98+ onResponseMetadata = { this .responseMetadata = it },
99+ responseAdapter = method.responseAdapter,
100+ ),
101+ )
102+
103+ return responseChannel
104+ }
105+
106+ override fun executeBlocking (request : S ): MessageSource <R > {
107+ val call = initCall(request)
108+ val messageSource = BlockingMessageSource (
109+ onResponseMetadata = { this .responseMetadata = it },
110+ responseAdapter = method.responseAdapter,
111+ call = call,
112+ )
113+ call.enqueue(messageSource.readFromResponseBodyCallback())
114+ return messageSource
115+ }
116+
117+ private fun initCall (request : S ): okhttp3.Call {
118+ check(this .call == null ) { " already executed" }
119+ val requestBody = newRequestBody(
120+ minMessageToCompress = grpcClient.minMessageToCompress,
121+ requestAdapter = method.requestAdapter,
122+ onlyMessage = request,
123+ )
124+ val result = grpcClient.newCall(method, requestMetadata, requestBody, timeout)
125+ this .call = result
126+ if (canceled) result.cancel()
127+ (timeout as ForwardingTimeout ).setDelegate(result.timeout())
128+ return result
129+ }
130+ }
131+
132+ /* *
133+ * Wraps a [GrpcStreamingCall] as a [GrpcServerStreamingCall]. Used for test doubles created via
134+ * [com.squareup.wire.GrpcServerStreamingCall] factory functions in GrpcCalls.
135+ */
136+ internal class GrpcStreamingCallServerStreamingAdapter <S : Any , R : Any >(
27137 private val callDelegate : GrpcStreamingCall <S , R >,
28138 override val method : GrpcMethod <S , R >,
29139) : GrpcServerStreamingCall<S, R> {
@@ -48,7 +158,7 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
48158
49159 override fun isExecuted () = callDelegate.isExecuted()
50160
51- override fun clone () = RealGrpcServerStreamingCall (callDelegate.clone(), method)
161+ override fun clone () = GrpcStreamingCallServerStreamingAdapter (callDelegate.clone(), method)
52162
53163 override suspend fun executeIn (scope : CoroutineScope , request : S ): ReceiveChannel <R > {
54164 val (sendChannel, receiveChannel) = callDelegate.executeIn(scope)
@@ -67,4 +177,125 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
67177 }
68178}
69179
70- internal fun <S : Any , R : Any > GrpcStreamingCall <S , R >.asGrpcServerStreamingCall () = RealGrpcServerStreamingCall (this , method)
180+ internal fun <S : Any , R : Any > GrpcStreamingCall <S , R >.asGrpcServerStreamingCall () = GrpcStreamingCallServerStreamingAdapter (this , method)
181+
182+ /* *
183+ * Wraps a [GrpcServerStreamingCall] as the legacy [GrpcStreamingCall] API. This is used by
184+ * generated clients when explicit streaming call types are disabled.
185+ */
186+ internal class GrpcServerStreamingCallStreamingAdapter <S : Any , R : Any >(
187+ private val callDelegate : GrpcServerStreamingCall <S , R >,
188+ override val method : GrpcMethod <S , R >,
189+ ) : GrpcStreamingCall<S, R> {
190+ private var executed = false
191+
192+ override val timeout: Timeout
193+ get() = callDelegate.timeout
194+
195+ override var requestMetadata: Map <String , String >
196+ get() = callDelegate.requestMetadata
197+ set(value) {
198+ callDelegate.requestMetadata = value
199+ }
200+
201+ override val responseMetadata: Map <String , String >?
202+ get() = callDelegate.responseMetadata
203+
204+ override fun cancel () {
205+ callDelegate.cancel()
206+ }
207+
208+ override fun isCanceled () = callDelegate.isCanceled()
209+
210+ @Suppress(" OPT_IN_USAGE" , " OVERRIDE_DEPRECATION" )
211+ override fun execute (): Pair <SendChannel <S >, ReceiveChannel<R>> = executeIn(GlobalScope )
212+
213+ override fun executeIn (scope : CoroutineScope ): Pair <SendChannel <S >, ReceiveChannel<R>> = executeWithChannels(scope)
214+
215+ @Suppress(" OPT_IN_USAGE" )
216+ override fun executeBlocking (): Pair <MessageSink <S >, MessageSource<R>> {
217+ val (requestChannel, responseChannel) = executeWithChannels(GlobalScope )
218+ return requestChannel.toMessageSink() to responseChannel.toMessageSource()
219+ }
220+
221+ override fun isExecuted () = executed || callDelegate.isExecuted()
222+
223+ override fun clone () = GrpcServerStreamingCallStreamingAdapter (callDelegate.clone(), method)
224+
225+ private fun executeWithChannels (scope : CoroutineScope ): Pair <Channel <S >, Channel<R>> {
226+ check(! executed) { " already executed" }
227+ executed = true
228+
229+ val requestChannel = Channel <S >(1 )
230+ val responseChannel = Channel <R >(1 )
231+ var delegateResponseChannel: ReceiveChannel <R >? = null
232+
233+ responseChannel.invokeOnClose { cause ->
234+ if (cause != null ) {
235+ requestChannel.cancel()
236+ delegateResponseChannel?.cancel()
237+ callDelegate.cancel()
238+ }
239+ }
240+
241+ scope.launch {
242+ try {
243+ val requestResult = requestChannel.receiveCatching()
244+ requestResult.exceptionOrNull()?.let { throw it }
245+ val request = requestResult.getOrNull()
246+ ? : throw ProtocolException (" expected 1 message but got none" )
247+ requestChannel.close()
248+ val responses = callDelegate.executeIn(scope, request)
249+ delegateResponseChannel = responses
250+ for (response in responses) {
251+ responseChannel.send(response)
252+ }
253+ responseChannel.close()
254+ } catch (e: Throwable ) {
255+ responseChannel.close(e)
256+ }
257+ }
258+
259+ return requestChannel to responseChannel
260+ }
261+ }
262+
263+ internal fun <S : Any , R : Any > GrpcServerStreamingCall <S , R >.asGrpcStreamingCall () = GrpcServerStreamingCallStreamingAdapter (this , method)
264+
265+ private fun <E : Any > Channel<E>.toMessageSource () = object : MessageSource <E > {
266+ override fun read (): E ? = runBlocking {
267+ try {
268+ val result = receiveCatching()
269+ result.exceptionOrNull()?.let { throw it }
270+ result.getOrNull()
271+ } catch (e: Throwable ) {
272+ throw e.toIOException()
273+ }
274+ }
275+
276+ override fun close () {
277+ cancel()
278+ }
279+ }
280+
281+ private fun <E : Any > Channel<E>.toMessageSink () = object : MessageSink <E > {
282+ override fun write (message : E ) {
283+ runBlocking {
284+ try {
285+ send(message)
286+ } catch (e: Throwable ) {
287+ throw e.toIOException()
288+ }
289+ }
290+ }
291+
292+ override fun cancel () {
293+ this @toMessageSink.cancel()
294+ }
295+
296+ override fun close () {
297+ this @toMessageSink.close()
298+ }
299+ }
300+
301+ private fun Throwable.toIOException () = this as ? IOException ? : IOException (this )
0 commit comments