Skip to content

Commit 0361ecf

Browse files
committed
feat: adapt llm_token_ratelimit component to datasource module
1 parent 1094f0d commit 0361ecf

File tree

5 files changed

+525
-45
lines changed

5 files changed

+525
-45
lines changed

core/llm_token_ratelimit/config.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package llm_token_ratelimit
1616

1717
import (
18+
"encoding/json"
1819
"errors"
1920
"fmt"
2021
"strings"
@@ -271,12 +272,12 @@ func (c *SafeConfig) GetErrorMsg() string {
271272
return c.config.ErrorMessage
272273
}
273274

274-
func (p *TokenEncoderProvider) UnmarshalYAML(unmarshal func(interface{}) error) error {
275+
func (p *TokenEncoderProvider) UnmarshalJSON(data []byte) error {
275276
if p == nil {
276277
return fmt.Errorf("token encoder provider is nil")
277278
}
278279
var str string
279-
if err := unmarshal(&str); err != nil {
280+
if err := json.Unmarshal(data, &str); err != nil {
280281
return err
281282
}
282283
switch str {
@@ -288,12 +289,12 @@ func (p *TokenEncoderProvider) UnmarshalYAML(unmarshal func(interface{}) error)
288289
return nil
289290
}
290291

291-
func (it *IdentifierType) UnmarshalYAML(unmarshal func(interface{}) error) error {
292+
func (it *IdentifierType) UnmarshalJSON(data []byte) error {
292293
if it == nil {
293294
return fmt.Errorf("identifier type is nil")
294295
}
295296
var str string
296-
if err := unmarshal(&str); err != nil {
297+
if err := json.Unmarshal(data, &str); err != nil {
297298
return err
298299
}
299300
switch str {
@@ -307,12 +308,12 @@ func (it *IdentifierType) UnmarshalYAML(unmarshal func(interface{}) error) error
307308
return nil
308309
}
309310

310-
func (ct *CountStrategy) UnmarshalYAML(unmarshal func(interface{}) error) error {
311+
func (ct *CountStrategy) UnmarshalJSON(data []byte) error {
311312
if ct == nil {
312313
return fmt.Errorf("count strategy is nil")
313314
}
314315
var str string
315-
if err := unmarshal(&str); err != nil {
316+
if err := json.Unmarshal(data, &str); err != nil {
316317
return err
317318
}
318319
switch str {
@@ -328,12 +329,12 @@ func (ct *CountStrategy) UnmarshalYAML(unmarshal func(interface{}) error) error
328329
return nil
329330
}
330331

331-
func (tu *TimeUnit) UnmarshalYAML(unmarshal func(interface{}) error) error {
332+
func (tu *TimeUnit) UnmarshalJSON(data []byte) error {
332333
if tu == nil {
333334
return fmt.Errorf("time unit is nil")
334335
}
335336
var str string
336-
if err := unmarshal(&str); err != nil {
337+
if err := json.Unmarshal(data, &str); err != nil {
337338
return err
338339
}
339340
switch str {
@@ -351,12 +352,12 @@ func (tu *TimeUnit) UnmarshalYAML(unmarshal func(interface{}) error) error {
351352
return nil
352353
}
353354

354-
func (s *Strategy) UnmarshalYAML(unmarshal func(interface{}) error) error {
355+
func (s *Strategy) UnmarshalJSON(data []byte) error {
355356
if s == nil {
356357
return fmt.Errorf("strategy is nil")
357358
}
358359
var str string
359-
if err := unmarshal(&str); err != nil {
360+
if err := json.Unmarshal(data, &str); err != nil {
360361
return err
361362
}
362363
switch str {

core/llm_token_ratelimit/config_test.go

Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,72 @@
1515
package llm_token_ratelimit
1616

1717
import (
18+
"encoding/json"
1819
"testing"
1920

2021
"gopkg.in/yaml.v3"
2122
)
2223

23-
func TestIdentifierType_UnmarshalYAML(t *testing.T) {
24+
func TestTokenEncoderProvider_UnmarshalJSON(t *testing.T) {
2425
tests := []struct {
2526
name string
26-
yamlData string
27+
jsonData string
28+
expected TokenEncoderProvider
29+
wantErr bool
30+
}{
31+
{"openai provider", `{"provider": "openai"}`, OpenAIEncoderProvider, false},
32+
{"unknown provider", `{"provider": "unknown"}`, OpenAIEncoderProvider, true},
33+
{"empty provider", `{"provider": ""}`, OpenAIEncoderProvider, true},
34+
{"number as provider", `{"provider": 0}`, OpenAIEncoderProvider, true},
35+
{"invalid number", `{"provider": 999}`, OpenAIEncoderProvider, true},
36+
}
37+
38+
for _, tt := range tests {
39+
t.Run(tt.name, func(t *testing.T) {
40+
var data struct {
41+
Provider TokenEncoderProvider `json:"provider"`
42+
}
43+
44+
err := json.Unmarshal([]byte(tt.jsonData), &data)
45+
if tt.wantErr {
46+
if err == nil {
47+
t.Errorf("Expected error but got none")
48+
}
49+
} else {
50+
if err != nil {
51+
t.Errorf("Unexpected error: %v", err)
52+
}
53+
if data.Provider != tt.expected {
54+
t.Errorf("Expected %v, got %v", tt.expected, data.Provider)
55+
}
56+
}
57+
})
58+
}
59+
}
60+
61+
func TestIdentifierType_UnmarshalJSON(t *testing.T) {
62+
tests := []struct {
63+
name string
64+
jsonData string
2765
expected IdentifierType
2866
wantErr bool
2967
}{
30-
{"all identifier", `type: all`, AllIdentifier, false},
31-
{"header identifier", `type: header`, Header, false},
32-
{"unknown identifier", `type: unknown`, AllIdentifier, true},
33-
{"empty identifier", `type: ""`, AllIdentifier, true},
34-
{"number as identifier", `type: 123`, AllIdentifier, true},
68+
{"all identifier", `{"type": "all"}`, AllIdentifier, false},
69+
{"header identifier", `{"type": "header"}`, Header, false},
70+
{"unknown identifier", `{"type": "unknown"}`, AllIdentifier, true},
71+
{"empty identifier", `{"type": ""}`, AllIdentifier, true},
72+
{"number as identifier - 0", `{"type": 0}`, AllIdentifier, true},
73+
{"number as identifier - 1", `{"type": 1}`, Header, true},
74+
{"invalid number", `{"type": 999}`, AllIdentifier, true},
3575
}
3676

3777
for _, tt := range tests {
3878
t.Run(tt.name, func(t *testing.T) {
3979
var data struct {
40-
Type IdentifierType `yaml:"type"`
80+
Type IdentifierType `json:"type"`
4181
}
4282

43-
err := yaml.Unmarshal([]byte(tt.yamlData), &data)
83+
err := json.Unmarshal([]byte(tt.jsonData), &data)
4484
if tt.wantErr {
4585
if err == nil {
4686
t.Errorf("Expected error but got none")
@@ -57,27 +97,30 @@ func TestIdentifierType_UnmarshalYAML(t *testing.T) {
5797
}
5898
}
5999

60-
func TestCountStrategy_UnmarshalYAML(t *testing.T) {
100+
func TestCountStrategy_UnmarshalJSON(t *testing.T) {
61101
tests := []struct {
62102
name string
63-
yamlData string
103+
jsonData string
64104
expected CountStrategy
65105
wantErr bool
66106
}{
67-
{"total tokens", `strategy: total-tokens`, TotalTokens, false},
68-
{"input tokens", `strategy: input-tokens`, InputTokens, false},
69-
{"output tokens", `strategy: output-tokens`, OutputTokens, false},
70-
{"unknown strategy", `strategy: unknown-tokens`, TotalTokens, true},
71-
{"empty strategy", `strategy: ""`, TotalTokens, true},
107+
{"total tokens", `{"strategy": "total-tokens"}`, TotalTokens, false},
108+
{"input tokens", `{"strategy": "input-tokens"}`, InputTokens, false},
109+
{"output tokens", `{"strategy": "output-tokens"}`, OutputTokens, false},
110+
{"unknown strategy", `{"strategy": "unknown-tokens"}`, TotalTokens, true},
111+
{"empty strategy", `{"strategy": ""}`, TotalTokens, true},
112+
{"number as strategy - 0", `{"strategy": 0}`, TotalTokens, true},
113+
{"number as strategy - 1", `{"strategy": 1}`, InputTokens, true},
114+
{"number as strategy - 2", `{"strategy": 2}`, OutputTokens, true},
72115
}
73116

74117
for _, tt := range tests {
75118
t.Run(tt.name, func(t *testing.T) {
76119
var data struct {
77-
Strategy CountStrategy `yaml:"strategy"`
120+
Strategy CountStrategy `json:"strategy"`
78121
}
79122

80-
err := yaml.Unmarshal([]byte(tt.yamlData), &data)
123+
err := json.Unmarshal([]byte(tt.jsonData), &data)
81124
if tt.wantErr {
82125
if err == nil {
83126
t.Errorf("Expected error but got none")
@@ -94,28 +137,32 @@ func TestCountStrategy_UnmarshalYAML(t *testing.T) {
94137
}
95138
}
96139

97-
func TestTimeUnit_UnmarshalYAML(t *testing.T) {
140+
func TestTimeUnit_UnmarshalJSON(t *testing.T) {
98141
tests := []struct {
99142
name string
100-
yamlData string
143+
jsonData string
101144
expected TimeUnit
102145
wantErr bool
103146
}{
104-
{"second unit", `unit: second`, Second, false},
105-
{"minute unit", `unit: minute`, Minute, false},
106-
{"hour unit", `unit: hour`, Hour, false},
107-
{"day unit", `unit: day`, Day, false},
108-
{"unknown unit", `unit: week`, Second, true},
109-
{"empty unit", `unit: ""`, Second, true},
147+
{"second unit", `{"unit": "second"}`, Second, false},
148+
{"minute unit", `{"unit": "minute"}`, Minute, false},
149+
{"hour unit", `{"unit": "hour"}`, Hour, false},
150+
{"day unit", `{"unit": "day"}`, Day, false},
151+
{"unknown unit", `{"unit": "week"}`, Second, true},
152+
{"empty unit", `{"unit": ""}`, Second, true},
153+
{"number as unit - 0", `{"unit": 0}`, Second, true},
154+
{"number as unit - 1", `{"unit": 1}`, Minute, true},
155+
{"number as unit - 2", `{"unit": 2}`, Hour, true},
156+
{"number as unit - 3", `{"unit": 3}`, Day, true},
110157
}
111158

112159
for _, tt := range tests {
113160
t.Run(tt.name, func(t *testing.T) {
114161
var data struct {
115-
Unit TimeUnit `yaml:"unit"`
162+
Unit TimeUnit `json:"unit"`
116163
}
117164

118-
err := yaml.Unmarshal([]byte(tt.yamlData), &data)
165+
err := json.Unmarshal([]byte(tt.jsonData), &data)
119166
if tt.wantErr {
120167
if err == nil {
121168
t.Errorf("Expected error but got none")
@@ -132,25 +179,28 @@ func TestTimeUnit_UnmarshalYAML(t *testing.T) {
132179
}
133180
}
134181

135-
func TestStrategy_UnmarshalYAML(t *testing.T) {
182+
func TestStrategy_UnmarshalJSON(t *testing.T) {
136183
tests := []struct {
137184
name string
138-
yamlData string
185+
jsonData string
139186
expected Strategy
140187
wantErr bool
141188
}{
142-
{"fixed window", `strategy: fixed-window`, FixedWindow, false},
143-
{"unknown strategy", `strategy: sliding-window`, FixedWindow, true},
144-
{"empty strategy", `strategy: ""`, FixedWindow, true},
189+
{"fixed window", `{"strategy": "fixed-window"}`, FixedWindow, false},
190+
{"peta strategy", `{"strategy": "peta"}`, PETA, false},
191+
{"unknown strategy", `{"strategy": "sliding-window"}`, FixedWindow, true},
192+
{"empty strategy", `{"strategy": ""}`, FixedWindow, true},
193+
{"number as strategy - 0", `{"strategy": 0}`, FixedWindow, true},
194+
{"number as strategy - 1", `{"strategy": 1}`, PETA, true},
145195
}
146196

147197
for _, tt := range tests {
148198
t.Run(tt.name, func(t *testing.T) {
149199
var data struct {
150-
Strategy Strategy `yaml:"strategy"`
200+
Strategy Strategy `json:"strategy"`
151201
}
152202

153-
err := yaml.Unmarshal([]byte(tt.yamlData), &data)
203+
err := json.Unmarshal([]byte(tt.jsonData), &data)
154204
if tt.wantErr {
155205
if err == nil {
156206
t.Errorf("Expected error but got none")

ext/datasource/helper.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/alibaba/sentinel-golang/core/flow"
2323
"github.com/alibaba/sentinel-golang/core/hotspot"
2424
"github.com/alibaba/sentinel-golang/core/isolation"
25+
"github.com/alibaba/sentinel-golang/core/llm_token_ratelimit"
2526
"github.com/alibaba/sentinel-golang/core/system"
2627
)
2728

@@ -277,3 +278,50 @@ func IsolationRulesUpdater(data interface{}) error {
277278
func NewIsolationRulesHandler(converter PropertyConverter) *DefaultPropertyHandler {
278279
return NewDefaultPropertyHandler(converter, IsolationRulesUpdater)
279280
}
281+
282+
// LLMTokenRateLimitRuleJsonArrayParser provide JSON as the default serialization for list of llm_token_ratelimit.Rule
283+
func LLMTokenRateLimitRuleJsonArrayParser(src []byte) (interface{}, error) {
284+
if valid, err := checkSrcComplianceJson(src); !valid {
285+
return nil, err
286+
}
287+
288+
rules := make([]*llm_token_ratelimit.Rule, 0, 8)
289+
if err := json.Unmarshal(src, &rules); err != nil {
290+
desc := fmt.Sprintf("Fail to convert source bytes to []*llm_token_ratelimit.Rule, err: %s", err.Error())
291+
return nil, NewError(ConvertSourceError, desc)
292+
}
293+
return rules, nil
294+
}
295+
296+
// LLMTokenRateLimitRulesUpdater load the newest []llm_token_ratelimit.Rule to downstream system component.
297+
func LLMTokenRateLimitRulesUpdater(data interface{}) error {
298+
if data == nil {
299+
return llm_token_ratelimit.ClearRules()
300+
}
301+
302+
rules := make([]*llm_token_ratelimit.Rule, 0, 8)
303+
if val, ok := data.([]llm_token_ratelimit.Rule); ok {
304+
for _, v := range val {
305+
rules = append(rules, &v)
306+
}
307+
} else if val, ok := data.([]*llm_token_ratelimit.Rule); ok {
308+
rules = val
309+
} else {
310+
return NewError(
311+
UpdatePropertyError,
312+
fmt.Sprintf("Fail to type assert data to []llm_token_ratelimit.Rule or []*llm_token_ratelimit.Rule, in fact, data: %+v", data),
313+
)
314+
}
315+
_, err := llm_token_ratelimit.LoadRules(rules)
316+
if err == nil {
317+
return nil
318+
}
319+
return NewError(
320+
UpdatePropertyError,
321+
fmt.Sprintf("%+v", err),
322+
)
323+
}
324+
325+
func NewLLMTokenRateLimitRulesHandler(converter PropertyConverter) *DefaultPropertyHandler {
326+
return NewDefaultPropertyHandler(converter, LLMTokenRateLimitRulesUpdater)
327+
}

0 commit comments

Comments
 (0)