Skip to content

Commit bb12bfc

Browse files
committed
Avoid race in duplex pipe for streaming calls
1 parent 5315939 commit bb12bfc

10 files changed

Lines changed: 373 additions & 18 deletions

File tree

wire-grpc-client/api/wire-grpc-client.api

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,14 @@ public final class com/squareup/wire/GrpcHttpUrlKt {
8181

8282
public final class com/squareup/wire/GrpcMethod {
8383
public fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;)V
84+
public fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;Z)V
85+
public fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;ZZ)V
86+
public synthetic fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V
8487
public final fun getPath ()Ljava/lang/String;
8588
public final fun getRequestAdapter ()Lcom/squareup/wire/ProtoAdapter;
89+
public final fun getRequestStreaming ()Z
8690
public final fun getResponseAdapter ()Lcom/squareup/wire/ProtoAdapter;
91+
public final fun getResponseStreaming ()Z
8792
}
8893

8994
public abstract interface class com/squareup/wire/GrpcServerStreamingCall {

wire-grpc-client/src/commonMain/kotlin/com/squareup/wire/GrpcMethod.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
*/
1616
package com.squareup.wire
1717

18-
class GrpcMethod<S : Any, R : Any>(
18+
import kotlin.jvm.JvmOverloads
19+
20+
class GrpcMethod<S : Any, R : Any> @JvmOverloads constructor(
1921
val path: String,
2022
val requestAdapter: ProtoAdapter<S>,
2123
val responseAdapter: ProtoAdapter<R>,
24+
val requestStreaming: Boolean = false,
25+
val responseStreaming: Boolean = false,
2226
)

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/GrpcClient.kt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
package com.squareup.wire
1717

1818
import com.squareup.wire.internal.RealGrpcCall
19+
import com.squareup.wire.internal.RealGrpcServerStreamingCall
1920
import com.squareup.wire.internal.RealGrpcStreamingCall
2021
import com.squareup.wire.internal.asGrpcClientStreamingCall
21-
import com.squareup.wire.internal.asGrpcServerStreamingCall
22+
import com.squareup.wire.internal.asGrpcStreamingCall
2223
import java.util.concurrent.TimeUnit
2324
import kotlin.reflect.KClass
2425
import okhttp3.Call
@@ -183,9 +184,13 @@ internal class WireGrpcClient internal constructor(
183184
) : GrpcClient() {
184185
override fun <S : Any, R : Any> newCall(method: GrpcMethod<S, R>): GrpcCall<S, R> = RealGrpcCall(this, method)
185186

186-
override fun <S : Any, R : Any> newStreamingCall(method: GrpcMethod<S, R>): GrpcStreamingCall<S, R> = RealGrpcStreamingCall(this, method)
187+
override fun <S : Any, R : Any> newStreamingCall(method: GrpcMethod<S, R>): GrpcStreamingCall<S, R> = if (!method.requestStreaming && method.responseStreaming) {
188+
RealGrpcServerStreamingCall(this, method).asGrpcStreamingCall()
189+
} else {
190+
RealGrpcStreamingCall(this, method)
191+
}
187192

188193
override fun <S : Any, R : Any> newClientStreamingCall(method: GrpcMethod<S, R>): GrpcClientStreamingCall<S, R> = RealGrpcStreamingCall(this, method).asGrpcClientStreamingCall()
189194

190-
override fun <S : Any, R : Any> newServerStreamingCall(method: GrpcMethod<S, R>): GrpcServerStreamingCall<S, R> = RealGrpcStreamingCall(this, method).asGrpcServerStreamingCall()
195+
override fun <S : Any, R : Any> newServerStreamingCall(method: GrpcMethod<S, R>): GrpcServerStreamingCall<S, R> = RealGrpcServerStreamingCall(this, method)
191196
}

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/BlockingMessageSource.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import okio.IOException
3333
* * Complete: enqueued when the stream completes normally.
3434
*/
3535
internal class BlockingMessageSource<R : Any>(
36-
val grpcCall: RealGrpcStreamingCall<*, R>,
36+
val onResponseMetadata: (Map<String, String>) -> Unit,
3737
val responseAdapter: ProtoAdapter<R>,
3838
val call: Call,
3939
) : MessageSource<R> {
@@ -66,7 +66,7 @@ internal class BlockingMessageSource<R : Any>(
6666

6767
override fun onResponse(call: Call, response: Response) {
6868
try {
69-
grpcCall.responseMetadata = response.headers.toMap()
69+
onResponseMetadata(response.headers.toMap())
7070
response.use {
7171
response.messageSource(responseAdapter).use { reader ->
7272
while (true) {

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcServerStreamingCall.kt

Lines changed: 233 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,122 @@ package com.squareup.wire.internal
1818
import com.squareup.wire.GrpcMethod
1919
import com.squareup.wire.GrpcServerStreamingCall
2020
import com.squareup.wire.GrpcStreamingCall
21+
import com.squareup.wire.MessageSink
2122
import com.squareup.wire.MessageSource
23+
import com.squareup.wire.WireGrpcClient
24+
import java.util.concurrent.TimeUnit
2225
import kotlinx.coroutines.CoroutineScope
26+
import kotlinx.coroutines.GlobalScope
27+
import kotlinx.coroutines.channels.Channel
2328
import kotlinx.coroutines.channels.ReceiveChannel
29+
import kotlinx.coroutines.channels.SendChannel
30+
import kotlinx.coroutines.launch
31+
import kotlinx.coroutines.runBlocking
32+
import okio.ForwardingTimeout
33+
import okio.IOException
2434
import okio.Timeout
2535

36+
/**
37+
* A [GrpcServerStreamingCall] that sends a single non-duplex request and reads a streaming
38+
* response. Using a non-duplex request body ensures the complete request (including END_STREAM) is
39+
* sent to the server before responses are read, avoiding delays on servers that wait for the
40+
* client's half-close before starting to stream responses.
41+
*/
2642
internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
43+
private val grpcClient: WireGrpcClient,
44+
override val method: GrpcMethod<S, R>,
45+
) : GrpcServerStreamingCall<S, R> {
46+
47+
private var call: okhttp3.Call? = null
48+
private var canceled = false
49+
50+
override val timeout: Timeout = ForwardingTimeout(Timeout())
51+
52+
init {
53+
timeout.clearTimeout()
54+
timeout.clearDeadline()
55+
}
56+
57+
override var requestMetadata: Map<String, String> = mapOf()
58+
59+
override var responseMetadata: Map<String, String>? = null
60+
internal set
61+
62+
override fun cancel() {
63+
canceled = true
64+
call?.cancel()
65+
}
66+
67+
override fun isCanceled(): Boolean = canceled || call?.isCanceled() == true
68+
69+
override fun isExecuted(): Boolean = call?.isExecuted() ?: false
70+
71+
override fun clone(): GrpcServerStreamingCall<S, R> {
72+
val result = RealGrpcServerStreamingCall(grpcClient, method)
73+
val oldTimeout = this.timeout
74+
result.timeout.also { newTimeout ->
75+
newTimeout.timeout(oldTimeout.timeoutNanos(), TimeUnit.NANOSECONDS)
76+
if (oldTimeout.hasDeadline()) {
77+
newTimeout.deadlineNanoTime(oldTimeout.deadlineNanoTime())
78+
} else {
79+
newTimeout.clearDeadline()
80+
}
81+
}
82+
result.requestMetadata += this.requestMetadata
83+
return result
84+
}
85+
86+
override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
87+
val responseChannel = Channel<R>(1)
88+
val call = initCall(request)
89+
90+
responseChannel.invokeOnClose { cause ->
91+
if (cause != null) {
92+
call.cancel()
93+
}
94+
}
95+
96+
call.enqueue(
97+
responseChannel.readFromResponseBodyCallback(
98+
onResponseMetadata = { this.responseMetadata = it },
99+
responseAdapter = method.responseAdapter,
100+
),
101+
)
102+
103+
return responseChannel
104+
}
105+
106+
override fun executeBlocking(request: S): MessageSource<R> {
107+
val call = initCall(request)
108+
val messageSource = BlockingMessageSource(
109+
onResponseMetadata = { this.responseMetadata = it },
110+
responseAdapter = method.responseAdapter,
111+
call = call,
112+
)
113+
call.enqueue(messageSource.readFromResponseBodyCallback())
114+
return messageSource
115+
}
116+
117+
private fun initCall(request: S): okhttp3.Call {
118+
check(this.call == null) { "already executed" }
119+
val requestBody = newRequestBody(
120+
minMessageToCompress = grpcClient.minMessageToCompress,
121+
requestAdapter = method.requestAdapter,
122+
onlyMessage = request,
123+
)
124+
val result = grpcClient.newCall(method, requestMetadata, requestBody, timeout)
125+
this.call = result
126+
if (canceled) result.cancel()
127+
(timeout as ForwardingTimeout).setDelegate(result.timeout())
128+
return result
129+
}
130+
}
131+
132+
/**
133+
* Wraps a [GrpcStreamingCall] as a [GrpcServerStreamingCall]. Used for test doubles created via
134+
* [com.squareup.wire.GrpcServerStreamingCall] factory functions in GrpcCalls.
135+
*/
136+
internal class GrpcStreamingCallServerStreamingAdapter<S : Any, R : Any>(
27137
private val callDelegate: GrpcStreamingCall<S, R>,
28138
override val method: GrpcMethod<S, R>,
29139
) : GrpcServerStreamingCall<S, R> {
@@ -48,7 +158,7 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
48158

49159
override fun isExecuted() = callDelegate.isExecuted()
50160

51-
override fun clone() = RealGrpcServerStreamingCall(callDelegate.clone(), method)
161+
override fun clone() = GrpcStreamingCallServerStreamingAdapter(callDelegate.clone(), method)
52162

53163
override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
54164
val (sendChannel, receiveChannel) = callDelegate.executeIn(scope)
@@ -67,4 +177,125 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
67177
}
68178
}
69179

70-
internal fun <S : Any, R : Any> GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() = RealGrpcServerStreamingCall(this, method)
180+
internal fun <S : Any, R : Any> GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() = GrpcStreamingCallServerStreamingAdapter(this, method)
181+
182+
/**
183+
* Wraps a [GrpcServerStreamingCall] as the legacy [GrpcStreamingCall] API. This is used by
184+
* generated clients when explicit streaming call types are disabled.
185+
*/
186+
internal class GrpcServerStreamingCallStreamingAdapter<S : Any, R : Any>(
187+
private val callDelegate: GrpcServerStreamingCall<S, R>,
188+
override val method: GrpcMethod<S, R>,
189+
) : GrpcStreamingCall<S, R> {
190+
private var executed = false
191+
192+
override val timeout: Timeout
193+
get() = callDelegate.timeout
194+
195+
override var requestMetadata: Map<String, String>
196+
get() = callDelegate.requestMetadata
197+
set(value) {
198+
callDelegate.requestMetadata = value
199+
}
200+
201+
override val responseMetadata: Map<String, String>?
202+
get() = callDelegate.responseMetadata
203+
204+
override fun cancel() {
205+
callDelegate.cancel()
206+
}
207+
208+
override fun isCanceled() = callDelegate.isCanceled()
209+
210+
@Suppress("OPT_IN_USAGE", "OVERRIDE_DEPRECATION")
211+
override fun execute(): Pair<SendChannel<S>, ReceiveChannel<R>> = executeIn(GlobalScope)
212+
213+
override fun executeIn(scope: CoroutineScope): Pair<SendChannel<S>, ReceiveChannel<R>> = executeWithChannels(scope)
214+
215+
@Suppress("OPT_IN_USAGE")
216+
override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> {
217+
val (requestChannel, responseChannel) = executeWithChannels(GlobalScope)
218+
return requestChannel.toMessageSink() to responseChannel.toMessageSource()
219+
}
220+
221+
override fun isExecuted() = executed || callDelegate.isExecuted()
222+
223+
override fun clone() = GrpcServerStreamingCallStreamingAdapter(callDelegate.clone(), method)
224+
225+
private fun executeWithChannels(scope: CoroutineScope): Pair<Channel<S>, Channel<R>> {
226+
check(!executed) { "already executed" }
227+
executed = true
228+
229+
val requestChannel = Channel<S>(1)
230+
val responseChannel = Channel<R>(1)
231+
var delegateResponseChannel: ReceiveChannel<R>? = null
232+
233+
responseChannel.invokeOnClose { cause ->
234+
if (cause != null) {
235+
requestChannel.cancel()
236+
delegateResponseChannel?.cancel()
237+
callDelegate.cancel()
238+
}
239+
}
240+
241+
scope.launch {
242+
try {
243+
val requestResult = requestChannel.receiveCatching()
244+
requestResult.exceptionOrNull()?.let { throw it }
245+
val request = requestResult.getOrNull()
246+
?: throw ProtocolException("expected 1 message but got none")
247+
requestChannel.close()
248+
val responses = callDelegate.executeIn(scope, request)
249+
delegateResponseChannel = responses
250+
for (response in responses) {
251+
responseChannel.send(response)
252+
}
253+
responseChannel.close()
254+
} catch (e: Throwable) {
255+
responseChannel.close(e)
256+
}
257+
}
258+
259+
return requestChannel to responseChannel
260+
}
261+
}
262+
263+
internal fun <S : Any, R : Any> GrpcServerStreamingCall<S, R>.asGrpcStreamingCall() = GrpcServerStreamingCallStreamingAdapter(this, method)
264+
265+
private fun <E : Any> Channel<E>.toMessageSource() = object : MessageSource<E> {
266+
override fun read(): E? = runBlocking {
267+
try {
268+
val result = receiveCatching()
269+
result.exceptionOrNull()?.let { throw it }
270+
result.getOrNull()
271+
} catch (e: Throwable) {
272+
throw e.toIOException()
273+
}
274+
}
275+
276+
override fun close() {
277+
cancel()
278+
}
279+
}
280+
281+
private fun <E : Any> Channel<E>.toMessageSink() = object : MessageSink<E> {
282+
override fun write(message: E) {
283+
runBlocking {
284+
try {
285+
send(message)
286+
} catch (e: Throwable) {
287+
throw e.toIOException()
288+
}
289+
}
290+
}
291+
292+
override fun cancel() {
293+
this@toMessageSink.cancel()
294+
}
295+
296+
override fun close() {
297+
this@toMessageSink.close()
298+
}
299+
}
300+
301+
private fun Throwable.toIOException() = this as? IOException ?: IOException(this)

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcStreamingCall.kt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,23 @@ internal class RealGrpcStreamingCall<S : Any, R : Any>(
9090
callForCancel = call,
9191
)
9292
}
93-
call.enqueue(responseChannel.readFromResponseBodyCallback(this, method.responseAdapter))
93+
call.enqueue(
94+
responseChannel.readFromResponseBodyCallback(
95+
onResponseMetadata = { this.responseMetadata = it },
96+
responseAdapter = method.responseAdapter,
97+
),
98+
)
9499

95100
return requestChannel to responseChannel
96101
}
97102

98103
override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> {
99104
val call = initCall()
100-
val messageSource = BlockingMessageSource(this, method.responseAdapter, call)
105+
val messageSource = BlockingMessageSource(
106+
onResponseMetadata = { this.responseMetadata = it },
107+
responseAdapter = method.responseAdapter,
108+
call = call,
109+
)
101110
val messageSink = requestBody.messageSink(
102111
minMessageToCompress = grpcClient.minMessageToCompress,
103112
requestAdapter = method.requestAdapter,

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/grpc.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ internal fun <S : Any> PipeDuplexRequestBody.messageSink(
8383
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
8484
grpcCall: RealGrpcStreamingCall<*, R>,
8585
responseAdapter: ProtoAdapter<R>,
86+
): Callback = readFromResponseBodyCallback(
87+
onResponseMetadata = { grpcCall.responseMetadata = it },
88+
responseAdapter = responseAdapter,
89+
)
90+
91+
/** Sends the response messages to the channel. */
92+
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
93+
onResponseMetadata: (Map<String, String>) -> Unit,
94+
responseAdapter: ProtoAdapter<R>,
8695
): Callback {
8796
return object : Callback {
8897
override fun onFailure(call: Call, e: IOException) {
@@ -91,7 +100,7 @@ internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
91100
}
92101

93102
override fun onResponse(call: Call, response: Response) {
94-
grpcCall.responseMetadata = response.headers.toMap()
103+
onResponseMetadata(response.headers.toMap())
95104
runBlocking {
96105
response.use {
97106
val messageSource = try {

0 commit comments

Comments
 (0)