Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,15 +719,19 @@ func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.
}
}

func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
switch pf {
case compressionNone:
case compressionMade:
if recvCompress == "" || recvCompress == encoding.Identity {
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
if isServer {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
} else {
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
Expand All @@ -744,14 +748,14 @@ type payloadInfo struct {
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, nil, err
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
return nil, nil, st.Err()
}

Expand Down Expand Up @@ -825,8 +829,8 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
payInfo = &payloadInfo{}
}

d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand Down
10 changes: 5 additions & 5 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream.
a.decompSet = true
}
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp)
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false)
if err != nil {
if err == io.EOF {
if statusErr := a.s.Status().Err(); statusErr != nil {
Expand Down Expand Up @@ -1122,7 +1122,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
}
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp)
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
Expand Down Expand Up @@ -1423,7 +1423,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream.
as.decompSet = true
}
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false)
if err != nil {
if err == io.EOF {
if statusErr := as.s.Status().Err(); statusErr != nil {
Expand All @@ -1444,7 +1444,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {

// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp)
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
Expand Down Expand Up @@ -1715,7 +1715,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
payInfo = &payloadInfo{}
}
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil {
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
if err == io.EOF {
if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{}
Expand Down
108 changes: 108 additions & 0 deletions test/compression_cases_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package test

import (
"context"
"fmt"
"testing"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/stubserver"
testpb "google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/status"
)

type mockCompressor struct {
grpc.Compressor
}

func newMockCompressor() grpc.Compressor {
return &mockCompressor{grpc.NewGZIPCompressor()}
}

func (c *mockCompressor) Type() string {
return "mock"
}

type mockDecompressor struct {
grpc.Decompressor
}

func newMockDecompressor() grpc.Decompressor {
return &mockDecompressor{grpc.NewGZIPDecompressor()}
}

func (d *mockDecompressor) Type() string {
return "mock"
}

func TestCompressionCases(t *testing.T) {
cases := []struct {
desc string
clientUseMock bool
serverUseMock bool
expectedStatus codes.Code
}{
{
desc: "Client and Server use mock compression",
clientUseMock: true,
serverUseMock: true,
expectedStatus: codes.OK,
},
{
desc: "Only Client use mock compression",
clientUseMock: true,
serverUseMock: false,
expectedStatus: codes.Unimplemented,
},
{
desc: "Only Server use mock compression",
clientUseMock: false,
serverUseMock: true,
expectedStatus: codes.Internal,
},
}
for i, tc := range cases {
fmt.Println("TESTCASE: ", i)

ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{
Payload: in.Payload,
}, nil
},
}
sopts := []grpc.ServerOption{}
if tc.serverUseMock {
sopts = append(sopts, grpc.RPCCompressor(newMockCompressor()), grpc.RPCDecompressor(newMockDecompressor()))
}
if err := ss.Start(sopts); err != nil {
t.Fatalf("Error starting server: %v", err)
}

defer ss.Stop()
dOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
if tc.clientUseMock {
dOpts = append(dOpts, grpc.WithCompressor(newMockCompressor()), grpc.WithDecompressor(newMockDecompressor()))
}
cc, err := grpc.Dial(ss.Address, dOpts...)
if err != nil {
t.Fatalf("Failed to dial server: %v", err)
}
defer cc.Close()
ss.Client = testpb.NewTestServiceClient(cc)

payload := &testpb.SimpleRequest{
Payload: &testpb.Payload{
Body: []byte("test message"),
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err = ss.Client.UnaryCall(ctx, payload)
if st, _ := status.FromError(err); st.Code() != tc.expectedStatus {
t.Fatalf("got %v want %v", st.Code(), tc.expectedStatus)
}
}
}