Skip to content

Commit 8cb9846

Browse files
authored
grpc: Add a pointer of server to ctx passed into stats handler (#6750)
1 parent 8190d88 commit 8cb9846

File tree

4 files changed

+174
-0
lines changed

4 files changed

+174
-0
lines changed

internal/internal.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ var (
7373
// xDS-enabled server invokes this method on a grpc.Server when a particular
7474
// listener moves to "not-serving" mode.
7575
DrainServerTransports any // func(*grpc.Server, string)
76+
// IsRegisteredMethod returns whether the passed in method is registered as
77+
// a method on the server.
78+
IsRegisteredMethod any // func(*grpc.Server, string) bool
79+
// ServerFromContext returns the server from the context.
80+
ServerFromContext any // func(context.Context) *grpc.Server
7681
// AddGlobalServerOptions adds an array of ServerOption that will be
7782
// effective globally for newly created servers. The priority will be: 1.
7883
// user-provided; 2. this method; 3. default values.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
*
3+
* Copyright 2023 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package testutils
20+
21+
import (
22+
"context"
23+
24+
"google.golang.org/grpc/stats"
25+
)
26+
27+
// StubStatsHandler is a stats handler that is easy to customize within
28+
// individual test cases. It is a stubbable implementation of
29+
// google.golang.org/grpc/stats.Handler for testing purposes.
30+
type StubStatsHandler struct {
31+
TagRPCF func(ctx context.Context, info *stats.RPCTagInfo) context.Context
32+
HandleRPCF func(ctx context.Context, info stats.RPCStats)
33+
TagConnF func(ctx context.Context, info *stats.ConnTagInfo) context.Context
34+
HandleConnF func(ctx context.Context, info stats.ConnStats)
35+
}
36+
37+
// TagRPC calls the StubStatsHandler's TagRPCF, if set.
38+
func (ssh *StubStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
39+
if ssh.TagRPCF != nil {
40+
return ssh.TagRPCF(ctx, info)
41+
}
42+
return ctx
43+
}
44+
45+
// HandleRPC calls the StubStatsHandler's HandleRPCF, if set.
46+
func (ssh *StubStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
47+
if ssh.HandleRPCF != nil {
48+
ssh.HandleRPCF(ctx, rs)
49+
}
50+
}
51+
52+
// TagConn calls the StubStatsHandler's TagConnF, if set.
53+
func (ssh *StubStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
54+
if ssh.TagConnF != nil {
55+
return ssh.TagConnF(ctx, info)
56+
}
57+
return ctx
58+
}
59+
60+
// HandleConn calls the StubStatsHandler's HandleConnF, if set.
61+
func (ssh *StubStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
62+
if ssh.HandleConnF != nil {
63+
ssh.HandleConnF(ctx, cs)
64+
}
65+
}

server.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ func init() {
7070
internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
7171
return srv.opts.creds
7272
}
73+
internal.IsRegisteredMethod = func(srv *Server, method string) bool {
74+
return srv.isRegisteredMethod(method)
75+
}
76+
internal.ServerFromContext = serverFromContext
7377
internal.DrainServerTransports = func(srv *Server, addr string) {
7478
srv.drainServerTransports(addr)
7579
}
@@ -1707,6 +1711,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
17071711

17081712
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
17091713
ctx := stream.Context()
1714+
ctx = contextWithServer(ctx, s)
17101715
var ti *traceInfo
17111716
if EnableTracing {
17121717
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
@@ -1953,6 +1958,44 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
19531958
return codec
19541959
}
19551960

1961+
type serverKey struct{}
1962+
1963+
// serverFromContext gets the Server from the context.
1964+
func serverFromContext(ctx context.Context) *Server {
1965+
s, _ := ctx.Value(serverKey{}).(*Server)
1966+
return s
1967+
}
1968+
1969+
// contextWithServer sets the Server in the context.
1970+
func contextWithServer(ctx context.Context, server *Server) context.Context {
1971+
return context.WithValue(ctx, serverKey{}, server)
1972+
}
1973+
1974+
// isRegisteredMethod returns whether the passed in method is registered as a
1975+
// method on the server. /service/method and service/method will match if the
1976+
// service and method are registered on the server.
1977+
func (s *Server) isRegisteredMethod(serviceMethod string) bool {
1978+
if serviceMethod != "" && serviceMethod[0] == '/' {
1979+
serviceMethod = serviceMethod[1:]
1980+
}
1981+
pos := strings.LastIndex(serviceMethod, "/")
1982+
if pos == -1 { // Invalid method name syntax.
1983+
return false
1984+
}
1985+
service := serviceMethod[:pos]
1986+
method := serviceMethod[pos+1:]
1987+
srv, knownService := s.services[service]
1988+
if knownService {
1989+
if _, ok := srv.methods[method]; ok {
1990+
return true
1991+
}
1992+
if _, ok := srv.streams[method]; ok {
1993+
return true
1994+
}
1995+
}
1996+
return false
1997+
}
1998+
19561999
// SetHeader sets the header metadata to be sent from the server to the client.
19572000
// The context provided must be the context passed to the server's handler.
19582001
//

stats/stats_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ import (
3131
"github.com/golang/protobuf/proto"
3232
"google.golang.org/grpc"
3333
"google.golang.org/grpc/credentials/insecure"
34+
"google.golang.org/grpc/internal"
3435
"google.golang.org/grpc/internal/grpctest"
36+
"google.golang.org/grpc/internal/stubserver"
37+
"google.golang.org/grpc/internal/testutils"
3538
"google.golang.org/grpc/metadata"
3639
"google.golang.org/grpc/stats"
3740
"google.golang.org/grpc/status"
@@ -1457,3 +1460,61 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) {
14571460
t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
14581461
}
14591462
}
1463+
1464+
// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
1465+
// gets access to a Server on the server side, and thus the method that the
1466+
// server owns which specifies whether a method is made or not. The test sets up
1467+
// a server with a unary call and full duplex call configured, and makes an RPC.
1468+
// Within the stats handler, asking the server whether unary or duplex method
1469+
// names are registered should return true, and any other query should return
1470+
// false.
1471+
func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
1472+
wg := sync.WaitGroup{}
1473+
wg.Add(1)
1474+
stubStatsHandler := &testutils.StubStatsHandler{
1475+
TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
1476+
// OpenTelemetry instrumentation needs the passed in Server to determine if
1477+
// methods are registered in different handle calls in to record metrics.
1478+
// This tag RPC call context gets passed into every handle call, so can
1479+
// assert once here, since it maps to all the handle RPC calls that come
1480+
// after. These internal calls will be how the OpenTelemetry instrumentation
1481+
// component accesses this server and the subsequent helper on the server.
1482+
server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx)
1483+
if server == nil {
1484+
t.Errorf("stats handler received ctx has no server present")
1485+
}
1486+
isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)
1487+
// /s/m and s/m are valid.
1488+
if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") {
1489+
t.Errorf("UnaryCall should be a registered method according to server")
1490+
}
1491+
if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") {
1492+
t.Errorf("FullDuplexCall should be a registered method according to server")
1493+
}
1494+
if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") {
1495+
t.Errorf("DoesNotExistCall should not be a registered method according to server")
1496+
}
1497+
if isRegisteredMethod(server, "/unknownService/UnaryCall") {
1498+
t.Errorf("/unknownService/UnaryCall should not be a registered method according to server")
1499+
}
1500+
wg.Done()
1501+
return ctx
1502+
},
1503+
}
1504+
ss := &stubserver.StubServer{
1505+
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
1506+
return &testpb.SimpleResponse{}, nil
1507+
},
1508+
}
1509+
if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil {
1510+
t.Fatalf("Error starting endpoint server: %v", err)
1511+
}
1512+
defer ss.Stop()
1513+
1514+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
1515+
defer cancel()
1516+
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
1517+
t.Fatalf("Unexpected error from UnaryCall: %v", err)
1518+
}
1519+
wg.Wait()
1520+
}

0 commit comments

Comments
 (0)