Skip to content

Commit 0123eaa

Browse files
committed
fix: fix issues with data dependency in token prediction
1 parent 4daede3 commit 0123eaa

File tree

10 files changed

+55
-149
lines changed

10 files changed

+55
-149
lines changed

core/llm_token_ratelimit/constant.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,5 @@ const (
8484

8585
// ================================= OpenAIEncoder ============================
8686
const (
87-
TokenEncoderKeyFormat string = "%s:token-encoder:%s:%s" // redisRatelimitKey, provider, model
87+
TokenEncoderKeyFormat string = "{shard-%s}:token-encoder:%s:%s:%s" // hashTag, provider, model, redisRatelimitKey
8888
)

core/llm_token_ratelimit/ratelimit_checker.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (c *PETAChecker) checkLimitKey(ctx *Context, rule *MatchedRule) bool {
127127
prompts = reqInfos.Prompts
128128
}
129129

130-
estimatedToken, err := c.countTokens(ctx, prompts, rule)
130+
length, err := c.countTokens(ctx, prompts, rule)
131131
if err != nil {
132132
logging.Error(err, "failed to count tokens in llm_token_ratelimit.PETAChecker.checkLimitKey()",
133133
"requestID", ctx.Get(KeyRequestID),
@@ -137,9 +137,10 @@ func (c *PETAChecker) checkLimitKey(ctx *Context, rule *MatchedRule) bool {
137137

138138
slidingWindowKey := fmt.Sprintf(PETASlidingWindowKeyFormat, generateHash(rule.LimitKey), rule.LimitKey)
139139
tokenBucketKey := fmt.Sprintf(PETATokenBucketKeyFormat, generateHash(rule.LimitKey), rule.LimitKey)
140+
tokenEncoderKey := fmt.Sprintf(TokenEncoderKeyFormat, generateHash(rule.LimitKey), rule.Encoding.Provider.String(), rule.Encoding.Model, rule.LimitKey)
140141

141-
keys := []string{slidingWindowKey, tokenBucketKey}
142-
args := []interface{}{estimatedToken, util.CurrentTimeMillis(), rule.TokenSize, rule.TimeWindow * 1000, generateRandomString(PETARandomStringLength)}
142+
keys := []string{slidingWindowKey, tokenBucketKey, tokenEncoderKey}
143+
args := []interface{}{length, util.CurrentTimeMillis(), rule.TokenSize, rule.TimeWindow * 1000, generateRandomString(PETARandomStringLength)}
143144
response, err := globalRedisClient.Eval(globalPETAWithholdScript, keys, args...)
144145
if err != nil {
145146
logging.Error(err, "failed to execute redis script in llm_token_ratelimit.PETAChecker.checkLimitKey()",
@@ -148,14 +149,23 @@ func (c *PETAChecker) checkLimitKey(ctx *Context, rule *MatchedRule) bool {
148149
return true
149150
}
150151
result := parseRedisResponse(ctx, response)
151-
if result == nil || len(result) != 2 {
152+
if result == nil || len(result) != 4 {
152153
logging.Error(errors.New("invalid redis response"),
153154
"invalid redis response in llm_token_ratelimit.PETAChecker.checkLimitKey()",
154155
"response", response,
155156
"requestID", ctx.Get(KeyRequestID),
156157
)
157158
return true
158159
}
160+
logging.Info("[LLMTokenRateLimit] estimated infos",
161+
"limitKey", rule.LimitKey,
162+
"current_capacity", result[0],
163+
"waiting_time(ms)", result[1],
164+
"estimated_token", result[2],
165+
"difference", result[3],
166+
"tokenization_length", length,
167+
"requestID", ctx.Get(KeyRequestID),
168+
)
159169

160170
// TODO: add waiting and timeout callback
161171
waitingTime := result[1]
@@ -174,7 +184,7 @@ func (c *PETAChecker) checkLimitKey(ctx *Context, rule *MatchedRule) bool {
174184
return false
175185
}
176186
ctx.Set(KeyResponseHeaders, responseHeader)
177-
c.cacheEstimatedToken(rule, estimatedToken)
187+
c.cacheEstimatedToken(rule, result[2])
178188
return true
179189
}
180190

@@ -203,7 +213,7 @@ func (c *PETAChecker) countTokens(ctx *Context, prompts []string, rule *MatchedR
203213
return 0, fmt.Errorf("unknown count strategy: %s", rule.CountStrategy.String())
204214
}
205215

206-
func (c *PETAChecker) cacheEstimatedToken(rule *MatchedRule, count int) {
216+
func (c *PETAChecker) cacheEstimatedToken(rule *MatchedRule, count int64) {
207217
if c == nil || rule == nil {
208218
return
209219
}

core/llm_token_ratelimit/rule_manager.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,11 @@ func logRuleUpdate(m map[string][]*Rule) {
197197
logging.Info("[LLMTokenRateLimit] rules were cleared")
198198
} else {
199199
var builder strings.Builder
200-
for _, r := range rs {
200+
for i, r := range rs {
201201
builder.WriteString(r.String())
202+
if i != len(rs)-1 {
203+
builder.WriteString(", ")
204+
}
202205
}
203206
logging.Info("[LLMTokenRateLimit] rules were loaded",
204207
"rules", builder.String(),

core/llm_token_ratelimit/rule_matcher.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ type MatchedRule struct {
2828
CountStrategy CountStrategy
2929
// PETA
3030
Encoding TokenEncoding
31-
EstimatedToken int
31+
EstimatedToken int64
3232
}
3333

3434
type MatchedRuleCollector interface {

core/llm_token_ratelimit/script/peta/correct.lua

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
-- limitations under the License.
1414
-- KEYS[1]: Sliding Window Key ("{shard-<hashtag>}:sliding-window:<redisRatelimitKey>")
1515
-- KEYS[2]: Token Bucket Key ("{shard-<hashtag>}:token-bucket:<redisRatelimitKey>")
16+
-- KEYS[3]: Token Encoder Key ("{shard-<hashtag>}:token-encoder:<provider>:<model>:<redisRatelimitKey>")
1617
-- ARGV[1]: Estimated token consumption
1718
-- ARGV[2]: Current timestamp (milliseconds)
1819
-- ARGV[3]: Token bucket capacity
@@ -48,6 +49,7 @@ end
4849

4950
local sliding_window_key = tostring(KEYS[1])
5051
local token_bucket_key = tostring(KEYS[2])
52+
local token_encoder_key = tostring(KEYS[3])
5153

5254
local estimated = tonumber(ARGV[1])
5355
local current_timestamp = tonumber(ARGV[2])
@@ -86,7 +88,14 @@ if released_tokens > 0 then -- Expired tokens exist, attempt to replenish new to
8688
-- Immediately replenish new tokens
8789
redis.call('HSET', token_bucket_key, 'capacity', current_capacity)
8890
end
89-
91+
-- Update the difference from the token encoder
92+
local difference = actual - estimated
93+
local ttl = redis.call('PTTL', token_encoder_key)
94+
if ttl < 0 then
95+
redis.call('SET', token_encoder_key, difference, 'PX', window_size + 5000)
96+
else
97+
redis.call('INCRBY', token_encoder_key, difference)
98+
end
9099
-- Correction result for reservation
91100
local correct_result = 0
92101
if estimated < 0 or actual < 0 then
@@ -130,5 +139,6 @@ end
130139
-- Set expiration time to window size plus 5 seconds buffer
131140
redis.call('PEXPIRE', sliding_window_key, window_size + 5000)
132141
redis.call('PEXPIRE', token_bucket_key, window_size + 5000)
142+
redis.call('PEXPIRE', token_encoder_key, window_size + 5000)
133143

134144
return {correct_result}

core/llm_token_ratelimit/script/peta/withhold.lua

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
-- limitations under the License.
1414
-- KEYS[1]: Sliding Window Key ("{shard-<hashtag>}:sliding-window:<redisRatelimitKey>")
1515
-- KEYS[2]: Token Bucket Key ("{shard-<hashtag>}:token-bucket:<redisRatelimitKey>")
16+
-- KEYS[3]: Token Encoder Key ("{shard-<hashtag>}:token-encoder:<provider>:<model>:<redisRatelimitKey>")
1617
-- ARGV[1]: Estimated token consumption
1718
-- ARGV[2]: Current timestamp (milliseconds)
1819
-- ARGV[3]: Token bucket capacity
@@ -30,6 +31,7 @@ end
3031

3132
local sliding_window_key = tostring(KEYS[1])
3233
local token_bucket_key = tostring(KEYS[2])
34+
local token_encoder_key = tostring(KEYS[3])
3335

3436
local estimated = tonumber(ARGV[1])
3537
local current_timestamp = tonumber(ARGV[2])
@@ -69,6 +71,18 @@ if released_tokens > 0 then -- Expired tokens exist, attempt to replenish new to
6971
-- Immediately replenish new tokens
7072
redis.call('HSET', token_bucket_key, 'capacity', current_capacity)
7173
end
74+
-- Plus the difference from the token encoder if it exists
75+
local ttl = redis.call('PTTL', token_encoder_key)
76+
local difference = tonumber(redis.call('GET', token_encoder_key))
77+
if ttl < 0 then
78+
difference = 0
79+
else
80+
if difference + estimated >= 0 then
81+
estimated = estimated + difference
82+
else
83+
redis.call('SET', token_encoder_key, 0, 'PX', window_size + 5000)
84+
end
85+
end
7286
-- Check if the request can be satisfied
7387
if max_capacity < estimated or estimated < 0 then -- If max capacity is less than estimated consumption or estimated is less than 0, return -1 indicating rejection
7488
waiting_time = -1
@@ -91,5 +105,6 @@ end
91105
-- Set expiration time to window size plus 5 seconds buffer
92106
redis.call('PEXPIRE', sliding_window_key, window_size + 5000)
93107
redis.call('PEXPIRE', token_bucket_key, window_size + 5000)
108+
redis.call('PEXPIRE', token_encoder_key, window_size + 5000)
94109

95-
return {current_capacity, waiting_time}
110+
return {current_capacity, waiting_time, estimated, difference}

core/llm_token_ratelimit/script/token_encoder/query.lua

Lines changed: 0 additions & 32 deletions
This file was deleted.

core/llm_token_ratelimit/script/token_encoder/update.lua

Lines changed: 0 additions & 27 deletions
This file was deleted.

core/llm_token_ratelimit/token_encoder.go

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ var (
3030
tokenEncoderMapRWMux = &sync.RWMutex{}
3131
)
3232

33-
//go:embed script/token_encoder/update.lua
34-
var globalTokenEncoderUpdateScript string
35-
3633
type TokenEncoder interface {
3734
CountTokens(ctx *Context, prompts []string, rule *MatchedRule) (int, error)
3835
}
@@ -62,9 +59,6 @@ func LookupTokenEncoder(ctx *Context, encoding TokenEncoding) TokenEncoder {
6259
}
6360

6461
// ================================= OpenAIEncoder ====================================
65-
//
66-
//go:embed script/token_encoder/query.lua
67-
var globalTokenEncoderQueryScript string
6862

6963
type OpenAIEncoder struct {
7064
Model string
@@ -107,39 +101,5 @@ func (e *OpenAIEncoder) CountTokens(ctx *Context, prompts []string, rule *Matche
107101
builder.WriteString(prompt)
108102
}
109103
token := e.Encoder.Encode(builder.String(), nil, nil)
110-
if len(token) > 0 {
111-
estimatedToken, err := e.countTokens(ctx, rule, len(token))
112-
if err != nil {
113-
return 0, err
114-
}
115-
return estimatedToken, nil
116-
}
117-
return 0, nil
118-
}
119-
120-
func (e *OpenAIEncoder) countTokens(ctx *Context, rule *MatchedRule, tokenization int) (int, error) {
121-
if e == nil {
122-
return 0, fmt.Errorf("OpenAIEncoder is nil")
123-
}
124-
key := fmt.Sprintf(TokenEncoderKeyFormat, rule.LimitKey, OpenAIEncoderProvider.String(), e.Model)
125-
126-
keys := []string{key}
127-
args := []interface{}{tokenization, rule.TimeWindow * 1000}
128-
129-
response, err := globalRedisClient.Eval(globalTokenEncoderQueryScript, keys, args...)
130-
if err != nil {
131-
return 0, err
132-
}
133-
result := parseRedisResponse(ctx, response)
134-
if result == nil || len(result) != 2 {
135-
return 0, fmt.Errorf("unexpected redis response: %v", response)
136-
}
137-
138-
logging.Info("[LLMTokenRateLimit] estimated token",
139-
"limitKey", rule.LimitKey,
140-
"estimatedToken", result[0],
141-
"difference", result[1],
142-
"requestID", ctx.Get(KeyRequestID),
143-
)
144-
return int(result[0]), nil
104+
return len(token), nil
145105
}

core/llm_token_ratelimit/token_updater.go

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,18 @@ func (u *PETAUpdater) updateLimitKey(ctx *Context, rule *MatchedRule, infos *Use
112112
return
113113
}
114114
actualToken := calculator.Calculate(ctx, infos)
115-
logging.Info("[LLMTokenRateLimit] actual token",
115+
logging.Info("[LLMTokenRateLimit] actual infos",
116116
"limitKey", rule.LimitKey,
117-
"actualToken", actualToken,
117+
"estimated_token", rule.EstimatedToken,
118+
"actual_token", actualToken,
118119
"requestID", ctx.Get(KeyRequestID),
119120
)
120121

121122
slidingWindowKey := fmt.Sprintf(PETASlidingWindowKeyFormat, generateHash(rule.LimitKey), rule.LimitKey)
122123
tokenBucketKey := fmt.Sprintf(PETATokenBucketKeyFormat, generateHash(rule.LimitKey), rule.LimitKey)
124+
tokenEncoderKey := fmt.Sprintf(TokenEncoderKeyFormat, generateHash(rule.LimitKey), rule.Encoding.Provider.String(), rule.Encoding.Model, rule.LimitKey)
123125

124-
keys := []string{slidingWindowKey, tokenBucketKey}
126+
keys := []string{slidingWindowKey, tokenBucketKey, tokenEncoderKey}
125127
args := []interface{}{rule.EstimatedToken, util.CurrentTimeMillis(), rule.TokenSize, rule.TimeWindow * 1000, actualToken, generateRandomString(PETARandomStringLength)}
126128
response, err := globalRedisClient.Eval(globalPETACorrectScript, keys, args...)
127129
if err != nil {
@@ -149,39 +151,4 @@ func (u *PETAUpdater) updateLimitKey(ctx *Context, rule *MatchedRule, infos *Use
149151
)
150152
return
151153
}
152-
u.updateDifference(ctx, rule, actualToken-rule.EstimatedToken)
153-
}
154-
155-
func (u *PETAUpdater) updateDifference(ctx *Context, rule *MatchedRule, difference int) {
156-
if u == nil {
157-
return
158-
}
159-
key := fmt.Sprintf(TokenEncoderKeyFormat, rule.LimitKey, rule.Encoding.Provider.String(), rule.Encoding.Model)
160-
161-
keys := []string{key}
162-
args := []interface{}{difference, rule.TimeWindow * 1000}
163-
164-
response, err := globalRedisClient.Eval(globalTokenEncoderUpdateScript, keys, args...)
165-
if err != nil {
166-
logging.Error(err, "failed to update the difference in llm_token_ratelimit.PETAUpdater.updateDifference()",
167-
"key", key,
168-
"difference", difference,
169-
"requestID", ctx.Get(KeyRequestID),
170-
)
171-
return
172-
}
173-
result := parseRedisResponse(ctx, response)
174-
if result == nil || len(result) != 1 {
175-
logging.Error(errors.New("invalid redis response"),
176-
"invalid redis response in llm_token_ratelimit.PETAUpdater.updateDifference()",
177-
"response", response,
178-
"requestID", ctx.Get(KeyRequestID),
179-
)
180-
return
181-
}
182-
logging.Info("[LLMTokenRateLimit] successfully update the difference in llm_token_ratelimit.PETAUpdater.updateDifference()",
183-
"key", key,
184-
"difference", result[0],
185-
"requestID", ctx.Get(KeyRequestID),
186-
)
187154
}

0 commit comments

Comments
 (0)