Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,15 +719,19 @@ func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.
}
}

func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
switch pf {
case compressionNone:
case compressionMade:
if recvCompress == "" || recvCompress == encoding.Identity {
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
if isServer {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
} else {
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
Expand All @@ -744,14 +748,16 @@ type payloadInfo struct {
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, nil, err
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
return nil, nil, st.Err()
}

Expand Down Expand Up @@ -825,8 +831,8 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
payInfo = &payloadInfo{}
}

d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand Down
14 changes: 5 additions & 9 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream.
a.decompSet = true
}
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp)
if err != nil {
if err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false); err != nil {
if err == io.EOF {
if statusErr := a.s.Status().Err(); statusErr != nil {
return statusErr
Expand Down Expand Up @@ -1122,8 +1121,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
}
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp)
if err == nil {
if err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false); err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
Expand Down Expand Up @@ -1423,8 +1421,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream.
as.decompSet = true
}
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
if err != nil {
if err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err != nil {
if err == io.EOF {
if statusErr := as.s.Status().Err(); statusErr != nil {
return statusErr
Expand All @@ -1444,8 +1441,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {

// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
if err == nil {
if err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
Expand Down Expand Up @@ -1715,7 +1711,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
payInfo = &payloadInfo{}
}
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil {
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
if err == io.EOF {
if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{}
Expand Down
77 changes: 77 additions & 0 deletions test/compressor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
Expand All @@ -39,6 +40,82 @@ import (
testpb "google.golang.org/grpc/interop/grpc_testing"
)

// TestUnsupportedEncodingResponse validates gRPC status codes
// for different client-server compression setups
// ensuring the correct behavior when compression is enabled or disabled on either side.
func (s) TestUnsupportedEncodingResponse(t *testing.T) {
tests := []struct {
name string
clientCompress bool
serverCompress bool
wantStatus codes.Code
}{
{
name: "client_server_compression",
clientCompress: true,
serverCompress: true,
wantStatus: codes.OK,
},
{
name: "client_compression",
clientCompress: true,
serverCompress: false,
wantStatus: codes.Unimplemented,
},
{
name: "server_compression",
clientCompress: false,
serverCompress: true,
wantStatus: codes.Internal,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{Payload: in.Payload}, nil
},
}
sopts := []grpc.ServerOption{}
if test.serverCompress {
// Using deprecated methods to selectively apply compression only on the server side.
// with encoding.registerCompressor(), the compressor is applied globally,
// affecting both the client and server.
sopts = append(sopts, grpc.RPCCompressor(newNopCompressor()), grpc.RPCDecompressor(newNopDecompressor()))
}
if err := ss.Start(sopts); err != nil {
t.Fatalf("Error starting server: %v", err)
}
defer ss.Stop()

dopts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
if test.clientCompress {
// UseCompressor() requires the compressor to be registered
// using encoding.RegisterCompressor() which applies compressor globally,
// Hence, using deprecated WithCompressor() and WithDecompressor()
// to apply compression only on client.
dopts = append(dopts, grpc.WithCompressor(newNopCompressor()), grpc.WithDecompressor(newNopDecompressor()))
}
if err := ss.StartClient(dopts...); err != nil {
t.Fatalf("Error starting client: %v", err)
}

payload := &testpb.SimpleRequest{
Payload: &testpb.Payload{
Body: []byte("test message"),
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err := ss.Client.UnaryCall(ctx, payload)
if got, want := status.Code(err), test.wantStatus; got != want {
t.Errorf("Client.UnaryCall() = %v, want %v", got, want)
}
})
}
}

func (s) TestCompressServerHasNoSupport(t *testing.T) {
for _, e := range listTestEnv() {
testCompressServerHasNoSupport(t, e)
Expand Down