diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go index 5d5669c23d..a37f4c496a 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/main.go +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main.go @@ -45,26 +45,55 @@ func init() { const ( RedisKeyPrefix string = "higress-token-ratelimit" - // AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数 - AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d" - // AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值 - AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s" + // AiTokenGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口 + AiTokenGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d" + // AiTokenRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:限流key名称:限流key对应的实际值 + AiTokenRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%s:%s" RequestPhaseFixedWindowScript = ` - local ttl = redis.call('ttl', KEYS[1]) - if ttl < 0 then - redis.call('set', KEYS[1], ARGV[1], 'EX', ARGV[2]) - return {ARGV[1], ARGV[1], ARGV[2]} - end - return {ARGV[1], redis.call('get', KEYS[1]), ttl} + local current = redis.call('get', KEYS[1]) + local ttl = redis.call('ttl', KEYS[1]) + local threshold = tonumber(ARGV[1]) + local window = tonumber(ARGV[2]) + + -- 键不存在时,返回初始状态(计数0,窗口时间为过期时间) + if not current then + return {threshold, 0, window} + end + + -- 修复异常过期时间(确保窗口有效) + if ttl < 0 then + ttl = window + end + + -- 返回窗口状态:阈值、当前计数、剩余时间 + return {threshold, tonumber(current), ttl} ` ResponsePhaseFixedWindowScript = ` - local ttl = redis.call('ttl', KEYS[1]) - if ttl < 0 then - redis.call('set', KEYS[1], ARGV[1]-ARGV[3], 'EX', ARGV[2]) - return {ARGV[1], ARGV[1]-ARGV[3], ARGV[2]} - end - return {ARGV[1], redis.call('decrby', KEYS[1], ARGV[3]), ttl} - ` + local key = KEYS[1] + local threshold = tonumber(ARGV[1]) + local window = tonumber(ARGV[2]) + local added = tonumber(ARGV[3]) -- 需要累加的token数量 + + local current = tonumber(redis.call('get', key) or "0") + + -- 只有当前计数未超过阈值时才执行累加 + if current <= threshold then + current = redis.call('incrby', key, added) + -- 第一次设置值时初始化过期时间 + if current == added then + redis.call('expire', key, window) + else + -- 非首次设置时检查过期时间,确保窗口有效性 + local ttl = redis.call('ttl', key) + if ttl < 0 then + redis.call('expire', key, window) + end + end + end + + -- 返回当前窗口状态:阈值、当前计数、剩余时间 + return {threshold, current, redis.call('ttl', key)} + ` LimitRedisContextKey = "LimitRedisContext" @@ -107,7 +136,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo if cfg.GlobalThreshold != nil { // 全局限流模式 - limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count) + limitKey = fmt.Sprintf(AiTokenGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow) count = cfg.GlobalThreshold.Count timeWindow = cfg.GlobalThreshold.TimeWindow } else { @@ -118,7 +147,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo return types.ActionContinue } - limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val) + limitKey = fmt.Sprintf(AiTokenRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, ruleItem.Key, val) count = configItem.Count timeWindow = configItem.TimeWindow } @@ -139,12 +168,15 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.AiTokenRateLimitCo proxywasm.ResumeHttpRequest() return } + + // 获取限流结果 + threshold, current, ttl := resultArray[0].Integer(), resultArray[1].Integer(), resultArray[2].Integer() context := LimitContext{ - count: resultArray[0].Integer(), - remaining: resultArray[1].Integer(), - reset: resultArray[2].Integer(), + count: threshold, + remaining: threshold - current, + reset: ttl, } - if context.remaining < 0 { + if current > threshold { // 触发限流 ctx.SetUserAttribute("token_ratelimit_status", "limited") ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) diff --git a/plugins/wasm-go/extensions/ai-token-ratelimit/main_test.go b/plugins/wasm-go/extensions/ai-token-ratelimit/main_test.go index c51b625468..fb43373829 100644 --- a/plugins/wasm-go/extensions/ai-token-ratelimit/main_test.go +++ b/plugins/wasm-go/extensions/ai-token-ratelimit/main_test.go @@ -291,8 +291,8 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 模拟 Redis 调用响应(允许请求) - // 返回 [count, remaining, ttl] 格式 - resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60}) + // 返回 [threshold, current, ttl] 格式 + resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60}) host.CallOnRedisCall(0, resp) host.CompleteHttp() @@ -316,7 +316,7 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 模拟 Redis 调用响应(允许请求) - resp := test.CreateRedisRespArray([]interface{}{100, 99, 60}) + resp := test.CreateRedisRespArray([]interface{}{100, 1, 60}) host.CallOnRedisCall(0, resp) host.CompleteHttp() @@ -339,7 +339,7 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 模拟 Redis 调用响应(允许请求) - resp := test.CreateRedisRespArray([]interface{}{50, 49, 60}) + resp := test.CreateRedisRespArray([]interface{}{50, 1, 60}) host.CallOnRedisCall(0, resp) host.CompleteHttp() @@ -363,7 +363,7 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 模拟 Redis 调用响应(允许请求) - resp := test.CreateRedisRespArray([]interface{}{200, 199, 60}) + resp := test.CreateRedisRespArray([]interface{}{200, 1, 60}) host.CallOnRedisCall(0, resp) host.CompleteHttp() @@ -387,7 +387,7 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 模拟 Redis 调用响应(允许请求) - resp := test.CreateRedisRespArray([]interface{}{75, 74, 60}) + resp := test.CreateRedisRespArray([]interface{}{75, 1, 60}) host.CallOnRedisCall(0, resp) host.CompleteHttp() @@ -410,8 +410,8 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 模拟 Redis 调用响应(触发限流) - // 返回 [count, remaining, ttl] 格式,remaining < 0 表示触发限流 - resp := test.CreateRedisRespArray([]interface{}{1000, -1, 60}) + // 返回 [threshold, current, ttl] 格式,current > threshold 表示触发限流 + resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60}) host.CallOnRedisCall(0, resp) // 检查是否发送了限流响应 @@ -459,7 +459,7 @@ func TestOnHttpStreamingBody(t *testing.T) { }) // 模拟 Redis 调用响应 - resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60}) + resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60}) host.CallOnRedisCall(0, resp) // 处理流式响应体 @@ -499,7 +499,7 @@ func TestOnHttpStreamingBody(t *testing.T) { }) // 模拟 Redis 调用响应 - resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60}) + resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60}) host.CallOnRedisCall(0, resp) // 处理流式响应体 @@ -537,7 +537,7 @@ func TestCompleteFlow(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 2. 模拟 Redis 调用响应 - resp := test.CreateRedisRespArray([]interface{}{100, 99, 60}) + resp := test.CreateRedisRespArray([]interface{}{100, 1, 60}) host.CallOnRedisCall(0, resp) // 3. 处理流式响应体 diff --git a/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go b/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go index 4005907433..28d08e4ca7 100644 --- a/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go +++ b/plugins/wasm-go/extensions/cluster-key-rate-limit/main.go @@ -46,17 +46,30 @@ func init() { const ( // RedisKeyPrefix 集群限流插件在 Redis 中 key 的统一前缀 RedisKeyPrefix = "higress-cluster-key-rate-limit" - // ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口:窗口内限流数 - ClusterGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d:%d" - // ClusterRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:窗口内限流数:限流key名称:限流key对应的实际值 - ClusterRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%d:%s:%s" + // ClusterGlobalRateLimitFormat 全局限流模式 redis key 为 RedisKeyPrefix:限流规则名称:global_threshold:时间窗口 + ClusterGlobalRateLimitFormat = RedisKeyPrefix + ":%s:global_threshold:%d" + // ClusterRateLimitFormat 规则限流模式 redis key 为 RedisKeyPrefix:限流规则名称:限流类型:时间窗口:限流key名称:限流key对应的实际值 + ClusterRateLimitFormat = RedisKeyPrefix + ":%s:%s:%d:%s:%s" FixedWindowScript = ` - local ttl = redis.call('ttl', KEYS[1]) - if ttl < 0 then - redis.call('set', KEYS[1], ARGV[1] - 1, 'EX', ARGV[2]) - return {ARGV[1], ARGV[1] - 1, ARGV[2]} - end - return {ARGV[1], redis.call('incrby', KEYS[1], -1), ttl} + local key = KEYS[1] + local threshold = tonumber(ARGV[1]) + local window = tonumber(ARGV[2]) + + local current = tonumber(redis.call('get', key) or "0") + + -- 只有超过阈值时才停止累加,达到阈值时仍允许(此时是最后一次允许) + if current > threshold then + return {threshold, current, redis.call('ttl', key)} + end + + -- 计数未超过阈值,执行累加 + current = redis.call('incr', key) + -- 第一次累加时设置过期时间 + if current == 1 then + redis.call('expire', key, window) + end + + return {threshold, current, redis.call('ttl', key)} ` LimitContextKey = "LimitContext" // 限流上下文信息 @@ -92,7 +105,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi if cfg.GlobalThreshold != nil { // 全局限流模式 - limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow, cfg.GlobalThreshold.Count) + limitKey = fmt.Sprintf(ClusterGlobalRateLimitFormat, cfg.RuleName, cfg.GlobalThreshold.TimeWindow) count = cfg.GlobalThreshold.Count timeWindow = cfg.GlobalThreshold.TimeWindow } else { @@ -103,7 +116,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi return types.ActionContinue } - limitKey = fmt.Sprintf(ClusterRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, configItem.Count, ruleItem.Key, val) + limitKey = fmt.Sprintf(ClusterRateLimitFormat, cfg.RuleName, ruleItem.LimitType, configItem.TimeWindow, ruleItem.Key, val) count = configItem.Count timeWindow = configItem.TimeWindow } @@ -118,12 +131,15 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, cfg config.ClusterKeyRateLimi proxywasm.ResumeHttpRequest() return } + + // 获取限流结果 + threshold, current, ttl := resultArray[0].Integer(), resultArray[1].Integer(), resultArray[2].Integer() context := LimitContext{ - count: resultArray[0].Integer(), - remaining: resultArray[1].Integer(), - reset: resultArray[2].Integer(), + count: threshold, + remaining: threshold - current, + reset: ttl, } - if context.remaining < 0 { + if current > threshold { // 触发限流 rejected(cfg, context) } else { diff --git a/plugins/wasm-go/extensions/cluster-key-rate-limit/main_test.go b/plugins/wasm-go/extensions/cluster-key-rate-limit/main_test.go index d76860e69b..d61103575b 100644 --- a/plugins/wasm-go/extensions/cluster-key-rate-limit/main_test.go +++ b/plugins/wasm-go/extensions/cluster-key-rate-limit/main_test.go @@ -15,10 +15,11 @@ package main import ( - "cluster-key-rate-limit/config" "encoding/json" "testing" + "cluster-key-rate-limit/config" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/test" "github.com/stretchr/testify/require" @@ -527,9 +528,16 @@ func TestOnHttpRequestHeaders(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 模拟 Redis 调用响应(触发限流) - resp := test.CreateRedisRespArray([]interface{}{1000, -1, 60}) + // 当前请求数(1001)超过阈值(1000),触发限流 + resp := test.CreateRedisRespArray([]interface{}{1000, 1001, 60}) host.CallOnRedisCall(0, resp) + // 检查是否发送了限流响应 + localResponse := host.GetLocalResponse() + require.NotNil(t, localResponse) + require.Equal(t, uint32(429), localResponse.StatusCode) + require.Contains(t, string(localResponse.Data), "Too many requests") + host.CompleteHttp() }) }) @@ -641,7 +649,7 @@ func TestCompleteFlow(t *testing.T) { require.Equal(t, types.HeaderStopAllIterationAndWatermark, action) // 2. 模拟 Redis 调用响应 - resp := test.CreateRedisRespArray([]interface{}{1000, 999, 60}) + resp := test.CreateRedisRespArray([]interface{}{1000, 1, 60}) host.CallOnRedisCall(0, resp) // 3. 处理响应头