diff --git a/stream.go b/stream.go index ca6948926f93..ccde41f8dbcb 100644 --- a/stream.go +++ b/stream.go @@ -1138,6 +1138,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { if statusErr := a.transportStream.Status().Err(); statusErr != nil { return statusErr } + // Received no msg and status OK for non-server streaming rpcs. + if !cs.desc.ServerStreams { + return status.Error(codes.Internal, "cardinality violation: received no response message from non-streaming RPC") + } return io.EOF // indicates successful end of stream. } @@ -1171,7 +1175,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } else if err != nil { return toRPCErr(err) } - return status.Errorf(codes.Internal, "cardinality violation: expected for non server-streaming RPCs, but received another message") + return status.Error(codes.Internal, "cardinality violation: expected for non server-streaming RPCs, but received another message") } func (a *csAttempt) finish(err error) { @@ -1478,6 +1482,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { if statusErr := as.transportStream.Status().Err(); statusErr != nil { return statusErr } + // Received no msg and status OK for non-server streaming rpcs. + if !as.desc.ServerStreams { + return status.Error(codes.Internal, "cardinality violation: received no response message from non-streaming RPC") + } return io.EOF // indicates successful end of stream. } return toRPCErr(err) @@ -1495,7 +1503,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { } else if err != nil { return toRPCErr(err) } - return status.Errorf(codes.Internal, "cardinality violation: expected for non server-streaming RPCs, but received another message") + return status.Error(codes.Internal, "cardinality violation: expected for non server-streaming RPCs, but received another message") } func (as *addrConnStream) finish(err error) { diff --git a/test/end2end_test.go b/test/end2end_test.go index 584c90ca3b15..ab2517d5f18e 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3589,9 +3589,6 @@ func testClientStreamingError(t *testing.T, e env) { // Tests that a client receives a cardinality violation error for client-streaming // RPCs if the server doesn't send a message before returning status OK. func (s) TestClientStreamingCardinalityViolation_ServerHandlerMissingSendAndClose(t *testing.T) { - // TODO : https://github.com/grpc/grpc-go/issues/8119 - remove `t.Skip()` - // after this is fixed. - t.Skip() ss := &stubserver.StubServer{ StreamingInputCallF: func(_ testgrpc.TestService_StreamingInputCallServer) error { // Returning status OK without sending a response message.This is a @@ -3740,8 +3737,113 @@ func (s) TestClientStreaming_ReturnErrorAfterSendAndClose(t *testing.T) { } } +// Tests that a client receives a cardinality violation error for unary +// RPCs if the server doesn't send a message before returning status OK. +func (s) TestUnaryRPC_ServerSendsOnlyTrailersWithOK(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + ss := grpc.UnknownServiceHandler(func(any, grpc.ServerStream) error { + return nil + }) + + s := grpc.NewServer(ss) + go s.Serve(lis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", lis.Addr(), err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests that client will receive cardinality violations when calling +// RecvMsg() multiple times for non-streaming response streams. +func (s) TestUnaryRPC_ClientCallRecvMsgTwice(t *testing.T) { + e := tcpTLSEnv + te := newTest(t, e) + defer te.tearDown() + + te.startServer(&testServer{security: e.security}) + + cc := te.clientConn() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + desc := &grpc.StreamDesc{ + StreamName: "UnaryCall", + ServerStreams: false, + ClientStreams: false, + } + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/UnaryCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.SimpleRequest{}); err != nil { + t.Fatalf("stream.SendMsg(_) = %v, want ", err) + } + + resp := &testpb.SimpleResponse{} + if err := stream.RecvMsg(resp); err != nil { + t.Fatalf("stream.RecvMsg() = %v , want ", err) + } + + if err = stream.RecvMsg(resp); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests that client will receive cardinality violations when calling +// RecvMsg() multiple times for non-streaming response streams. +func (s) TestClientStreaming_ClientCallRecvMsgTwice(t *testing.T) { + ss := stubserver.StubServer{ + StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error { + if err := stream.SendAndClose(&testpb.StreamingInputCallResponse{}); err != nil { + t.Errorf("stream.SendAndClose(_) = %v, want ", err) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := ss.Client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf(".StreamingInputCall(_) = _, %v, want ", err) + } + if err := stream.Send(&testpb.StreamingInputCallRequest{}); err != nil { + t.Fatalf("stream.Send(_) = %v, want ", err) + } + if err := stream.CloseSend(); err != nil { + t.Fatalf("stream.CloseSend() = %v, want ", err) + } + resp := new(testpb.StreamingInputCallResponse) + if err := stream.RecvMsg(resp); err != nil { + t.Fatalf("stream.RecvMsg() = %v , want ", err) + } + if err = stream.RecvMsg(resp); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + // Tests that a client receives a cardinality violation error for client-streaming -// RPCs if the server call SendMsg multiple times. +// RPCs if the server call SendMsg() multiple times. func (s) TestClientStreaming_ServerHandlerSendMsgAfterSendMsg(t *testing.T) { ss := stubserver.StubServer{ StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error {