Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions grpcreflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,31 +314,39 @@ type ExtensionResolver interface {
}

type fileDescriptorNameSet struct {
names map[protoreflect.FullName]struct{}
names map[string]struct{}
}

func (s *fileDescriptorNameSet) Insert(fd protoreflect.FileDescriptor) {
func (s *fileDescriptorNameSet) Insert(path string) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fileDescriptorNameSet use to take an fd. It looks like it changed to a path for testing. Maybe could keep the fd and the tests reference the fd too to make it clear that FullName wasn't unique. Otherwise this is a generic string set type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this could be a set of type pather interface { Path() string }? I do like the fact that the existing code keeps the set type a little higher-level - it'd be nice to preserve that. In tests, we could implement the pather interface with type path string. That also leaves a clear place to document that FullName isn't unique.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. I just changed the signatures back to protoreflect.FileDescriptor and then added a dummyFile implementation that the test can use for this.

if s.names == nil {
s.names = make(map[protoreflect.FullName]struct{}, 1)
s.names = make(map[string]struct{}, 1)
}
s.names[fd.FullName()] = struct{}{}
s.names[path] = struct{}{}
}

func (s *fileDescriptorNameSet) Contains(fd protoreflect.FileDescriptor) bool {
_, ok := s.names[fd.FullName()]
func (s *fileDescriptorNameSet) Contains(path string) bool {
_, ok := s.names[path]
return ok
}

func fileDescriptorWithDependencies(fd protoreflect.FileDescriptor, sent *fileDescriptorNameSet) ([][]byte, error) {
func fileDescriptorWithDependencies(rootFile protoreflect.FileDescriptor, sent *fileDescriptorNameSet) ([][]byte, error) {
if rootFile.IsPlaceholder() {
// A placeholder is used when a dependency is missing. If a placeholder is all we have
// then we don't actually have anything.
return nil, protoregistry.NotFound
}
results := make([][]byte, 0, 1)
queue := []protoreflect.FileDescriptor{fd}
queue := []protoreflect.FileDescriptor{rootFile}
for len(queue) > 0 {
curr := queue[0]
queue = queue[1:]
if len(results) == 0 || !sent.Contains(curr) { // always send root fd
if curr.IsPlaceholder() {
continue // don't bother serializing placeholders
}
if len(results) == 0 || !sent.Contains(curr.Path()) { // always send root fd
// Mark as sent immediately. If we hit an error marshaling below, there's
// no point trying again later.
sent.Insert(curr)
sent.Insert(curr.Path())
encoded, err := proto.Marshal(protodesc.ToFileDescriptorProto(curr))
if err != nil {
return nil, err
Expand Down
203 changes: 198 additions & 5 deletions grpcreflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,20 @@ import (
"context"
"net/http"
"net/http/httptest"
"reflect"
"sort"
"testing"

"connectrpc.com/connect"
_ "connectrpc.com/grpcreflect/internal/gen/go/connect/reflecttest/v1"
reflectionv1 "connectrpc.com/grpcreflect/internal/gen/go/connectext/grpc/reflection/v1"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/descriptorpb"
)

const actualServiceName = "connectext.grpc.reflection.v1.ServerReflection"
Expand Down Expand Up @@ -87,6 +93,7 @@ func testReflector(t *testing.T, reflector *Reflector, servicePath string) {
assertFileDescriptorResponseContains := func(
tb testing.TB,
req *reflectionv1.ServerReflectionRequest,
numFiles int,
substring string,
) {
tb.Helper()
Expand All @@ -102,8 +109,8 @@ func testReflector(t *testing.T, reflector *Reflector, servicePath string) {
tb.Fatal("got nil FileDescriptorResponse")
return // convinces staticcheck that remaining code is unreachable
}
if len(fds.FileDescriptorProto) != 1 {
tb.Fatalf("got %d FileDescriptorProtos, expected 1", len(fds.FileDescriptorProto))
if len(fds.FileDescriptorProto) != numFiles {
tb.Fatalf("got %d FileDescriptorProtos, expected %d", len(fds.FileDescriptorProto), numFiles)
}
if !bytes.Contains(fds.FileDescriptorProto[0], []byte(substring)) {
tb.Fatalf(
Expand Down Expand Up @@ -171,7 +178,7 @@ func testReflector(t *testing.T, reflector *Reflector, servicePath string) {
FileByFilename: "connectext/grpc/reflection/v1/reflection.proto",
},
}
assertFileDescriptorResponseContains(t, req, reflectionRequestFQN)
assertFileDescriptorResponseContains(t, req, 1, reflectionRequestFQN)
})
t.Run("file_by_filename_missing", func(t *testing.T) {
t.Parallel()
Expand All @@ -191,7 +198,7 @@ func testReflector(t *testing.T, reflector *Reflector, servicePath string) {
FileContainingSymbol: reflectionRequestFQN,
},
}
assertFileDescriptorResponseContains(t, req, "reflection.proto")
assertFileDescriptorResponseContains(t, req, 1, "reflection.proto")
})
t.Run("file_containing_symbol_missing", func(t *testing.T) {
t.Parallel()
Expand All @@ -214,7 +221,8 @@ func testReflector(t *testing.T, reflector *Reflector, servicePath string) {
},
},
}
assertFileDescriptorResponseContains(t, req, "reflecttest_ext.proto")
// We expect two files here: both reflecttest_ext.proto and its dependency, reflecttest.proto
assertFileDescriptorResponseContains(t, req, 2, "reflecttest_ext.proto")
})
t.Run("file_containing_extension_missing", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -293,3 +301,188 @@ func testReflector(t *testing.T, reflector *Reflector, servicePath string) {
assertFileDescriptorResponseNotFound(t, req)
})
}

func TestFileDescriptorWithDependencies(t *testing.T) {
t.Parallel()

depFile, err := protodesc.NewFile(
&descriptorpb.FileDescriptorProto{
Name: proto.String("dep.proto"),
}, nil,
)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

deps := &protoregistry.Files{}
if err := deps.RegisterFile(depFile); err != nil {
t.Fatalf("unexpected error: %s", err)
}

rootFileProto := &descriptorpb.FileDescriptorProto{
Name: proto.String("root.proto"),
Dependency: []string{
"google/protobuf/descriptor.proto",
"connect/reflecttest/v1/reflecttest_ext.proto",
"dep.proto",
},
}

// dep.proto is in deps; the other imports come from protoregistry.GlobalFiles
resolver := &combinedResolver{first: protoregistry.GlobalFiles, second: deps}
rootFile, err := protodesc.NewFile(rootFileProto, resolver)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

// Create a file hierarchy that contains a placeholder for dep.proto
placeholderDep := placeholderFile{depFile}
placeholderDeps := &protoregistry.Files{}
if err := placeholderDeps.RegisterFile(placeholderDep); err != nil {
t.Fatalf("unexpected error: %s", err)
}
resolver = &combinedResolver{first: protoregistry.GlobalFiles, second: placeholderDeps}

rootFileHasPlaceholderDep, err := protodesc.NewFile(rootFileProto, resolver)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

rootFileIsPlaceholder := placeholderFile{rootFile}

// Full transitive dependency graph of root.proto includes five files:
// - root.proto
// - google/protobuf/descriptor.proto
// - connect/reflecttest/v1/reflecttest_ext.proto
// - connect/reflecttest/v1/reflecttest.proto
// - dep.proto

testCases := []struct {
name string
sent []string
root protoreflect.FileDescriptor
expect []string
}{
{
name: "send_all",
root: rootFile,
// expect full transitive closure
expect: []string{
"root.proto",
"google/protobuf/descriptor.proto",
"connect/reflecttest/v1/reflecttest_ext.proto",
"connect/reflecttest/v1/reflecttest.proto",
"dep.proto",
},
},
{
name: "already_sent",
sent: []string{
"root.proto",
"google/protobuf/descriptor.proto",
"connect/reflecttest/v1/reflecttest_ext.proto",
"connect/reflecttest/v1/reflecttest.proto",
"dep.proto",
},
root: rootFile,
// expect only the root to be re-sent
expect: []string{"root.proto"},
},
{
name: "some_already_sent",
sent: []string{
"connect/reflecttest/v1/reflecttest_ext.proto",
"connect/reflecttest/v1/reflecttest.proto",
},
root: rootFile,
expect: []string{
"root.proto",
"google/protobuf/descriptor.proto",
"dep.proto",
},
},
{
name: "root_is_placeholder",
root: rootFileIsPlaceholder,
// expect error, no files
},
{
name: "placeholder_skipped",
root: rootFileHasPlaceholderDep,
// dep.proto is a placeholder so is skipped
expect: []string{
"root.proto",
"google/protobuf/descriptor.proto",
"connect/reflecttest/v1/reflecttest_ext.proto",
"connect/reflecttest/v1/reflecttest.proto",
},
},
{
name: "placeholder_skipped_and_some_sent",
sent: []string{
"connect/reflecttest/v1/reflecttest_ext.proto",
"connect/reflecttest/v1/reflecttest.proto",
},
root: rootFileHasPlaceholderDep,
expect: []string{
"root.proto",
"google/protobuf/descriptor.proto",
},
},
}

for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()

sent := &fileDescriptorNameSet{}
for _, path := range testCase.sent {
sent.Insert(path)
}

descriptors, err := fileDescriptorWithDependencies(testCase.root, sent)
if len(testCase.expect) == 0 {
// if we're not expecting any files then we're expecting an error
if err == nil {
t.Fatalf("expecting an error; instead got %d files", len(descriptors))
}
return
}

checkDescriptorResults(t, descriptors, testCase.expect)
})
}
}

func checkDescriptorResults(t *testing.T, descriptors [][]byte, expect []string) {
t.Helper()
if len(descriptors) != len(expect) {
t.Errorf("expected result to contain %d descriptor(s); instead got %d", len(expect), len(descriptors))
}
names := map[string]struct{}{}
for i, desc := range descriptors {
var descProto descriptorpb.FileDescriptorProto
if err := proto.Unmarshal(desc, &descProto); err != nil {
t.Fatalf("could not unmarshal descriptor result #%d", i+1)
}
names[descProto.GetName()] = struct{}{}
}
actual := make([]string, 0, len(names))
for name := range names {
actual = append(actual, name)
}
sort.Strings(actual)
sort.Strings(expect)
if !reflect.DeepEqual(actual, expect) {
t.Fatalf("expected file descriptors for %v; instead got %v", expect, actual)
}
}

type placeholderFile struct {
protoreflect.FileDescriptor
}

func (placeholderFile) IsPlaceholder() bool {
return true
}