Skip to content

Commit 6d0aaae

Browse files
authored
grpc: make client report Internal status when server response contains unsupported encoding (#7461)
1 parent 338595c commit 6d0aaae

File tree

4 files changed

+95
-16
lines changed

4 files changed

+95
-16
lines changed

rpc_util.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,15 +719,19 @@ func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.
719719
}
720720
}
721721

722-
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
722+
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
723723
switch pf {
724724
case compressionNone:
725725
case compressionMade:
726726
if recvCompress == "" || recvCompress == encoding.Identity {
727727
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
728728
}
729729
if !haveCompressor {
730-
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
730+
if isServer {
731+
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
732+
} else {
733+
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
734+
}
731735
}
732736
default:
733737
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
@@ -744,14 +748,16 @@ type payloadInfo struct {
744748
//
745749
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
746750
// the buffer is no longer needed.
747-
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
751+
// TODO: Refactor this function to reduce the number of arguments.
752+
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
753+
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
748754
) (uncompressedBuf []byte, cancel func(), err error) {
749755
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
750756
if err != nil {
751757
return nil, nil, err
752758
}
753759

754-
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
760+
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
755761
return nil, nil, st.Err()
756762
}
757763

@@ -825,8 +831,8 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
825831
// For the two compressor parameters, both should not be set, but if they are,
826832
// dc takes precedence over compressor.
827833
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
828-
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
829-
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
834+
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
835+
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
830836
if err != nil {
831837
return err
832838
}

server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1336,7 +1336,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
13361336
payInfo = &payloadInfo{}
13371337
}
13381338

1339-
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
1339+
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
13401340
if err != nil {
13411341
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
13421342
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)

stream.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,8 +1083,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
10831083
// Only initialize this state once per stream.
10841084
a.decompSet = true
10851085
}
1086-
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp)
1087-
if err != nil {
1086+
if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false); err != nil {
10881087
if err == io.EOF {
10891088
if statusErr := a.s.Status().Err(); statusErr != nil {
10901089
return statusErr
@@ -1122,8 +1121,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
11221121
}
11231122
// Special handling for non-server-stream rpcs.
11241123
// This recv expects EOF or errors, so we don't collect inPayload.
1125-
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp)
1126-
if err == nil {
1124+
if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false); err == nil {
11271125
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
11281126
}
11291127
if err == io.EOF {
@@ -1423,8 +1421,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
14231421
// Only initialize this state once per stream.
14241422
as.decompSet = true
14251423
}
1426-
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
1427-
if err != nil {
1424+
if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err != nil {
14281425
if err == io.EOF {
14291426
if statusErr := as.s.Status().Err(); statusErr != nil {
14301427
return statusErr
@@ -1444,8 +1441,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
14441441

14451442
// Special handling for non-server-stream rpcs.
14461443
// This recv expects EOF or errors, so we don't collect inPayload.
1447-
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
1448-
if err == nil {
1444+
if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err == nil {
14491445
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
14501446
}
14511447
if err == io.EOF {
@@ -1715,7 +1711,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
17151711
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
17161712
payInfo = &payloadInfo{}
17171713
}
1718-
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil {
1714+
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
17191715
if err == io.EOF {
17201716
if len(ss.binlogs) != 0 {
17211717
chc := &binarylog.ClientHalfClose{}

test/compressor_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030

3131
"google.golang.org/grpc"
3232
"google.golang.org/grpc/codes"
33+
"google.golang.org/grpc/credentials/insecure"
3334
"google.golang.org/grpc/encoding"
3435
"google.golang.org/grpc/internal/stubserver"
3536
"google.golang.org/grpc/metadata"
@@ -39,6 +40,82 @@ import (
3940
testpb "google.golang.org/grpc/interop/grpc_testing"
4041
)
4142

43+
// TestUnsupportedEncodingResponse validates gRPC status codes
44+
// for different client-server compression setups
45+
// ensuring the correct behavior when compression is enabled or disabled on either side.
46+
func (s) TestUnsupportedEncodingResponse(t *testing.T) {
47+
tests := []struct {
48+
name string
49+
clientCompress bool
50+
serverCompress bool
51+
wantStatus codes.Code
52+
}{
53+
{
54+
name: "client_server_compression",
55+
clientCompress: true,
56+
serverCompress: true,
57+
wantStatus: codes.OK,
58+
},
59+
{
60+
name: "client_compression",
61+
clientCompress: true,
62+
serverCompress: false,
63+
wantStatus: codes.Unimplemented,
64+
},
65+
{
66+
name: "server_compression",
67+
clientCompress: false,
68+
serverCompress: true,
69+
wantStatus: codes.Internal,
70+
},
71+
}
72+
73+
for _, test := range tests {
74+
t.Run(test.name, func(t *testing.T) {
75+
ss := &stubserver.StubServer{
76+
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
77+
return &testpb.SimpleResponse{Payload: in.Payload}, nil
78+
},
79+
}
80+
sopts := []grpc.ServerOption{}
81+
if test.serverCompress {
82+
// Using deprecated methods to selectively apply compression
83+
// only on the server side. With encoding.registerCompressor(),
84+
// the compressor is applied globally, affecting client and server
85+
sopts = append(sopts, grpc.RPCCompressor(newNopCompressor()), grpc.RPCDecompressor(newNopDecompressor()))
86+
}
87+
if err := ss.StartServer(sopts...); err != nil {
88+
t.Fatalf("Error starting server: %v", err)
89+
}
90+
defer ss.Stop()
91+
92+
dopts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
93+
if test.clientCompress {
94+
// UseCompressor() requires the compressor to be registered
95+
// using encoding.RegisterCompressor() which applies compressor globally,
96+
// Hence, using deprecated WithCompressor() and WithDecompressor()
97+
// to apply compression only on client.
98+
dopts = append(dopts, grpc.WithCompressor(newNopCompressor()), grpc.WithDecompressor(newNopDecompressor()))
99+
}
100+
if err := ss.StartClient(dopts...); err != nil {
101+
t.Fatalf("Error starting client: %v", err)
102+
}
103+
104+
payload := &testpb.SimpleRequest{
105+
Payload: &testpb.Payload{
106+
Body: []byte("test message"),
107+
},
108+
}
109+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
110+
defer cancel()
111+
_, err := ss.Client.UnaryCall(ctx, payload)
112+
if got, want := status.Code(err), test.wantStatus; got != want {
113+
t.Errorf("Client.UnaryCall() = %v, want %v", got, want)
114+
}
115+
})
116+
}
117+
}
118+
42119
func (s) TestCompressServerHasNoSupport(t *testing.T) {
43120
for _, e := range listTestEnv() {
44121
testCompressServerHasNoSupport(t, e)

0 commit comments

Comments
 (0)