Skip to content

Commit 9b04d30

Browse files
authored
fix(scale): Use *int for scale index (#3274)
1 parent bd68814 commit 9b04d30

5 files changed

Lines changed: 69 additions & 36 deletions

File tree

pkg/scale/decode_test.go

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,18 @@ func Test_decodeState_decodeStruct(t *testing.T) {
9292
if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
9393
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
9494
}
95-
var diff string
96-
if tt.out != nil {
97-
diff = cmp.Diff(dst, tt.out, cmpopts.IgnoreUnexported(tt.in))
98-
} else {
99-
diff = cmp.Diff(dst, tt.in, cmpopts.IgnoreUnexported(big.Int{}, tt.in, VDTValue2{}, MyStructWithIgnore{}))
100-
}
101-
if diff != "" {
102-
t.Errorf("decodeState.unmarshal() = %s", diff)
95+
96+
// assert response only if we aren't expecting an error
97+
if !tt.wantErr {
98+
var diff string
99+
if tt.out != nil {
100+
diff = cmp.Diff(dst, tt.out, cmpopts.IgnoreUnexported(tt.in))
101+
} else {
102+
diff = cmp.Diff(dst, tt.in, cmpopts.IgnoreUnexported(big.Int{}, tt.in, VDTValue2{}, MyStructWithIgnore{}))
103+
}
104+
if diff != "" {
105+
t.Errorf("decodeState.unmarshal() = %s", diff)
106+
}
103107
}
104108
})
105109
}
@@ -294,20 +298,24 @@ func Test_unmarshal_optionality(t *testing.T) {
294298
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
295299
return
296300
}
297-
var diff string
298-
if tt.out != nil {
299-
diff = cmp.Diff(
300-
reflect.ValueOf(dst).Elem().Interface(),
301-
reflect.ValueOf(tt.out).Interface(),
302-
cmpopts.IgnoreUnexported(tt.in))
303-
} else {
304-
diff = cmp.Diff(
305-
reflect.ValueOf(dst).Elem().Interface(),
306-
reflect.ValueOf(tt.in).Interface(),
307-
cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}, MyStructWithPrivate{}))
308-
}
309-
if diff != "" {
310-
t.Errorf("decodeState.unmarshal() = %s", diff)
301+
302+
// assert response only if we aren't expecting an error
303+
if !tt.wantErr {
304+
var diff string
305+
if tt.out != nil {
306+
diff = cmp.Diff(
307+
reflect.ValueOf(dst).Elem().Interface(),
308+
reflect.ValueOf(tt.out).Interface(),
309+
cmpopts.IgnoreUnexported(tt.in))
310+
} else {
311+
diff = cmp.Diff(
312+
reflect.ValueOf(dst).Elem().Interface(),
313+
reflect.ValueOf(tt.in).Interface(),
314+
cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}, MyStructWithPrivate{}))
315+
}
316+
if diff != "" {
317+
t.Errorf("decodeState.unmarshal() = %s", diff)
318+
}
311319
}
312320
}
313321
})
@@ -325,7 +333,11 @@ func Test_unmarshal_optionality_nil_case(t *testing.T) {
325333
// ignore out, since we are testing nil case
326334
// out: t.out,
327335
}
328-
ptrTest.want = []byte{0x00}
336+
337+
// for error cases, we don't need to modify the input since we need it to fail
338+
if !t.wantErr {
339+
ptrTest.want = []byte{0x00}
340+
}
329341

330342
temp := reflect.New(reflect.TypeOf(t.in))
331343
// create a new pointer to type of temp

pkg/scale/encode_test.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -597,10 +597,19 @@ var (
597597
},
598598
want: []byte{0x04, 0x01, 0x02, 0, 0, 0, 0x01},
599599
},
600+
{
601+
name: "struct_{[]byte,_int32}_with_invalid_tag",
602+
in: &struct {
603+
Foo []byte `scale:"1,invalid"`
604+
}{
605+
Foo: []byte{0x01},
606+
},
607+
wantErr: true,
608+
},
600609
{
601610
name: "struct_{[]byte,_int32,_bool}",
602611
in: struct {
603-
Baz bool `scale:"3,enum"`
612+
Baz bool `scale:"3"`
604613
Bar int32 `scale:"2"`
605614
Foo []byte `scale:"1"`
606615
}{
@@ -1073,8 +1082,12 @@ func Test_encodeState_encodeStruct(t *testing.T) {
10731082
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
10741083
t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr)
10751084
}
1076-
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1077-
t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want)
1085+
1086+
// we don't need this check for error cases
1087+
if !tt.wantErr {
1088+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1089+
t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want)
1090+
}
10781091
}
10791092
})
10801093
}
@@ -1182,8 +1195,12 @@ func Test_marshal_optionality(t *testing.T) {
11821195
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
11831196
t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr)
11841197
}
1185-
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1186-
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
1198+
1199+
// if we expect an error, we do not need to check the result
1200+
if !tt.wantErr {
1201+
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
1202+
t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want)
1203+
}
11871204
}
11881205
})
11891206
}
@@ -1195,9 +1212,6 @@ func Test_marshal_optionality_nil_cases(t *testing.T) {
11951212
t := allTests[i]
11961213
ptrTest := test{
11971214
name: t.name,
1198-
// in: t.in,
1199-
wantErr: t.wantErr,
1200-
want: t.want,
12011215
}
12021216
// create a new pointer to new zero value of t.in
12031217
temp := reflect.New(reflect.TypeOf(t.in))

pkg/scale/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ var (
2020
errBigIntIsNil = errors.New("big int is nil")
2121
ErrVaryingDataTypeNotSet = errors.New("varying data type not set")
2222
ErrUnsupportedCustomPrimitive = errors.New("unsupported type for custom primitive")
23+
ErrInvalidScaleIndex = errors.New("invalid scale index")
2324
)

pkg/scale/scale.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"reflect"
99
"sort"
10+
"strconv"
1011
"strings"
1112
"sync"
1213
)
@@ -19,7 +20,7 @@ var cache = &fieldScaleIndicesCache{
1920
// fieldScaleIndex is used to map field index to scale index
2021
type fieldScaleIndex struct {
2122
fieldIndex int
22-
scaleIndex *string
23+
scaleIndex *int
2324
}
2425
type fieldScaleIndices []fieldScaleIndex
2526

@@ -61,9 +62,14 @@ func (fsic *fieldScaleIndicesCache) fieldScaleIndices(in interface{}) (
6162
// ignore this field
6263
continue
6364
default:
65+
scaleIndex, indexErr := strconv.Atoi(tag)
66+
if indexErr != nil {
67+
err = fmt.Errorf("%w: %v", ErrInvalidScaleIndex, indexErr)
68+
return
69+
}
6470
indices = append(indices, fieldScaleIndex{
6571
fieldIndex: i,
66-
scaleIndex: &tag,
72+
scaleIndex: &scaleIndex,
6773
})
6874
}
6975
}

pkg/scale/scale_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ func Test_fieldScaleIndicesCache_fieldScaleIndices(t *testing.T) {
3535
wantIndices: fieldScaleIndices{
3636
{
3737
fieldIndex: 5,
38-
scaleIndex: newStringPtr("1"),
38+
scaleIndex: newIntPtr(1),
3939
},
4040
{
4141
fieldIndex: 3,
42-
scaleIndex: newStringPtr("2"),
42+
scaleIndex: newIntPtr(2),
4343
},
4444
{
4545
fieldIndex: 1,
46-
scaleIndex: newStringPtr("3"),
46+
scaleIndex: newIntPtr(3),
4747
},
4848
{
4949
fieldIndex: 0,

0 commit comments

Comments
 (0)