|
1 | 1 | package rpc |
2 | 2 |
|
3 | 3 | import ( |
4 | | - gogoproto "github.com/gogo/protobuf/proto" |
5 | | - "github.com/golang/protobuf/proto" //nolint:staticcheck |
6 | 4 | "google.golang.org/grpc/encoding" |
| 5 | + "google.golang.org/grpc/mem" |
7 | 6 | ) |
8 | 7 |
|
9 | 8 | const name = "proto" |
10 | 9 |
|
11 | | -type codec struct{} |
| 10 | +type gogoprotoMessage interface { |
| 11 | + MarshalToSizedBuffer([]byte) (int, error) |
| 12 | + Unmarshal([]byte) error |
| 13 | + ProtoSize() int |
| 14 | +} |
| 15 | + |
| 16 | +var pool = mem.DefaultBufferPool() |
| 17 | + |
| 18 | +type codec struct { |
| 19 | + fallback encoding.CodecV2 |
| 20 | +} |
12 | 21 |
|
13 | | -var _ encoding.Codec = codec{} |
| 22 | +var _ encoding.CodecV2 = &codec{} |
14 | 23 |
|
15 | 24 | func init() { |
16 | | - encoding.RegisterCodec(codec{}) |
| 25 | + encoding.RegisterCodecV2(&codec{ |
| 26 | + fallback: encoding.GetCodecV2(name), |
| 27 | + }) |
17 | 28 | } |
18 | 29 |
|
19 | | -func (codec) Marshal(v interface{}) ([]byte, error) { |
20 | | - if m, ok := v.(gogoproto.Marshaler); ok { |
21 | | - return m.Marshal() |
| 30 | +func (c *codec) Marshal(v any) (mem.BufferSlice, error) { |
| 31 | + if m, ok := v.(gogoprotoMessage); ok { |
| 32 | + size := m.ProtoSize() |
| 33 | + if mem.IsBelowBufferPoolingThreshold(size) { |
| 34 | + buf := make([]byte, size) |
| 35 | + if _, err := m.MarshalToSizedBuffer(buf[:size]); err != nil { |
| 36 | + return nil, err |
| 37 | + } |
| 38 | + return mem.BufferSlice{mem.SliceBuffer(buf)}, nil |
| 39 | + } |
| 40 | + |
| 41 | + buf := pool.Get(size) |
| 42 | + if _, err := m.MarshalToSizedBuffer((*buf)[:size]); err != nil { |
| 43 | + pool.Put(buf) |
| 44 | + return nil, err |
| 45 | + } |
| 46 | + return mem.BufferSlice{mem.NewBuffer(buf, pool)}, nil |
22 | 47 | } |
23 | | - return proto.Marshal(v.(proto.Message)) |
| 48 | + return c.fallback.Marshal(v) |
24 | 49 | } |
25 | 50 |
|
26 | | -func (codec) Unmarshal(data []byte, v interface{}) error { |
27 | | - if m, ok := v.(gogoproto.Unmarshaler); ok { |
28 | | - return m.Unmarshal(data) |
| 51 | +func (c *codec) Unmarshal(data mem.BufferSlice, v any) error { |
| 52 | + if m, ok := v.(gogoprotoMessage); ok { |
| 53 | + buf := data.MaterializeToBuffer(pool) |
| 54 | + defer buf.Free() |
| 55 | + return m.Unmarshal(buf.ReadOnlyData()) |
29 | 56 | } |
30 | | - return proto.Unmarshal(data, v.(proto.Message)) |
| 57 | + return c.fallback.Unmarshal(data, v) |
31 | 58 | } |
32 | 59 |
|
33 | | -func (codec) Name() string { |
| 60 | +func (*codec) Name() string { |
34 | 61 | return name |
35 | 62 | } |
0 commit comments