diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 9bba55422156..fe670b81f246 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -642,8 +642,14 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) { switch frame := frame.(type) { case *http2.MetaHeadersFrame: if err := t.operateHeaders(ctx, frame, handle); err != nil { - t.Close(err) - break + // Any error processing client headers, e.g. invalid stream ID, + // is considered a protocol violation. + t.controlBuf.put(&goAway{ + code: http2.ErrCodeProtocol, + debugData: []byte(err.Error()), + closeConn: err, + }) + continue } case *http2.DataFrame: t.handleData(frame) diff --git a/test/end2end_test.go b/test/end2end_test.go index 001a1228a25b..321aeb52d70e 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3901,6 +3901,66 @@ func (s) TestClientRequestBodyErrorCloseAfterLength(t *testing.T) { } } +// Tests gRPC server's behavior when a gRPC client sends a frame with an invalid +// streamID. Per [HTTP/2 spec]: Streams initiated by a client MUST use +// odd-numbered stream identifiers. This test sets up a test server and sends a +// header frame with stream ID of 2. The test asserts that a subsequent read on +// the transport sends a GoAwayFrame with error code: PROTOCOL_ERROR. +// +// [HTTP/2 spec]: https://httpwg.org/specs/rfc7540.html#StreamIdentifiers +func (s) TestClientInvalidStreamID(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer lis.Close() + s := grpc.NewServer() + defer s.Stop() + go s.Serve(lis) + + conn, err := net.DialTimeout("tcp", lis.Addr().String(), defaultTestTimeout) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + st := newServerTesterFromConn(t, conn) + st.greet() + st.writeHeadersGRPC(2, "/grpc.testing.TestService/StreamingInputCall", true) + goAwayFrame := st.wantGoAway(http2.ErrCodeProtocol) + want := "received an illegal stream id: 2." + if got := string(goAwayFrame.DebugData()); !strings.Contains(got, want) { + t.Fatalf(" Received: %v, Expected error message to contain: %v.", got, want) + } +} + +// TestInvalidStreamIDSmallerThanPrevious tests the server sends a GOAWAY frame +// with error code: PROTOCOL_ERROR when the streamID of the current frame is +// lower than the previous frames. +func (s) TestInvalidStreamIDSmallerThanPrevious(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer lis.Close() + s := grpc.NewServer() + defer s.Stop() + go s.Serve(lis) + + conn, err := net.DialTimeout("tcp", lis.Addr().String(), defaultTestTimeout) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + st := newServerTesterFromConn(t, conn) + st.greet() + st.writeHeadersGRPC(3, "/grpc.testing.TestService/StreamingInputCall", true) + st.wantAnyFrame() + st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", true) + goAwayFrame := st.wantGoAway(http2.ErrCodeProtocol) + want := "received an illegal stream id: 1" + if got := string(goAwayFrame.DebugData()); !strings.Contains(got, want) { + t.Fatalf(" Received: %v, Expected error message to contain: %v.", got, want) + } +} + func testClientRequestBodyErrorCloseAfterLength(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("Server.processUnaryRPC failed to write status")