Skip to content

Commit ba9575f

Browse files
committed
Avoid race in duplex pipe for streaming calls
1 parent fb8eefc commit ba9575f

6 files changed

Lines changed: 172 additions & 10 deletions

File tree

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/GrpcClient.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
package com.squareup.wire
1717

1818
import com.squareup.wire.internal.RealGrpcCall
19+
import com.squareup.wire.internal.RealGrpcServerStreamingCall
1920
import com.squareup.wire.internal.RealGrpcStreamingCall
2021
import com.squareup.wire.internal.asGrpcClientStreamingCall
21-
import com.squareup.wire.internal.asGrpcServerStreamingCall
2222
import java.util.concurrent.TimeUnit
2323
import kotlin.reflect.KClass
2424
import okhttp3.Call
@@ -194,6 +194,6 @@ internal class WireGrpcClient internal constructor(
194194
}
195195

196196
override fun <S : Any, R : Any> newServerStreamingCall(method: GrpcMethod<S, R>): GrpcServerStreamingCall<S, R> {
197-
return RealGrpcStreamingCall(this, method).asGrpcServerStreamingCall()
197+
return RealGrpcServerStreamingCall(this, method)
198198
}
199199
}

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/BlockingMessageSource.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import okio.IOException
3333
* * Complete: enqueued when the stream completes normally.
3434
*/
3535
internal class BlockingMessageSource<R : Any>(
36-
val grpcCall: RealGrpcStreamingCall<*, R>,
36+
val onResponseMetadata: (Map<String, String>) -> Unit,
3737
val responseAdapter: ProtoAdapter<R>,
3838
val call: Call,
3939
) : MessageSource<R> {
@@ -67,7 +67,7 @@ internal class BlockingMessageSource<R : Any>(
6767

6868
override fun onResponse(call: Call, response: Response) {
6969
try {
70-
grpcCall.responseMetadata = response.headers.toMap()
70+
onResponseMetadata(response.headers.toMap())
7171
response.use {
7272
response.messageSource(responseAdapter).use { reader ->
7373
while (true) {

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcServerStreamingCall.kt

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,115 @@ import com.squareup.wire.GrpcMethod
1919
import com.squareup.wire.GrpcServerStreamingCall
2020
import com.squareup.wire.GrpcStreamingCall
2121
import com.squareup.wire.MessageSource
22+
import com.squareup.wire.WireGrpcClient
23+
import java.util.concurrent.TimeUnit
2224
import kotlinx.coroutines.CoroutineScope
25+
import kotlinx.coroutines.channels.Channel
2326
import kotlinx.coroutines.channels.ReceiveChannel
27+
import okio.ForwardingTimeout
2428
import okio.Timeout
2529

30+
/**
31+
* A [GrpcServerStreamingCall] that sends a single non-duplex request and reads a streaming
32+
* response. Using a non-duplex request body ensures the complete request (including END_STREAM) is
33+
* sent to the server before responses are read, avoiding delays on servers that wait for the
34+
* client's half-close before starting to stream responses.
35+
*/
2636
internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
37+
private val grpcClient: WireGrpcClient,
38+
override val method: GrpcMethod<S, R>,
39+
) : GrpcServerStreamingCall<S, R> {
40+
41+
private var call: okhttp3.Call? = null
42+
private var canceled = false
43+
44+
override val timeout: Timeout = ForwardingTimeout(Timeout())
45+
46+
init {
47+
timeout.clearTimeout()
48+
timeout.clearDeadline()
49+
}
50+
51+
override var requestMetadata: Map<String, String> = mapOf()
52+
53+
override var responseMetadata: Map<String, String>? = null
54+
internal set
55+
56+
override fun cancel() {
57+
canceled = true
58+
call?.cancel()
59+
}
60+
61+
override fun isCanceled(): Boolean = canceled || call?.isCanceled() == true
62+
63+
override fun isExecuted(): Boolean = call?.isExecuted() ?: false
64+
65+
override fun clone(): GrpcServerStreamingCall<S, R> {
66+
val result = RealGrpcServerStreamingCall(grpcClient, method)
67+
val oldTimeout = this.timeout
68+
result.timeout.also { newTimeout ->
69+
newTimeout.timeout(oldTimeout.timeoutNanos(), TimeUnit.NANOSECONDS)
70+
if (oldTimeout.hasDeadline()) {
71+
newTimeout.deadlineNanoTime(oldTimeout.deadlineNanoTime())
72+
} else {
73+
newTimeout.clearDeadline()
74+
}
75+
}
76+
result.requestMetadata += this.requestMetadata
77+
return result
78+
}
79+
80+
override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
81+
val responseChannel = Channel<R>(1)
82+
val call = initCall(request)
83+
84+
responseChannel.invokeOnClose {
85+
if (responseChannel.isClosedForReceive) {
86+
call.cancel()
87+
}
88+
}
89+
90+
call.enqueue(
91+
responseChannel.readFromResponseBodyCallback(
92+
onResponseMetadata = { this.responseMetadata = it },
93+
responseAdapter = method.responseAdapter,
94+
),
95+
)
96+
97+
return responseChannel
98+
}
99+
100+
override fun executeBlocking(request: S): MessageSource<R> {
101+
val call = initCall(request)
102+
val messageSource = BlockingMessageSource(
103+
onResponseMetadata = { this.responseMetadata = it },
104+
responseAdapter = method.responseAdapter,
105+
call = call,
106+
)
107+
call.enqueue(messageSource.readFromResponseBodyCallback())
108+
return messageSource
109+
}
110+
111+
private fun initCall(request: S): okhttp3.Call {
112+
check(this.call == null) { "already executed" }
113+
val requestBody = newRequestBody(
114+
minMessageToCompress = grpcClient.minMessageToCompress,
115+
requestAdapter = method.requestAdapter,
116+
onlyMessage = request,
117+
)
118+
val result = grpcClient.newCall(method, requestMetadata, requestBody, timeout)
119+
this.call = result
120+
if (canceled) result.cancel()
121+
(timeout as ForwardingTimeout).setDelegate(result.timeout())
122+
return result
123+
}
124+
}
125+
126+
/**
127+
* Wraps a [GrpcStreamingCall] as a [GrpcServerStreamingCall]. Used for test doubles created via
128+
* [com.squareup.wire.GrpcServerStreamingCall] factory functions in GrpcCalls.
129+
*/
130+
internal class GrpcStreamingCallServerStreamingAdapter<S : Any, R : Any>(
27131
private val callDelegate: GrpcStreamingCall<S, R>,
28132
override val method: GrpcMethod<S, R>,
29133
) : GrpcServerStreamingCall<S, R> {
@@ -46,7 +150,7 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
46150

47151
override fun isExecuted() = callDelegate.isExecuted()
48152

49-
override fun clone() = RealGrpcServerStreamingCall(callDelegate.clone(), method)
153+
override fun clone() = GrpcStreamingCallServerStreamingAdapter(callDelegate.clone(), method)
50154

51155
override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
52156
val (sendChannel, receiveChannel) = callDelegate.executeIn(scope)
@@ -65,5 +169,5 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
65169
}
66170
}
67171

68-
internal fun <S : Any, R : Any>GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() =
69-
RealGrpcServerStreamingCall(this, method)
172+
internal fun <S : Any, R : Any> GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() =
173+
GrpcStreamingCallServerStreamingAdapter(this, method)

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/RealGrpcStreamingCall.kt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,23 @@ internal class RealGrpcStreamingCall<S : Any, R : Any>(
9090
callForCancel = call,
9191
)
9292
}
93-
call.enqueue(responseChannel.readFromResponseBodyCallback(this, method.responseAdapter))
93+
call.enqueue(
94+
responseChannel.readFromResponseBodyCallback(
95+
onResponseMetadata = { this.responseMetadata = it },
96+
responseAdapter = method.responseAdapter,
97+
),
98+
)
9499

95100
return requestChannel to responseChannel
96101
}
97102

98103
override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> {
99104
val call = initCall()
100-
val messageSource = BlockingMessageSource(this, method.responseAdapter, call)
105+
val messageSource = BlockingMessageSource(
106+
onResponseMetadata = { this.responseMetadata = it },
107+
responseAdapter = method.responseAdapter,
108+
call = call,
109+
)
101110
val messageSink = requestBody.messageSink(
102111
minMessageToCompress = grpcClient.minMessageToCompress,
103112
requestAdapter = method.requestAdapter,

wire-grpc-client/src/jvmMain/kotlin/com/squareup/wire/internal/grpc.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ internal fun <S : Any> PipeDuplexRequestBody.messageSink(
8787
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
8888
grpcCall: RealGrpcStreamingCall<*, R>,
8989
responseAdapter: ProtoAdapter<R>,
90+
): Callback = readFromResponseBodyCallback(
91+
onResponseMetadata = { grpcCall.responseMetadata = it },
92+
responseAdapter = responseAdapter,
93+
)
94+
95+
/** Sends the response messages to the channel. */
96+
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
97+
onResponseMetadata: (Map<String, String>) -> Unit,
98+
responseAdapter: ProtoAdapter<R>,
9099
): Callback {
91100
return object : Callback {
92101
override fun onFailure(call: Call, e: IOException) {
@@ -95,7 +104,7 @@ internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
95104
}
96105

97106
override fun onResponse(call: Call, response: Response) {
98-
grpcCall.responseMetadata = response.headers.toMap()
107+
onResponseMetadata(response.headers.toMap())
99108
runBlocking {
100109
response.use {
101110
val messageSource = try {

wire-grpc-tests/src/test/java/com/squareup/wire/GrpcOnMockWebServerTest.kt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,21 @@ package com.squareup.wire
1818
import assertk.assertThat
1919
import assertk.assertions.containsExactly
2020
import assertk.assertions.isEqualTo
21+
import assertk.assertions.isNull
2122
import com.squareup.wire.mockwebserver.GrpcDispatcher
2223
import java.util.concurrent.TimeUnit
2324
import java.util.concurrent.atomic.AtomicReference
2425
import kotlinx.coroutines.ExperimentalCoroutinesApi
2526
import kotlinx.coroutines.ObsoleteCoroutinesApi
2627
import kotlinx.coroutines.runBlocking
2728
import okhttp3.Call
29+
import okhttp3.Headers.Companion.headersOf
2830
import okhttp3.Interceptor
2931
import okhttp3.OkHttpClient
3032
import okhttp3.Protocol
33+
import okhttp3.mockwebserver.MockResponse
3134
import okhttp3.mockwebserver.MockWebServer
35+
import okio.Buffer
3236
import org.junit.Before
3337
import org.junit.Rule
3438
import org.junit.Test
@@ -82,6 +86,42 @@ class GrpcOnMockWebServerTest {
8286
routeGuideService = grpcClient.create(RouteGuideClient::class)
8387
}
8488

89+
@Test
90+
fun serverStreamingListFeatures() {
91+
// MockWebServer only dispatches after receiving the complete request body including
92+
// END_STREAM — matching the server behavior that caused responses to hang until timeout
93+
// when GrpcServerStreamingCall used a duplex request body (https://github.com/square/wire/issues/3370).
94+
val responseBody = Buffer()
95+
for (feature in listOf(Feature(name = "peak"), Feature(name = "valley"))) {
96+
val encoded = Feature.ADAPTER.encodeByteString(feature)
97+
responseBody.writeByte(0) // not compressed
98+
responseBody.writeInt(encoded.size)
99+
responseBody.write(encoded)
100+
}
101+
val grpcDispatcher = mockWebServer.dispatcher
102+
mockWebServer.dispatcher = object : okhttp3.mockwebserver.Dispatcher() {
103+
override fun dispatch(request: okhttp3.mockwebserver.RecordedRequest): MockResponse {
104+
if (request.path == "/routeguide.RouteGuide/ListFeatures") {
105+
return MockResponse()
106+
.setHeader("Content-Type", "application/grpc")
107+
.setTrailers(headersOf("grpc-status", "0"))
108+
.setBody(responseBody)
109+
}
110+
return grpcDispatcher.dispatch(request)
111+
}
112+
}
113+
114+
runBlocking {
115+
val responses = routeGuideService.ListFeatures().executeIn(
116+
this,
117+
Rectangle(lo = Point(latitude = 1, longitude = 2), hi = Point(latitude = 3, longitude = 4)),
118+
)
119+
assertThat(responses.receive()).isEqualTo(Feature(name = "peak"))
120+
assertThat(responses.receive()).isEqualTo(Feature(name = "valley"))
121+
assertThat(responses.receiveCatching().getOrNull()).isNull()
122+
}
123+
}
124+
85125
@Test
86126
fun requestResponseSuspend() {
87127
runBlocking {

0 commit comments

Comments
 (0)