diff --git a/server/util/proto/proto.go b/server/util/proto/proto.go index b100c084be5..e6f899a44c7 100644 --- a/server/util/proto/proto.go +++ b/server/util/proto/proto.go @@ -20,14 +20,18 @@ var Bool = gproto.Bool type Message = gproto.Message type MarshalOptions = gproto.MarshalOptions -type vtprotoMessage interface { +type VTProtoMessage interface { MarshalVT() ([]byte, error) UnmarshalVT([]byte) error CloneMessageVT() Message + + // For vtprotoCodecV2 + MarshalToSizedBufferVT(data []byte) (int, error) + SizeVT() int } func Marshal(v Message) ([]byte, error) { - vt, ok := v.(vtprotoMessage) + vt, ok := v.(VTProtoMessage) if ok { return vt.MarshalVT() } @@ -35,7 +39,7 @@ func Marshal(v Message) ([]byte, error) { } func Unmarshal(b []byte, v Message) error { - vt, ok := v.(vtprotoMessage) + vt, ok := v.(VTProtoMessage) if ok { return vt.UnmarshalVT(b) } @@ -44,7 +48,7 @@ func Unmarshal(b []byte, v Message) error { } func Clone(v Message) Message { - vt, ok := v.(vtprotoMessage) + vt, ok := v.(VTProtoMessage) if ok { return vt.CloneMessageVT() } diff --git a/server/util/vtprotocodec/BUILD b/server/util/vtprotocodec/BUILD index 591d91fae08..c0d83cfdff2 100644 --- a/server/util/vtprotocodec/BUILD +++ b/server/util/vtprotocodec/BUILD @@ -9,5 +9,6 @@ go_library( "//server/util/proto", "@org_golang_google_grpc//encoding", "@org_golang_google_grpc//encoding/proto", + "@org_golang_google_grpc//mem", ], ) diff --git a/server/util/vtprotocodec/vtprotocodec.go b/server/util/vtprotocodec/vtprotocodec.go index 3c7646be260..7569af333d2 100644 --- a/server/util/vtprotocodec/vtprotocodec.go +++ b/server/util/vtprotocodec/vtprotocodec.go @@ -1,42 +1,68 @@ package vtprotocodec import ( - "fmt" - "github.com/buildbuddy-io/buildbuddy/server/util/proto" "google.golang.org/grpc/encoding" + "google.golang.org/grpc/mem" + _ "google.golang.org/grpc/encoding/proto" // for default proto registration purposes ) const Name = "proto" -// vtprotoCodec represents a codec able to encode and decode vt enabled -// proto messages. -type vtprotoCodec struct{} +// CodecV2 implements encoding.CodecV2 and uses vtproto and default buffer pool +// to encode/decode proto messages. The implementation is heavily inspired by +// https://github.com/planetscale/vtprotobuf/pull/138 +// and https://github.com/vitessio/vitess/pull/16790. +type CodecV2 struct { + fallback encoding.CodecV2 +} -func (vtprotoCodec) Marshal(v any) ([]byte, error) { - vv, ok := v.(proto.Message) - if !ok { - return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) - } - return proto.Marshal(vv) +func (CodecV2) Name() string { + return Name } -func (vtprotoCodec) Unmarshal(data []byte, v any) error { - vv, ok := v.(proto.Message) +func (c *CodecV2) Marshal(v any) (mem.BufferSlice, error) { + m, ok := v.(proto.VTProtoMessage) if !ok { - return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) + return c.fallback.Marshal(v) } - return proto.Unmarshal(data, vv) + size := m.SizeVT() + if mem.IsBelowBufferPoolingThreshold(size) { + buf := make([]byte, size) + n, err := m.MarshalToSizedBufferVT(buf) + if err != nil { + return nil, err + } + return mem.BufferSlice{mem.SliceBuffer(buf[:n])}, nil + } + pool := mem.DefaultBufferPool() + buf := pool.Get(size) + n, err := m.MarshalToSizedBufferVT(*buf) + if err != nil { + pool.Put(buf) + return nil, err + } + *buf = (*buf)[:n] + return mem.BufferSlice{mem.NewBuffer(buf, pool)}, nil } -func (vtprotoCodec) Name() string { - return Name +func (c *CodecV2) Unmarshal(data mem.BufferSlice, v any) error { + m, ok := v.(proto.VTProtoMessage) + if !ok { + return c.fallback.Unmarshal(data, v) + } + buf := data.MaterializeToBuffer(mem.DefaultBufferPool()) + defer buf.Free() + return m.UnmarshalVT(buf.ReadOnlyData()) } // RegisterCodec registers the vtprotoCodec to encode/decode proto messages with // all gRPC clients and servers. func Register() { - encoding.RegisterCodec(vtprotoCodec{}) + encoding.RegisterCodecV2(&CodecV2{ + // the default codecv2 implemented in @org_golang_google_grpc//encoding/proto. + fallback: encoding.GetCodecV2("proto"), + }) }