From a94b0948a38dd739ec9671c859a18067f1b88838 Mon Sep 17 00:00:00 2001 From: Mahak Mukhi Date: Thu, 9 Mar 2017 16:58:23 -0800 Subject: [PATCH 1/3] Client should have a check on maximum size of received message size. --- call.go | 3 +-- clientconn.go | 40 +++++++++++++++++++++++++++------------- stream.go | 43 ++++++++++++++++++++++--------------------- test/end2end_test.go | 33 +++++++++++++++++++++++++++++++-- 4 files changed, 81 insertions(+), 38 deletions(-) diff --git a/call.go b/call.go index 81b52be294b6..c1588c6375c7 100644 --- a/call.go +++ b/call.go @@ -36,7 +36,6 @@ package grpc import ( "bytes" "io" - "math" "time" "golang.org/x/net/context" @@ -73,7 +72,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } } for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil { if err == io.EOF { break } diff --git a/clientconn.go b/clientconn.go index b8e31988afa7..1bf824bc7c44 100644 --- a/clientconn.go +++ b/clientconn.go @@ -36,6 +36,7 @@ package grpc import ( "errors" "fmt" + "math" "net" "strings" "sync" @@ -87,23 +88,33 @@ var ( // dialOptions configure a Dial call. dialOptions are set by the DialOption // values passed to Dial. type dialOptions struct { - unaryInt UnaryClientInterceptor - streamInt StreamClientInterceptor - codec Codec - cp Compressor - dc Decompressor - bs backoffStrategy - balancer Balancer - block bool - insecure bool - timeout time.Duration - scChan <-chan ServiceConfig - copts transport.ConnectOptions -} + unaryInt UnaryClientInterceptor + streamInt StreamClientInterceptor + codec Codec + cp Compressor + dc Decompressor + bs backoffStrategy + balancer Balancer + block bool + insecure bool + timeout time.Duration + scChan <-chan ServiceConfig + copts transport.ConnectOptions + maxMsgSize int +} + +const defaultClientMaxMsgSize = math.MaxInt32 // DialOption configures how we set up the connection. type DialOption func(*dialOptions) +// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. +func WithMaxMsgSize(s int) DialOption { + return func(o *dialOptions) { + o.maxMsgSize = s + } +} + // WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling. func WithCodec(c Codec) DialOption { return func(o *dialOptions) { @@ -304,6 +315,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout) defer cancel() } + if cc.dopts.maxMsgSize == 0 { + cc.dopts.maxMsgSize = defaultClientMaxMsgSize + } defer func() { select { diff --git a/stream.go b/stream.go index bb468dc37e63..0ef2077ce254 100644 --- a/stream.go +++ b/stream.go @@ -37,7 +37,6 @@ import ( "bytes" "errors" "io" - "math" "sync" "time" @@ -208,13 +207,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth break } cs := &clientStream{ - opts: opts, - c: c, - desc: desc, - codec: cc.dopts.codec, - cp: cc.dopts.cp, - dc: cc.dopts.dc, - cancel: cancel, + opts: opts, + c: c, + desc: desc, + codec: cc.dopts.codec, + cp: cc.dopts.cp, + dc: cc.dopts.dc, + maxMsgSize: cc.dopts.maxMsgSize, + cancel: cancel, put: put, t: t, @@ -259,17 +259,18 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { - opts []CallOption - c callInfo - t transport.ClientTransport - s *transport.Stream - p *parser - desc *StreamDesc - codec Codec - cp Compressor - cbuf *bytes.Buffer - dc Decompressor - cancel context.CancelFunc + opts []CallOption + c callInfo + t transport.ClientTransport + s *transport.Stream + p *parser + desc *StreamDesc + codec Codec + cp Compressor + cbuf *bytes.Buffer + dc Decompressor + maxMsgSize int + cancel context.CancelFunc tracing bool // set to EnableTracing when the clientStream is created. @@ -382,7 +383,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { Client: true, } } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -405,7 +406,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } // Special handling for client streaming rpc. // This recv expects EOF or errors, so we don't collect inPayload. - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) diff --git a/test/end2end_test.go b/test/end2end_test.go index d743623f97ee..8aee5a1f3704 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -570,6 +570,9 @@ func (te *test) clientConn() *grpc.ClientConn { if te.streamClientInt != nil { opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt)) } + if te.maxMsgSize > 0 { + opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize)) + } switch te.e.security { case "tls": creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") @@ -1427,22 +1430,33 @@ func testExceedMsgLimit(t *testing.T, e env) { tc := testpb.NewTestServiceClient(te.clientConn()) argSize := int32(te.maxMsgSize + 1) - const respSize = 1 + const smallSize = 1 payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) if err != nil { t.Fatal(err) } + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + // test on server side for unary RPC req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), - ResponseSize: proto.Int32(respSize), + ResponseSize: proto.Int32(smallSize), Payload: payload, } if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) } + // test on client side for unary RPC + req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1) + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) + } + // test on server side for streaming RPC stream, err := tc.FullDuplexCall(te.ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -1469,6 +1483,21 @@ func testExceedMsgLimit(t *testing.T, e env) { if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal) } + + // test on client side for streaming RPC + stream, err = tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam[0].Size = proto.Int32(int32(te.maxMsgSize) + 1) + sreq.Payload = smallPayload + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal) + } + } func TestPeerClientSide(t *testing.T) { From b9b6d48985c31f950af2020baad41b1d32f04450 Mon Sep 17 00:00:00 2001 From: Mahak Mukhi Date: Thu, 9 Mar 2017 17:01:46 -0800 Subject: [PATCH 2/3] test debug --- test/end2end_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/test/end2end_test.go b/test/end2end_test.go index 8aee5a1f3704..37964919945e 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -1452,6 +1452,7 @@ func testExceedMsgLimit(t *testing.T, e env) { } // test on client side for unary RPC req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1) + req.Payload = smallPayload if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) } From 0e71619115051aa6df2635efbee7f80f389dd14b Mon Sep 17 00:00:00 2001 From: Mahak Mukhi Date: Thu, 9 Mar 2017 17:48:37 -0800 Subject: [PATCH 3/3] making client consistent with server --- clientconn.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/clientconn.go b/clientconn.go index 1bf824bc7c44..1ba592c500cf 100644 --- a/clientconn.go +++ b/clientconn.go @@ -307,6 +307,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * conns: make(map[Address]*addrConn), } cc.ctx, cc.cancel = context.WithCancel(context.Background()) + cc.dopts.maxMsgSize = defaultClientMaxMsgSize for _, opt := range opts { opt(&cc.dopts) } @@ -315,9 +316,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout) defer cancel() } - if cc.dopts.maxMsgSize == 0 { - cc.dopts.maxMsgSize = defaultClientMaxMsgSize - } defer func() { select {