Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ package grpc
import (
"bytes"
"io"
"math"
"time"

"golang.org/x/net/context"
Expand Down Expand Up @@ -73,7 +72,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
}
}
for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil {
if err == io.EOF {
break
}
Expand Down
38 changes: 25 additions & 13 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package grpc
import (
"errors"
"fmt"
"math"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -87,23 +88,33 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
scChan <-chan ServiceConfig
copts transport.ConnectOptions
}
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
scChan <-chan ServiceConfig
copts transport.ConnectOptions
maxMsgSize int
}

const defaultClientMaxMsgSize = math.MaxInt32

// DialOption configures how we set up the connection.
type DialOption func(*dialOptions)

// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
func WithMaxMsgSize(s int) DialOption {
return func(o *dialOptions) {
o.maxMsgSize = s
}
}

// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
func WithCodec(c Codec) DialOption {
return func(o *dialOptions) {
Expand Down Expand Up @@ -296,6 +307,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
conns: make(map[Address]*addrConn),
}
cc.ctx, cc.cancel = context.WithCancel(context.Background())
cc.dopts.maxMsgSize = defaultClientMaxMsgSize
for _, opt := range opts {
opt(&cc.dopts)
}
Expand Down
43 changes: 22 additions & 21 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"bytes"
"errors"
"io"
"math"
"sync"
"time"

Expand Down Expand Up @@ -208,13 +207,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
break
}
cs := &clientStream{
opts: opts,
c: c,
desc: desc,
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
cancel: cancel,
opts: opts,
c: c,
desc: desc,
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
maxMsgSize: cc.dopts.maxMsgSize,
cancel: cancel,

put: put,
t: t,
Expand Down Expand Up @@ -259,17 +259,18 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth

// clientStream implements a client side Stream.
type clientStream struct {
opts []CallOption
c callInfo
t transport.ClientTransport
s *transport.Stream
p *parser
desc *StreamDesc
codec Codec
cp Compressor
cbuf *bytes.Buffer
dc Decompressor
cancel context.CancelFunc
opts []CallOption
c callInfo
t transport.ClientTransport
s *transport.Stream
p *parser
desc *StreamDesc
codec Codec
cp Compressor
cbuf *bytes.Buffer
dc Decompressor
maxMsgSize int
cancel context.CancelFunc

tracing bool // set to EnableTracing when the clientStream is created.

Expand Down Expand Up @@ -382,7 +383,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
Client: true,
}
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
Expand All @@ -405,7 +406,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
}
// Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
Expand Down
34 changes: 32 additions & 2 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ func (te *test) clientConn() *grpc.ClientConn {
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
if te.maxMsgSize > 0 {
opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize))
}
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
Expand Down Expand Up @@ -1427,22 +1430,34 @@ func testExceedMsgLimit(t *testing.T, e env) {
tc := testpb.NewTestServiceClient(te.clientConn())

argSize := int32(te.maxMsgSize + 1)
const respSize = 1
const smallSize = 1

payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
if err != nil {
t.Fatal(err)
}

// test on server side for unary RPC
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
ResponseSize: proto.Int32(smallSize),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
}
// test on client side for unary RPC
req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1)
req.Payload = smallPayload
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
}

// test on server side for streaming RPC
stream, err := tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
Expand All @@ -1469,6 +1484,21 @@ func testExceedMsgLimit(t *testing.T, e env) {
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
}

// test on client side for streaming RPC
stream, err = tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam[0].Size = proto.Int32(int32(te.maxMsgSize) + 1)
sreq.Payload = smallPayload
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
}

}

func TestPeerClientSide(t *testing.T) {
Expand Down