|
1 | 1 | package codec |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "encoding/json" |
4 | 5 | "fmt" |
5 | 6 | "reflect" |
| 7 | + "strings" |
6 | 8 |
|
7 | 9 | "github.com/cosmos/gogoproto/proto" |
8 | 10 | gogotypes "github.com/cosmos/gogoproto/types" |
9 | 11 | "google.golang.org/protobuf/encoding/protojson" |
10 | 12 | protov2 "google.golang.org/protobuf/proto" |
| 13 | + "google.golang.org/protobuf/reflect/protoreflect" |
| 14 | + "google.golang.org/protobuf/types/dynamicpb" |
| 15 | + "google.golang.org/protobuf/types/known/durationpb" |
| 16 | + "google.golang.org/protobuf/types/known/timestamppb" |
11 | 17 |
|
12 | 18 | "cosmossdk.io/collections" |
13 | 19 | collcodec "cosmossdk.io/collections/codec" |
| 20 | + "cosmossdk.io/schema" |
14 | 21 | ) |
15 | 22 |
|
16 | 23 | // BoolValue implements a ValueCodec that saves the bool value |
@@ -51,12 +58,17 @@ type protoMessage[T any] interface { |
51 | 58 | proto.Message |
52 | 59 | } |
53 | 60 |
|
| 61 | +type protoCollValueCodec[T any] interface { |
| 62 | + collcodec.HasSchemaCodec[T] |
| 63 | + collcodec.ValueCodec[T] |
| 64 | +} |
| 65 | + |
54 | 66 | // CollValue inits a collections.ValueCodec for a generic gogo protobuf message. |
55 | 67 | func CollValue[T any, PT protoMessage[T]](cdc interface { |
56 | 68 | Marshal(proto.Message) ([]byte, error) |
57 | 69 | Unmarshal([]byte, proto.Message) error |
58 | 70 | }, |
59 | | -) collcodec.ValueCodec[T] { |
| 71 | +) protoCollValueCodec[T] { |
60 | 72 | return &collValue[T, PT]{cdc.(Codec), proto.MessageName(PT(new(T)))} |
61 | 73 | } |
62 | 74 |
|
@@ -91,6 +103,139 @@ func (c collValue[T, PT]) ValueType() string { |
91 | 103 | return "github.com/cosmos/gogoproto/" + c.messageName |
92 | 104 | } |
93 | 105 |
|
| 106 | +func (c collValue[T, PT]) SchemaCodec() (collcodec.SchemaCodec[T], error) { |
| 107 | + var ( |
| 108 | + t T |
| 109 | + pt PT |
| 110 | + ) |
| 111 | + msgName := proto.MessageName(pt) |
| 112 | + desc, err := proto.HybridResolver.FindDescriptorByName(protoreflect.FullName(msgName)) |
| 113 | + if err != nil { |
| 114 | + return collcodec.SchemaCodec[T]{}, fmt.Errorf("could not find descriptor for %s: %w", msgName, err) |
| 115 | + } |
| 116 | + schemaFields := protoCols(desc.(protoreflect.MessageDescriptor)) |
| 117 | + |
| 118 | + kind := schema.KindForGoValue(t) |
| 119 | + if err := kind.Validate(); err == nil { |
| 120 | + return collcodec.SchemaCodec[T]{ |
| 121 | + Fields: []schema.Field{{ |
| 122 | + // we don't set any name so that this can be set to a good default by the caller |
| 123 | + Name: "", |
| 124 | + Kind: kind, |
| 125 | + }}, |
| 126 | + // these can be nil because T maps directly to a schema value for this kind |
| 127 | + ToSchemaType: nil, |
| 128 | + FromSchemaType: nil, |
| 129 | + }, nil |
| 130 | + } else { |
| 131 | + return collcodec.SchemaCodec[T]{ |
| 132 | + Fields: schemaFields, |
| 133 | + ToSchemaType: func(t T) (any, error) { |
| 134 | + values := []interface{}{} |
| 135 | + msgDesc, ok := desc.(protoreflect.MessageDescriptor) |
| 136 | + if !ok { |
| 137 | + return nil, fmt.Errorf("expected message descriptor, got %T", desc) |
| 138 | + } |
| 139 | + |
| 140 | + nm := dynamicpb.NewMessage(msgDesc) |
| 141 | + bz, err := c.cdc.Marshal(any(&t).(PT)) |
| 142 | + if err != nil { |
| 143 | + return nil, err |
| 144 | + } |
| 145 | + |
| 146 | + err = c.cdc.Unmarshal(bz, nm) |
| 147 | + if err != nil { |
| 148 | + return nil, err |
| 149 | + } |
| 150 | + |
| 151 | + for _, field := range schemaFields { |
| 152 | + // Find the field descriptor by the Protobuf field name |
| 153 | + fieldDesc := msgDesc.Fields().ByName(protoreflect.Name(field.Name)) |
| 154 | + if fieldDesc == nil { |
| 155 | + return nil, fmt.Errorf("field %q not found in message %s", field.Name, desc.FullName()) |
| 156 | + } |
| 157 | + |
| 158 | + val := nm.ProtoReflect().Get(fieldDesc) |
| 159 | + |
| 160 | + // if the field is a map or list, we need to convert it to a slice of values |
| 161 | + if fieldDesc.IsList() { |
| 162 | + repeatedVals := []interface{}{} |
| 163 | + list := val.List() |
| 164 | + for i := 0; i < list.Len(); i++ { |
| 165 | + repeatedVals = append(repeatedVals, list.Get(i).Interface()) |
| 166 | + } |
| 167 | + values = append(values, repeatedVals) |
| 168 | + continue |
| 169 | + } |
| 170 | + |
| 171 | + switch fieldDesc.Kind() { |
| 172 | + case protoreflect.BoolKind: |
| 173 | + values = append(values, val.Bool()) |
| 174 | + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, |
| 175 | + protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: |
| 176 | + values = append(values, val.Int()) |
| 177 | + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, |
| 178 | + protoreflect.Fixed64Kind: |
| 179 | + values = append(values, val.Uint()) |
| 180 | + case protoreflect.FloatKind, protoreflect.DoubleKind: |
| 181 | + values = append(values, val.Float()) |
| 182 | + case protoreflect.StringKind: |
| 183 | + values = append(values, val.String()) |
| 184 | + case protoreflect.BytesKind: |
| 185 | + values = append(values, val.Bytes()) |
| 186 | + case protoreflect.EnumKind: |
| 187 | + // TODO: postgres uses the enum name, not the number |
| 188 | + values = append(values, string(fieldDesc.Enum().Values().ByNumber(val.Enum()).Name())) |
| 189 | + case protoreflect.MessageKind: |
| 190 | + msg := val.Interface().(*dynamicpb.Message) |
| 191 | + msgbz, err := c.cdc.Marshal(msg) |
| 192 | + if err != nil { |
| 193 | + return nil, err |
| 194 | + } |
| 195 | + |
| 196 | + if field.Kind == schema.TimeKind { |
| 197 | + // make it a time.Time |
| 198 | + ts := ×tamppb.Timestamp{} |
| 199 | + err = c.cdc.Unmarshal(msgbz, ts) |
| 200 | + if err != nil { |
| 201 | + return nil, fmt.Errorf("error unmarshalling timestamp: %w %x %s", err, msgbz, fieldDesc.FullName()) |
| 202 | + } |
| 203 | + values = append(values, ts.AsTime()) |
| 204 | + } else if field.Kind == schema.DurationKind { |
| 205 | + // make it a time.Duration |
| 206 | + dur := &durationpb.Duration{} |
| 207 | + err = c.cdc.Unmarshal(msgbz, dur) |
| 208 | + if err != nil { |
| 209 | + return nil, fmt.Errorf("error unmarshalling duration: %w", err) |
| 210 | + } |
| 211 | + values = append(values, dur.AsDuration()) |
| 212 | + } else { |
| 213 | + // if not a time or duration, just keep it as a JSON object |
| 214 | + // we might want to change this to include the entire object as separate fields |
| 215 | + bz, err := c.cdc.MarshalJSON(msg) |
| 216 | + if err != nil { |
| 217 | + return nil, fmt.Errorf("error marshaling message: %w", err) |
| 218 | + } |
| 219 | + |
| 220 | + values = append(values, json.RawMessage(bz)) |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + } |
| 225 | + |
| 226 | + // if there's only one value, return it directly |
| 227 | + if len(values) == 1 { |
| 228 | + return values[0], nil |
| 229 | + } |
| 230 | + return values, nil |
| 231 | + }, |
| 232 | + FromSchemaType: func(a any) (T, error) { |
| 233 | + panic("not implemented") |
| 234 | + }, |
| 235 | + }, nil |
| 236 | + } |
| 237 | +} |
| 238 | + |
94 | 239 | type protoMessageV2[T any] interface { |
95 | 240 | *T |
96 | 241 | protov2.Message |
@@ -179,3 +324,101 @@ func (c collInterfaceValue[T]) ValueType() string { |
179 | 324 | var t T |
180 | 325 | return fmt.Sprintf("%T", t) |
181 | 326 | } |
| 327 | + |
| 328 | +// SchemaCodec returns a schema codec, which will always have a single JSON field |
| 329 | +// as there is no way to know in advance the necessary fields for an interface. |
| 330 | +func (c collInterfaceValue[T]) SchemaCodec() (collcodec.SchemaCodec[T], error) { |
| 331 | + var pt T |
| 332 | + |
| 333 | + kind := schema.KindForGoValue(pt) |
| 334 | + if err := kind.Validate(); err == nil { |
| 335 | + return collcodec.SchemaCodec[T]{ |
| 336 | + Fields: []schema.Field{{ |
| 337 | + // we don't set any name so that this can be set to a good default by the caller |
| 338 | + Name: "", |
| 339 | + Kind: kind, |
| 340 | + }}, |
| 341 | + // these can be nil because T maps directly to a schema value for this kind |
| 342 | + ToSchemaType: nil, |
| 343 | + FromSchemaType: nil, |
| 344 | + }, nil |
| 345 | + } else { |
| 346 | + return collcodec.SchemaCodec[T]{ |
| 347 | + Fields: []schema.Field{{ |
| 348 | + Name: "value", |
| 349 | + Kind: schema.JSONKind, |
| 350 | + }}, |
| 351 | + ToSchemaType: func(t T) (any, error) { |
| 352 | + bz, err := c.codec.MarshalInterfaceJSON(t) |
| 353 | + if err != nil { |
| 354 | + return nil, err |
| 355 | + } |
| 356 | + |
| 357 | + return json.RawMessage(bz), nil |
| 358 | + }, |
| 359 | + FromSchemaType: func(a any) (T, error) { |
| 360 | + panic("not implemented") |
| 361 | + }, |
| 362 | + }, nil |
| 363 | + } |
| 364 | +} |
| 365 | + |
| 366 | +func protoCols(desc protoreflect.MessageDescriptor) []schema.Field { |
| 367 | + nFields := desc.Fields() |
| 368 | + cols := make([]schema.Field, 0, nFields.Len()) |
| 369 | + for i := 0; i < nFields.Len(); i++ { |
| 370 | + f := nFields.Get(i) |
| 371 | + cols = append(cols, protoCol(f)) |
| 372 | + } |
| 373 | + return cols |
| 374 | +} |
| 375 | + |
| 376 | +func protoCol(f protoreflect.FieldDescriptor) schema.Field { |
| 377 | + col := schema.Field{Name: string(f.Name())} |
| 378 | + if f.IsMap() || f.IsList() { |
| 379 | + col.Kind = schema.JSONKind |
| 380 | + col.Nullable = true |
| 381 | + } else { |
| 382 | + switch f.Kind() { |
| 383 | + case protoreflect.BoolKind: |
| 384 | + col.Kind = schema.BoolKind |
| 385 | + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: |
| 386 | + col.Kind = schema.Int32Kind |
| 387 | + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: |
| 388 | + col.Kind = schema.Int64Kind |
| 389 | + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: |
| 390 | + col.Kind = schema.Int64Kind |
| 391 | + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: |
| 392 | + col.Kind = schema.Uint64Kind |
| 393 | + case protoreflect.FloatKind: |
| 394 | + col.Kind = schema.Float32Kind |
| 395 | + case protoreflect.DoubleKind: |
| 396 | + col.Kind = schema.Float64Kind |
| 397 | + case protoreflect.StringKind: |
| 398 | + col.Kind = schema.StringKind |
| 399 | + case protoreflect.BytesKind: |
| 400 | + col.Kind = schema.BytesKind |
| 401 | + case protoreflect.EnumKind: |
| 402 | + // TODO: support enums |
| 403 | + col.Kind = schema.EnumKind |
| 404 | + // use the full name to avoid collissions |
| 405 | + col.ReferencedType = string(f.Enum().FullName()) |
| 406 | + col.ReferencedType = strings.ReplaceAll(col.ReferencedType, ".", "_") |
| 407 | + case protoreflect.MessageKind: |
| 408 | + col.Nullable = true |
| 409 | + fullName := f.Message().FullName() |
| 410 | + if fullName == "google.protobuf.Timestamp" { |
| 411 | + col.Kind = schema.TimeKind |
| 412 | + } else if fullName == "google.protobuf.Duration" { |
| 413 | + col.Kind = schema.DurationKind |
| 414 | + } else { |
| 415 | + col.Kind = schema.JSONKind |
| 416 | + } |
| 417 | + } |
| 418 | + if f.HasPresence() { |
| 419 | + col.Nullable = true |
| 420 | + } |
| 421 | + } |
| 422 | + |
| 423 | + return col |
| 424 | +} |
0 commit comments