Skip to content

Commit 1f301be

Browse files
authored
fix: Optimization of Rate Limiting Logic for Cluster, AI Token and WASM Plugin (#2997)
1 parent b026455 commit 1f301be

File tree

4 files changed

+109
-53
lines changed

4 files changed

+109
-53
lines changed

plugins/wasm-go/extensions/ai-token-ratelimit/main.go

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,55 @@ func init() {
4545

4646
const (
4747
RedisKeyPrefix string = "higress-token-ratelimit"
48-
// AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数
49-
AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d"
50-
// AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值
51-
AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s"
48+
// AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口
49+
AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d"
50+
// AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:限流key名称:限流key对应的实际值
51+
AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%s:%s"
5252
RequestPhaseFixedWindowScript = `
53-
local ttl = redis.call('ttl', KEYS[1])
54-
if ttl < 0 then
55-
redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2])
56-
return {ARGV[1], ARGV[1], ARGV[2]}
57-
end
58-
return {ARGV[1], redis.call('get', KEYS[1]), ttl}
53+
local current = redis.call('get', KEYS[1])
54+
local ttl = redis.call('ttl', KEYS[1])
55+
local threshold = tonumber(ARGV[1])
56+
local window = tonumber(ARGV[2])
57+
58+
-- 键不存在时,返回初始状态(计数0,窗口时间为过期时间)
59+
if not current then
60+
return {threshold, 0, window}
61+
end
62+
63+
-- 修复异常过期时间(确保窗口有效)
64+
if ttl < 0 then
65+
ttl = window
66+
end
67+
68+
-- 返回窗口状态:阈值、当前计数、剩余时间
69+
return {threshold, tonumber(current), ttl}
5970
`
6071
ResponsePhaseFixedWindowScript = `
61-
local ttl = redis.call('ttl', KEYS[1])
62-
if ttl < 0 then
63-
redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2])
64-
return {ARGV[1], ARGV[1]-ARGV[3], ARGV[2]}
65-
end
66-
return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl}
67-
`
72+
local key = KEYS[1]
73+
local threshold = tonumber(ARGV[1])
74+
local window = tonumber(ARGV[2])
75+
local added = tonumber(ARGV[3]) -- 需要累加的token数量
76+
77+
local current = tonumber(redis.call('get', key) or "0")
78+
79+
-- 只有当前计数未超过阈值时才执行累加
80+
if current <= threshold then
81+
current = redis.call('incrby', key, added)
82+
-- 第一次设置值时初始化过期时间
83+
if current == added then
84+
redis.call('expire', key, window)
85+
else
86+
-- 非首次设置时检查过期时间,确保窗口有效性
87+
local ttl = redis.call('ttl', key)
88+
if ttl < 0 then
89+
redis.call('expire', key, window)
90+
end
91+
end
92+
end
93+
94+
-- 返回当前窗口状态:阈值、当前计数、剩余时间
95+
return {threshold, current, redis.call('ttl', key)}
96+
`
6897

6998
LimitRedisContextKey = "LimitRedisContext"
7099

@@ -107,7 +136,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
107136

108137
if cfg.GlobalThreshold != nil {
109138
// 全局限流模式
110-
limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count)
139+
limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow)
111140
count = cfg.GlobalThreshold.Count
112141
timeWindow = cfg.GlobalThreshold.TimeWindow
113142
} else {
@@ -118,7 +147,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
118147
return types.ActionContinue
119148
}
120149

121-
limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val)
150+
limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, ruleItem.Key, val)
122151
count = configItem.Count
123152
timeWindow = configItem.TimeWindow
124153
}
@@ -139,12 +168,15 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo
139168
proxywasm.ResumeHttpRequest()
140169
return
141170
}
171+
172+
// 获取限流结果
173+
threshold, current, ttl := resultArray[0].Integer(), resultArray[1].Integer(), resultArray[2].Integer()
142174
context := LimitContext{
143-
count: resultArray[0].Integer(),
144-
remaining: resultArray[1].Integer(),
145-
reset: resultArray[2].Integer(),
175+
count: threshold,
176+
remaining: threshold - current,
177+
reset: ttl,
146178
}
147-
if context.remaining < 0 {
179+
if current > threshold {
148180
// 触发限流
149181
ctx.SetUserAttribute("token_ratelimit_status", "limited")
150182
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)

plugins/wasm-go/extensions/ai-token-ratelimit/main_test.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ func TestOnHttpRequestHeaders(t *testing.T) {
291291
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
292292

293293
// 模拟 Redis 调用响应(允许请求)
294-
// 返回 [count, remaining, ttl] 格式
295-
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
294+
// 返回 [threshold, current, ttl] 格式
295+
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
296296
host.CallOnRedisCall(0, resp)
297297

298298
host.CompleteHttp()
@@ -316,7 +316,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
316316
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
317317

318318
// 模拟 Redis 调用响应(允许请求)
319-
resp := test.CreateRedisRespArray([]interface{}{100, 99, 60})
319+
resp := test.CreateRedisRespArray([]interface{}{100, 1, 60})
320320
host.CallOnRedisCall(0, resp)
321321

322322
host.CompleteHttp()
@@ -339,7 +339,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
339339
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
340340

341341
// 模拟 Redis 调用响应(允许请求)
342-
resp := test.CreateRedisRespArray([]interface{}{50, 49, 60})
342+
resp := test.CreateRedisRespArray([]interface{}{50, 1, 60})
343343
host.CallOnRedisCall(0, resp)
344344

345345
host.CompleteHttp()
@@ -363,7 +363,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
363363
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
364364

365365
// 模拟 Redis 调用响应(允许请求)
366-
resp := test.CreateRedisRespArray([]interface{}{200, 199, 60})
366+
resp := test.CreateRedisRespArray([]interface{}{200, 1, 60})
367367
host.CallOnRedisCall(0, resp)
368368

369369
host.CompleteHttp()
@@ -387,7 +387,7 @@ func TestOnHttpRequestHeaders(t *testing.T) {
387387
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
388388

389389
// 模拟 Redis 调用响应(允许请求)
390-
resp := test.CreateRedisRespArray([]interface{}{75, 74, 60})
390+
resp := test.CreateRedisRespArray([]interface{}{75, 1, 60})
391391
host.CallOnRedisCall(0, resp)
392392

393393
host.CompleteHttp()
@@ -410,8 +410,8 @@ func TestOnHttpRequestHeaders(t *testing.T) {
410410
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
411411

412412
// 模拟 Redis 调用响应(触发限流)
413-
// 返回 [count, remaining, ttl] 格式,remaining < 0 表示触发限流
414-
resp := test.CreateRedisRespArray([]interface{}{1000, -1, 60})
413+
// 返回 [threshold, current, ttl] 格式,current > threshold 表示触发限流
414+
resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60})
415415
host.CallOnRedisCall(0, resp)
416416

417417
// 检查是否发送了限流响应
@@ -459,7 +459,7 @@ func TestOnHttpStreamingBody(t *testing.T) {
459459
})
460460

461461
// 模拟 Redis 调用响应
462-
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
462+
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
463463
host.CallOnRedisCall(0, resp)
464464

465465
// 处理流式响应体
@@ -499,7 +499,7 @@ func TestOnHttpStreamingBody(t *testing.T) {
499499
})
500500

501501
// 模拟 Redis 调用响应
502-
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
502+
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
503503
host.CallOnRedisCall(0, resp)
504504

505505
// 处理流式响应体
@@ -537,7 +537,7 @@ func TestCompleteFlow(t *testing.T) {
537537
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
538538

539539
// 2. 模拟 Redis 调用响应
540-
resp := test.CreateRedisRespArray([]interface{}{100, 99, 60})
540+
resp := test.CreateRedisRespArray([]interface{}{100, 1, 60})
541541
host.CallOnRedisCall(0, resp)
542542

543543
// 3. 处理流式响应体

plugins/wasm-go/extensions/cluster-key-rate-limit/main.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,30 @@ func init() {
4646
const (
4747
// RedisKeyPrefix 集群限流插件在 Redis 中 key 的统一前缀
4848
RedisKeyPrefix = "higress-cluster-key-rate-limit"
49-
// ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数
50-
ClusterGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d"
51-
// ClusterRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值
52-
ClusterRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s"
49+
// ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口
50+
ClusterGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d"
51+
// ClusterRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:限流key名称:限流key对应的实际值
52+
ClusterRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%s:%s"
5353
FixedWindowScript = `
54-
local ttl = redis.call('ttl', KEYS[1])
55-
if ttl < 0 then
56-
redis.call('set', KEYS[1], ARGV[1] - 1, 'EX', ARGV[2])
57-
return {ARGV[1], ARGV[1] - 1, ARGV[2]}
58-
end
59-
return {ARGV[1], redis.call('incrby', KEYS[1], -1), ttl}
54+
local key = KEYS[1]
55+
local threshold = tonumber(ARGV[1])
56+
local window = tonumber(ARGV[2])
57+
58+
local current = tonumber(redis.call('get', key) or "0")
59+
60+
-- 只有超过阈值时才停止累加,达到阈值时仍允许(此时是最后一次允许)
61+
if current > threshold then
62+
return {threshold, current, redis.call('ttl', key)}
63+
end
64+
65+
-- 计数未超过阈值,执行累加
66+
current = redis.call('incr', key)
67+
-- 第一次累加时设置过期时间
68+
if current == 1 then
69+
redis.call('expire', key, window)
70+
end
71+
72+
return {threshold, current, redis.call('ttl', key)}
6073
`
6174

6275
LimitContextKey = "LimitContext" // 限流上下文信息
@@ -92,7 +105,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi
92105

93106
if cfg.GlobalThreshold != nil {
94107
// 全局限流模式
95-
limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count)
108+
limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow)
96109
count = cfg.GlobalThreshold.Count
97110
timeWindow = cfg.GlobalThreshold.TimeWindow
98111
} else {
@@ -103,7 +116,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi
103116
return types.ActionContinue
104117
}
105118

106-
limitKey = fmt.Sprintf(ClusterRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val)
119+
limitKey = fmt.Sprintf(ClusterRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, ruleItem.Key, val)
107120
count = configItem.Count
108121
timeWindow = configItem.TimeWindow
109122
}
@@ -118,12 +131,15 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi
118131
proxywasm.ResumeHttpRequest()
119132
return
120133
}
134+
135+
// 获取限流结果
136+
threshold, current, ttl := resultArray[0].Integer(), resultArray[1].Integer(), resultArray[2].Integer()
121137
context := LimitContext{
122-
count: resultArray[0].Integer(),
123-
remaining: resultArray[1].Integer(),
124-
reset: resultArray[2].Integer(),
138+
count: threshold,
139+
remaining: threshold - current,
140+
reset: ttl,
125141
}
126-
if context.remaining < 0 {
142+
if current > threshold {
127143
// 触发限流
128144
rejected(cfg, context)
129145
} else {

plugins/wasm-go/extensions/cluster-key-rate-limit/main_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
package main
1616

1717
import (
18-
"cluster-key-rate-limit/config"
1918
"encoding/json"
2019
"testing"
2120

21+
"cluster-key-rate-limit/config"
22+
2223
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
2324
"github.com/higress-group/wasm-go/pkg/test"
2425
"github.com/stretchr/testify/require"
@@ -527,9 +528,16 @@ func TestOnHttpRequestHeaders(t *testing.T) {
527528
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
528529

529530
// 模拟 Redis 调用响应(触发限流)
530-
resp := test.CreateRedisRespArray([]interface{}{1000, -1, 60})
531+
// 当前请求数(1001)超过阈值(1000),触发限流
532+
resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60})
531533
host.CallOnRedisCall(0, resp)
532534

535+
// 检查是否发送了限流响应
536+
localResponse := host.GetLocalResponse()
537+
require.NotNil(t, localResponse)
538+
require.Equal(t, uint32(429), localResponse.StatusCode)
539+
require.Contains(t, string(localResponse.Data), "Too many requests")
540+
533541
host.CompleteHttp()
534542
})
535543
})
@@ -641,7 +649,7 @@ func TestCompleteFlow(t *testing.T) {
641649
require.Equal(t, types.HeaderStopAllIterationAndWatermark, action)
642650

643651
// 2. 模拟 Redis 调用响应
644-
resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60})
652+
resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60})
645653
host.CallOnRedisCall(0, resp)
646654

647655
// 3. 处理响应头

0 commit comments

Comments
 (0)