diff --git a/plugins/wasm-go/extensions/ai-security-guard/config/config.go b/plugins/wasm-go/extensions/ai-security-guard/config/config.go new file mode 100644 index 0000000000..cdbd427044 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/config/config.go @@ -0,0 +1,583 @@ +package config + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + MaxRisk = "max" + HighRisk = "high" + MediumRisk = "medium" + LowRisk = "low" + NoRisk = "none" + + S4Sensitive = "s4" + S3Sensitive = "s3" + S2Sensitive = "s2" + S1Sensitive = "s1" + NoSensitive = "s0" + + ContentModerationType = "contentModeration" + PromptAttackType = "promptAttack" + SensitiveDataType = "sensitiveData" + MaliciousUrlDataType = "maliciousUrl" + ModelHallucinationDataType = "modelHallucination" + + // Default configurations + OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` + OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` + + DefaultDenyCode = 200 + DefaultDenyMessage = "很抱歉,我无法回答您的问题" + DefaultTimeout = 2000 + + AliyunUserAgent = "CIPFrom/AIGateway" + LengthLimit = 1800 + + DefaultRequestCheckService = "llm_query_moderation" + DefaultResponseCheckService = "llm_response_moderation" + DefaultRequestJsonPath = "messages.@reverse.0.content" + DefaultResponseJsonPath = "choices.0.message.content" + DefaultStreamingResponseJsonPath = "choices.0.delta.content" + + // Actions + MultiModalGuard = "MultiModalGuard" + MultiModalGuardForBase64 = "MultiModalGuardForBase64" + TextModerationPlus = "TextModerationPlus" + + // Services + DefaultMultiModalGuardTextInputCheckService = "query_security_check" + DefaultMultiModalGuardTextOutputCheckService = "response_security_check" + DefaultMultiModalGuardImageInputCheckService = "img_query_security_check" + + DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation" + DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation" +) + +// api types + +const ( + ApiTextGeneration = "text_generation" + ApiImageGeneration = "image_generation" +) + +// provider types +const ( + ProviderOpenAI = "openai" + ProviderQwen = "qwen" + ProviderComfyUI = "comfyui" +) + +type Response struct { + Code int `json:"Code"` + Message string `json:"Message"` + RequestId string `json:"RequestId"` + Data Data `json:"Data"` +} + +type Data struct { + RiskLevel string `json:"RiskLevel,omitempty"` + AttackLevel string `json:"AttackLevel,omitempty"` + Result []Result `json:"Result,omitempty"` + Advice []Advice `json:"Advice,omitempty"` + Detail []Detail `json:"Detail,omitempty"` +} + +type Result struct { + RiskWords string `json:"RiskWords,omitempty"` + Description string `json:"Description,omitempty"` + Confidence float64 `json:"Confidence,omitempty"` + Label string `json:"Label,omitempty"` +} + +type Advice struct { + Answer string `json:"Answer,omitempty"` + HitLabel string `json:"HitLabel,omitempty"` + HitLibName string `json:"HitLibName,omitempty"` +} + +type Detail struct { + Suggestion string `json:"Suggestion,omitempty"` + Type string `json:"Type,omitempty"` + Level string `json:"Level,omitempty"` +} + +type Matcher struct { + Exact string + Prefix string + Re *regexp.Regexp +} + +func (m *Matcher) match(consumer string) bool { + if m.Exact != "" { + return consumer == m.Exact + } else if m.Prefix != "" { + return strings.HasPrefix(consumer, m.Prefix) + } else if m.Re != nil { + return m.Re.MatchString(consumer) + } else { + return false + } +} + +type AISecurityConfig struct { + Client wrapper.HttpClient + Host string + AK string + SK string + Token string + Action string + CheckRequest bool + RequestCheckService string + RequestImageCheckService string + RequestContentJsonPath string + CheckResponse bool + ResponseCheckService string + ResponseImageCheckService string + ResponseContentJsonPath string + ResponseStreamContentJsonPath string + DenyCode int64 + DenyMessage string + ProtocolOriginal bool + RiskLevelBar string + ContentModerationLevelBar string + PromptAttackLevelBar string + SensitiveDataLevelBar string + MaliciousUrlLevelBar string + ModelHallucinationLevelBar string + Timeout uint32 + BufferLimit int + Metrics map[string]proxywasm.MetricCounter + ConsumerRequestCheckService []map[string]interface{} + ConsumerResponseCheckService []map[string]interface{} + ConsumerRiskLevel []map[string]interface{} + // text_generation, image_generation, etc. + ApiType string + // openai, qwen, comfyui, etc. + ProviderType string +} + +func (config *AISecurityConfig) Parse(json gjson.Result) error { + serviceName := json.Get("serviceName").String() + servicePort := json.Get("servicePort").Int() + serviceHost := json.Get("serviceHost").String() + config.Host = serviceHost + if serviceName == "" || servicePort == 0 || serviceHost == "" { + return errors.New("invalid service config") + } + config.AK = json.Get("accessKey").String() + config.SK = json.Get("secretKey").String() + if config.AK == "" || config.SK == "" { + return errors.New("invalid AK/SK config") + } + config.Token = json.Get("securityToken").String() + // set action + if obj := json.Get("action"); obj.Exists() { + config.Action = json.Get("action").String() + } else { + config.Action = TextModerationPlus + } + // set default values + config.SetDefaultValues() + // set values + if obj := json.Get("riskLevelBar"); obj.Exists() { + config.RiskLevelBar = obj.String() + } + if obj := json.Get("requestCheckService"); obj.Exists() { + config.RequestCheckService = obj.String() + } + if obj := json.Get("requestImageCheckService"); obj.Exists() { + config.RequestImageCheckService = obj.String() + } + if obj := json.Get("responseCheckService"); obj.Exists() { + config.ResponseCheckService = obj.String() + } + if obj := json.Get("responseImageCheckService"); obj.Exists() { + config.ResponseImageCheckService = obj.String() + } + config.CheckRequest = json.Get("checkRequest").Bool() + config.CheckResponse = json.Get("checkResponse").Bool() + config.ProtocolOriginal = json.Get("protocol").String() == "original" + config.DenyMessage = json.Get("denyMessage").String() + if obj := json.Get("denyCode"); obj.Exists() { + config.DenyCode = obj.Int() + } + if obj := json.Get("requestContentJsonPath"); obj.Exists() { + config.RequestContentJsonPath = obj.String() + } + if obj := json.Get("responseContentJsonPath"); obj.Exists() { + config.ResponseContentJsonPath = obj.String() + } + if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() { + config.ResponseStreamContentJsonPath = obj.String() + } + if obj := json.Get("contentModerationLevelBar"); obj.Exists() { + config.ContentModerationLevelBar = obj.String() + if LevelToInt(config.ContentModerationLevelBar) <= 0 { + return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]") + } + } + if obj := json.Get("promptAttackLevelBar"); obj.Exists() { + config.PromptAttackLevelBar = obj.String() + if LevelToInt(config.PromptAttackLevelBar) <= 0 { + return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]") + } + } + if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() { + config.SensitiveDataLevelBar = obj.String() + if LevelToInt(config.SensitiveDataLevelBar) <= 0 { + return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]") + } + } + if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() { + config.ModelHallucinationLevelBar = obj.String() + if LevelToInt(config.ModelHallucinationLevelBar) <= 0 { + return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]") + } + } + if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() { + config.MaliciousUrlLevelBar = obj.String() + if LevelToInt(config.MaliciousUrlLevelBar) <= 0 { + return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]") + } + } + if obj := json.Get("timeout"); obj.Exists() { + config.Timeout = uint32(obj.Int()) + } + if obj := json.Get("bufferLimit"); obj.Exists() { + config.BufferLimit = int(obj.Int()) + } + if obj := json.Get("consumerRequestCheckService"); obj.Exists() { + for _, item := range json.Get("consumerRequestCheckService").Array() { + m := make(map[string]interface{}) + for k, v := range item.Map() { + m[k] = v.Value() + } + consumerName, ok1 := m["name"] + matchType, ok2 := m["matchType"] + if !ok1 || !ok2 { + continue + } + switch fmt.Sprint(matchType) { + case "exact": + m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} + case "prefix": + m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} + case "regexp": + m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} + } + config.ConsumerRequestCheckService = append(config.ConsumerRequestCheckService, m) + } + } + if obj := json.Get("consumerResponseCheckService"); obj.Exists() { + for _, item := range json.Get("consumerResponseCheckService").Array() { + m := make(map[string]interface{}) + for k, v := range item.Map() { + m[k] = v.Value() + } + consumerName, ok1 := m["name"] + matchType, ok2 := m["matchType"] + if !ok1 || !ok2 { + continue + } + switch fmt.Sprint(matchType) { + case "exact": + m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} + case "prefix": + m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} + case "regexp": + m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} + } + config.ConsumerResponseCheckService = append(config.ConsumerResponseCheckService, m) + } + } + if obj := json.Get("consumerRiskLevel"); obj.Exists() { + for _, item := range json.Get("consumerRiskLevel").Array() { + m := make(map[string]interface{}) + for k, v := range item.Map() { + m[k] = v.Value() + } + consumerName, ok1 := m["name"] + matchType, ok2 := m["matchType"] + if !ok1 || !ok2 { + continue + } + switch fmt.Sprint(matchType) { + case "exact": + m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} + case "prefix": + m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} + case "regexp": + m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} + } + config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m) + } + } + if obj := json.Get("apiType"); obj.Exists() { + config.ApiType = obj.String() + } + if obj := json.Get("providerType"); obj.Exists() { + config.ProviderType = obj.String() + } + config.Client = wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: serviceName, + Port: servicePort, + Host: serviceHost, + }) + config.Metrics = make(map[string]proxywasm.MetricCounter) + return nil +} + +func (config *AISecurityConfig) SetDefaultValues() { + switch config.Action { + case TextModerationPlus: + config.RequestCheckService = DefaultTextModerationPlusTextInputCheckService + config.ResponseCheckService = DefaultTextModerationPlusTextOutputCheckService + case MultiModalGuard: + config.RequestCheckService = DefaultMultiModalGuardTextInputCheckService + config.RequestImageCheckService = DefaultMultiModalGuardImageInputCheckService + config.ResponseCheckService = DefaultMultiModalGuardTextOutputCheckService + } + config.RiskLevelBar = HighRisk + config.DenyCode = DefaultDenyCode + config.RequestContentJsonPath = DefaultRequestJsonPath + config.ResponseContentJsonPath = DefaultResponseJsonPath + config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath + config.ContentModerationLevelBar = MaxRisk + config.PromptAttackLevelBar = MaxRisk + config.SensitiveDataLevelBar = S4Sensitive + config.ModelHallucinationLevelBar = MaxRisk + config.MaliciousUrlLevelBar = MaxRisk + config.Timeout = DefaultTimeout + config.BufferLimit = 1000 + config.ApiType = ApiTextGeneration + config.ProviderType = ProviderOpenAI +} + +func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) { + counter, ok := config.Metrics[metricName] + if !ok { + counter = proxywasm.DefineCounterMetric(metricName) + config.Metrics[metricName] = counter + } + counter.Increment(inc) +} + +func (config *AISecurityConfig) GetRequestCheckService(consumer string) string { + result := config.RequestCheckService + for _, obj := range config.ConsumerRequestCheckService { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if requestCheckService, ok := obj["requestCheckService"]; ok { + result, _ = requestCheckService.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetRequestImageCheckService(consumer string) string { + result := config.RequestImageCheckService + for _, obj := range config.ConsumerRequestCheckService { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if requestCheckService, ok := obj["requestImageCheckService"]; ok { + result, _ = requestCheckService.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetResponseCheckService(consumer string) string { + result := config.ResponseCheckService + for _, obj := range config.ConsumerResponseCheckService { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if responseCheckService, ok := obj["responseCheckService"]; ok { + result, _ = responseCheckService.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) string { + result := config.ResponseImageCheckService + for _, obj := range config.ConsumerResponseCheckService { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if responseCheckService, ok := obj["responseImageCheckService"]; ok { + result, _ = responseCheckService.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string { + result := config.RiskLevelBar + for _, obj := range config.ConsumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if riskLevelBar, ok := obj["riskLevelBar"]; ok { + result, _ = riskLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string { + result := config.ContentModerationLevelBar + for _, obj := range config.ConsumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok { + result, _ = contentModerationLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string { + result := config.PromptAttackLevelBar + for _, obj := range config.ConsumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok { + result, _ = promptAttackLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string { + result := config.SensitiveDataLevelBar + for _, obj := range config.ConsumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok { + result, _ = sensitiveDataLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string { + result := config.MaliciousUrlLevelBar + for _, obj := range config.ConsumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok { + result, _ = maliciousUrlLevelBar.(string) + } + break + } + } + } + return result +} + +func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string { + result := config.ModelHallucinationLevelBar + for _, obj := range config.ConsumerRiskLevel { + if matcher, ok := obj["matcher"].(Matcher); ok { + if matcher.match(consumer) { + if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok { + result, _ = modelHallucinationLevelBar.(string) + } + break + } + } + } + return result +} + +func LevelToInt(riskLevel string) int { + // First check against our defined constants + switch strings.ToLower(riskLevel) { + case MaxRisk, S4Sensitive: + return 4 + case HighRisk, S3Sensitive: + return 3 + case MediumRisk, S2Sensitive: + return 2 + case LowRisk, S1Sensitive: + return 1 + case NoRisk, NoSensitive: + return 0 + default: + return -1 + } +} + +func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool { + if action == MultiModalGuard || action == MultiModalGuardForBase64 { + // Check top-level risk levels for MultiModalGuard + if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) { + return false + } + // Also check AttackLevel for prompt attack detection + if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) { + return false + } + + // Check detailed results for backward compatibility + for _, detail := range data.Detail { + switch detail.Type { + case ContentModerationType: + if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) { + return false + } + case PromptAttackType: + if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) { + return false + } + case SensitiveDataType: + if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) { + return false + } + case MaliciousUrlDataType: + if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) { + return false + } + case ModelHallucinationDataType: + if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) { + return false + } + } + } + return true + } else { + return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer)) + } +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.mod b/plugins/wasm-go/extensions/ai-security-guard/go.mod index d62eafb99b..9e5e068b7e 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.mod +++ b/plugins/wasm-go/extensions/ai-security-guard/go.mod @@ -20,5 +20,6 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/resp v0.1.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.sum b/plugins/wasm-go/extensions/ai-security-guard/go.sum index b055378c0c..67a9e45af3 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.sum +++ b/plugins/wasm-go/extensions/ai-security-guard/go.sum @@ -24,6 +24,8 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/request_builder.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/request_builder.go new file mode 100644 index 0000000000..9bf466dbdd --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/request_builder.go @@ -0,0 +1,249 @@ +package common + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "sort" + + "golang.org/x/exp/maps" + + "fmt" + "net/url" + "strings" + "time" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/google/uuid" +) + +const ( + ALGORITHM = "ACS3-HMAC-SHA256" +) + +type Request struct { + httpMethod string + canonicalUri string + host string + xAcsAction string + xAcsVersion string + headers map[string]string + body []byte + queryParam map[string]interface{} +} + +func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request { + req := &Request{ + httpMethod: httpMethod, + canonicalUri: canonicalUri, + host: host, + xAcsAction: xAcsAction, + xAcsVersion: xAcsVersion, + headers: make(map[string]string), + queryParam: make(map[string]interface{}), + } + req.headers["host"] = host + req.headers["x-acs-action"] = xAcsAction + req.headers["x-acs-version"] = xAcsVersion + req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339) + req.headers["x-acs-signature-nonce"] = uuid.New().String() + return req +} + +func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) { + newQueryParams := make(map[string]interface{}) + processObject(newQueryParams, "", req.queryParam) + req.queryParam = newQueryParams + canonicalQueryString := "" + keys := maps.Keys(req.queryParam) + sort.Strings(keys) + for _, k := range keys { + v := req.queryParam[k] + canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&" + } + canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&") + + var bodyContent []byte + if req.body == nil { + bodyContent = []byte("") + } else { + bodyContent = req.body + } + hashedRequestPayload := sha256Hex(bodyContent) + req.headers["x-acs-content-sha256"] = hashedRequestPayload + + if SecurityToken != "" { + req.headers["x-acs-security-token"] = SecurityToken + } + + canonicalHeaders := "" + signedHeaders := "" + HeadersKeys := maps.Keys(req.headers) + sort.Strings(HeadersKeys) + for _, k := range HeadersKeys { + lowerKey := strings.ToLower(k) + if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" { + canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n" + signedHeaders += lowerKey + ";" + } + } + signedHeaders = strings.TrimSuffix(signedHeaders, ";") + + canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload + + hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest)) + stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest + + byteData, err := hmac256([]byte(AccessKeySecret), stringToSign) + if err != nil { + fmt.Println(err) + panic(err) + } + signature := strings.ToLower(hex.EncodeToString(byteData)) + + authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature + req.headers["Authorization"] = authorization +} + +func hmac256(key []byte, toSignString string) ([]byte, error) { + h := hmac.New(sha256.New, key) + _, err := h.Write([]byte(toSignString)) + if err != nil { + return nil, err + } + return h.Sum(nil), nil +} + +func sha256Hex(byteArray []byte) string { + hash := sha256.New() + _, _ = hash.Write(byteArray) + hexString := hex.EncodeToString(hash.Sum(nil)) + return hexString +} + +func percentCode(str string) string { + str = strings.ReplaceAll(str, "+", "%20") + str = strings.ReplaceAll(str, "*", "%2A") + str = strings.ReplaceAll(str, "%7E", "~") + return str +} + +func formDataToString(formData map[string]interface{}) *string { + tmp := make(map[string]interface{}) + processObject(tmp, "", formData) + res := "" + urlEncoder := url.Values{} + for key, value := range tmp { + v := fmt.Sprintf("%v", value) + urlEncoder.Add(key, v) + } + res = urlEncoder.Encode() + return &res +} + +// processObject 递归处理对象,将复杂对象(如Map和List)展开为平面的键值对 +func processObject(mapResult map[string]interface{}, key string, value interface{}) { + if value == nil { + return + } + + switch v := value.(type) { + case []interface{}: + for i, item := range v { + processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item) + } + case map[string]interface{}: + for subKey, subValue := range v { + processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue) + } + default: + if strings.HasPrefix(key, ".") { + key = key[1:] + } + if b, ok := v.([]byte); ok { + mapResult[key] = string(b) + } else { + mapResult[key] = fmt.Sprintf("%v", v) + } + } +} + +func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) { + httpMethod := "POST" + canonicalUri := "/" + xAcsVersion := "2022-03-02" + req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion) + + req.queryParam["Service"] = checkService + + body := make(map[string]interface{}) + serviceParameters := make(map[string]interface{}) + serviceParameters["content"] = text + serviceParameters["sessionId"] = sessionID + serviceParameters["requestFrom"] = cfg.AliyunUserAgent + serviceParametersJSON, _ := json.Marshal(serviceParameters) + body["ServiceParameters"] = serviceParametersJSON + str := formDataToString(body) + req.body = []byte(*str) + req.headers["content-type"] = "application/x-www-form-urlencoded" + req.headers["User-Agent"] = cfg.AliyunUserAgent + + getAuthorization(req, config.AK, config.SK, config.Token) + + q := url.Values{} + keys := maps.Keys(req.queryParam) + sort.Strings(keys) + for _, k := range keys { + v := req.queryParam[k] + q.Set(k, fmt.Sprintf("%v", v)) + } + for k, v := range req.headers { + if k != "host" { + headers = append(headers, [2]string{k, v}) + } + } + return "?" + q.Encode(), headers, req.body +} + +func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) { + httpMethod := "POST" + canonicalUri := "/" + xAcsVersion := "2022-03-02" + req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion) + + req.queryParam["Service"] = checkService + + body := make(map[string]interface{}) + serviceParameters := make(map[string]interface{}) + if imgUrl != "" { + serviceParameters["imageUrls"] = []string{imgUrl} + } + serviceParametersJSON, _ := json.Marshal(serviceParameters) + serviceParameters["requestFrom"] = cfg.AliyunUserAgent + body["ServiceParameters"] = serviceParametersJSON + if imgBase64 != "" { + body["ImageBase64Str"] = imgBase64 + } + str := formDataToString(body) + req.body = []byte(*str) + req.headers["content-type"] = "application/x-www-form-urlencoded" + req.headers["User-Agent"] = cfg.AliyunUserAgent + + getAuthorization(req, config.AK, config.SK, config.Token) + + q := url.Values{} + keys := maps.Keys(req.queryParam) + sort.Strings(keys) + for _, k := range keys { + v := req.queryParam[k] + q.Set(k, fmt.Sprintf("%v", v)) + } + for k, v := range req.headers { + // host will be added by envoy automatically + if k != "host" { + headers = append(headers, [2]string{k, v}) + } + } + return "?" + q.Encode(), headers, req.body +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go new file mode 100644 index 0000000000..ac5f06fb41 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text/openai.go @@ -0,0 +1,228 @@ +package text + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + contentType, _ := proxywasm.GetHttpResponseHeader("content-type") + ctx.SetContext("end_of_stream_received", false) + ctx.SetContext("during_call", false) + ctx.SetContext("risk_detected", false) + sessionID, _ := utils.GenerateHexID(20) + ctx.SetContext("sessionID", sessionID) + if strings.Contains(contentType, "text/event-stream") { + ctx.NeedPauseStreamingResponse() + return types.ActionContinue + } else { + ctx.BufferResponseBody() + return types.HeaderStopIteration + } +} + +func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte { + consumer, _ := ctx.GetContext("consumer").(string) + var sessionID string + if ctx.GetContext("sessionID") == nil { + sessionID, _ = utils.GenerateHexID(20) + ctx.SetContext("sessionID", sessionID) + } else { + sessionID, _ = ctx.GetContext("sessionID").(string) + } + var bufferQueue [][]byte + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + if ctx.GetContext("end_of_stream_received").(bool) { + proxywasm.ResumeHttpResponse() + } + ctx.SetContext("during_call", false) + return + } + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at response phase") + if ctx.GetContext("end_of_stream_received").(bool) { + proxywasm.ResumeHttpResponse() + } + ctx.SetContext("during_call", false) + return + } + if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + denyMessage := cfg.DefaultDenyMessage + if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = "\n" + response.Data.Advice[0].Answer + } else if config.DenyMessage != "" { + denyMessage = config.DenyMessage + } + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.InjectEncodedDataToFilterChain(jsonData, true) + return + } + endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0 + proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream) + bufferQueue = [][]byte{} + if !endStream { + ctx.SetContext("during_call", false) + singleCall() + } + } + singleCall = func() { + if ctx.GetContext("during_call").(bool) { + return + } + if ctx.BufferQueueSize() >= config.BufferLimit || ctx.GetContext("end_of_stream_received").(bool) { + var buffer string + for ctx.BufferQueueSize() > 0 { + front := ctx.PopBuffer() + bufferQueue = append(bufferQueue, front) + msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String() + buffer += msg + if len([]rune(buffer)) >= config.BufferLimit { + break + } + } + // if streaming body has reasoning_content, buffer maybe empty + log.Debugf("current content piece: %s", buffer) + if len(buffer) == 0 { + return + } + ctx.SetContext("during_call", true) + log.Debugf("current content piece: %s", buffer) + checkService := config.GetResponseCheckService(consumer) + path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID) + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + if ctx.GetContext("end_of_stream_received").(bool) { + proxywasm.ResumeHttpResponse() + } + } + } + } + if !ctx.GetContext("risk_detected").(bool) { + for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) { + ctx.PushBuffer([]byte(string(chunk) + "\n\n")) + } + ctx.SetContext("end_of_stream_received", endOfStream) + if !ctx.GetContext("during_call").(bool) { + singleCall() + } + } else if endOfStream { + proxywasm.ResumeHttpResponse() + } + return []byte{} +} + +func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) + log.Debugf("checking response body...") + startTime := time.Now().UnixMilli() + contentType, _ := proxywasm.GetHttpResponseHeader("content-type") + isStreamingResponse := strings.Contains(contentType, "event-stream") + var content string + if isStreamingResponse { + content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath) + } else { + content = gjson.GetBytes(body, config.ResponseContentJsonPath).String() + } + log.Debugf("Raw response content is: %s", content) + if len(content) == 0 { + log.Info("response content is empty. skip") + return types.ActionContinue + } + contentIndex := 0 + sessionID, _ := utils.GenerateHexID(20) + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpResponse() + return + } + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at response phase") + proxywasm.ResumeHttpResponse() + return + } + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "response pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + proxywasm.ResumeHttpResponse() + } else { + singleCall() + } + return + } + denyMessage := cfg.DefaultDenyMessage + if config.DenyMessage != "" { + denyMessage = config.DenyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if isStreamingResponse { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + config.IncrementCounter("ai_sec_response_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "response deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + singleCall = func() { + var nextContentIndex int + if contentIndex+cfg.LengthLimit >= len(content) { + nextContentIndex = len(content) + } else { + nextContentIndex = contentIndex + cfg.LengthLimit + } + contentPiece := content[contentIndex:nextContentIndex] + contentIndex = nextContentIndex + log.Debugf("current content piece: %s", contentPiece) + checkService := config.GetResponseCheckService(consumer) + path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID) + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpResponse() + } + } + singleCall() + return types.ActionPause +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/handler.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/handler.go new file mode 100644 index 0000000000..7496e7e7d6 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/handler.go @@ -0,0 +1,67 @@ +package multi_modal_guard + +import ( + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + return types.ActionContinue +} + +func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + return text.HandleTextGenerationRequestBody(ctx, config, body) +} + +func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + switch config.ApiType { + case cfg.ApiTextGeneration: + return common_text.HandleTextGenerationResponseHeader(ctx, config) + case cfg.ApiImageGeneration: + switch config.ProviderType { + case cfg.ProviderOpenAI, cfg.ProviderQwen: + return image.HandleImageGenerationResponseHeader(ctx, config) + default: + log.Errorf("[on response header] image generation api don't support provider: %s", config.ProviderType) + return types.ActionContinue + } + default: + log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType) + return types.ActionContinue + } +} + +func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte { + switch config.ApiType { + case cfg.ApiTextGeneration: + return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream) + default: + log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType) + return data + } +} + +func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + switch config.ApiType { + case cfg.ApiTextGeneration: + return common_text.HandleTextGenerationResponseBody(ctx, config, body) + case cfg.ApiImageGeneration: + switch config.ProviderType { + case cfg.ProviderOpenAI: + return image.HandleOpenAIImageGenerationResponseBody(ctx, config, body) + case cfg.ProviderQwen: + return image.HandleQwenImageGenerationResponseBody(ctx, config, body) + default: + log.Errorf("[on response body] image generation api don't support provider: %s", config.ProviderType) + return types.ActionContinue + } + default: + log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType) + return types.ActionContinue + } +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/common.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/common.go new file mode 100644 index 0000000000..fd43336317 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/common.go @@ -0,0 +1,22 @@ +package image + +import ( + "strings" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + contentType, _ := proxywasm.GetHttpResponseHeader("content-type") + ctx.SetContext("risk_detected", false) + if strings.Contains(contentType, "text/event-stream") { + ctx.DontReadResponseBody() + return types.ActionContinue + } else { + ctx.BufferResponseBody() + return types.HeaderStopIteration + } +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/openai.go new file mode 100644 index 0000000000..ad2a7d4d63 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/openai.go @@ -0,0 +1,111 @@ +package image + +import ( + "encoding/json" + "net/http" + "time" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type ImageItemForOpenAI struct { + Content string + Type string // URL or BASE64 +} + +func getOpenAIImageResults(body []byte) []ImageItemForOpenAI { + // qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126 + result := []ImageItemForOpenAI{} + for _, part := range gjson.GetBytes(body, "data").Array() { + if url := part.Get("url").String(); url != "" { + result = append(result, ImageItemForOpenAI{ + Content: url, + Type: "URL", + }) + } + if b64 := part.Get("b64_json").String(); b64 != "" { + result = append(result, ImageItemForOpenAI{ + Content: b64, + Type: "BASE64", + }) + } + } + return result +} + +func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) + log.Debugf("checking response body...") + checkImageService := config.GetResponseImageCheckService(consumer) + startTime := time.Now().UnixMilli() + imgResults := getOpenAIImageResults(body) + if len(imgResults) == 0 { + return types.ActionContinue + } + imageIndex := 0 + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + imageIndex += 1 + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + if imageIndex < len(imgResults) { + singleCall() + } else { + proxywasm.ResumeHttpResponse() + } + return + } + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Errorf("%+v", err) + if imageIndex < len(imgResults) { + singleCall() + } else { + proxywasm.ResumeHttpResponse() + } + return + } + endTime := time.Now().UnixMilli() + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + if imageIndex >= len(imgResults) { + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + proxywasm.ResumeHttpResponse() + } else { + singleCall() + } + return + } + proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1) + config.IncrementCounter("ai_sec_request_deny", 1) + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + singleCall = func() { + img := imgResults[imageIndex] + imgUrl := "" + imgBase64 := "" + if img.Type == "BASE64" { + imgBase64 = img.Content + } else { + imgUrl = img.Content + } + path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64) + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpResponse() + } + } + singleCall() + return types.ActionPause +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/qwen.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/qwen.go new file mode 100644 index 0000000000..e2e5df07c5 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image/qwen.go @@ -0,0 +1,134 @@ +package image + +import ( + "encoding/json" + "net/http" + "time" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func getQwenImageUrls(body []byte) []string { + // qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126 + result := []string{} + // 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘 + // 虚拟模特/图像背景生成/人物写真FaceChain/文生图StableDiffusion/文生图FLUX/文字纹理生成API + for _, part := range gjson.GetBytes(body, "output.results").Array() { + if url := part.Get("url").String(); url != "" { + result = append(result, url) + } + } + // 图像编辑 + for _, part := range gjson.GetBytes(body, "output.choices.0.message.content").Array() { + if url := part.Get("image").String(); url != "" { + result = append(result, url) + } + } + // 图像翻译/AI试衣OutfitAnyone + if url := gjson.GetBytes(body, "output.image_url").String(); url != "" { + result = append(result, url) + } + // 图像画面扩展/(part of)人物实例分割/图像擦除补全 + if url := gjson.GetBytes(body, "output.output_image_url").String(); url != "" { + result = append(result, url) + } + // 鞋靴模特 + if url := gjson.GetBytes(body, "output.result_url").String(); url != "" { + result = append(result, url) + } + // 创意海报生成 + for _, part := range gjson.GetBytes(body, "output.render_urls").Array() { + if url := part.String(); url != "" { + result = append(result, url) + } + } + for _, part := range gjson.GetBytes(body, "output.bg_urls").Array() { + if url := part.String(); url != "" { + result = append(result, url) + } + } + // 人物实例分割 + if url := gjson.GetBytes(body, "output.output_vis_image_url").String(); url != "" { + result = append(result, url) + } + // 文字变形API + for _, part := range gjson.GetBytes(body, "output.results").Array() { + if url := part.Get("png_url").String(); url != "" { + result = append(result, url) + } + if url := part.Get("svg_url").String(); url != "" { + result = append(result, url) + } + } + return result +} + +func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) + log.Debugf("checking response body...") + checkImageService := config.GetResponseImageCheckService(consumer) + startTime := time.Now().UnixMilli() + imgUrls := getQwenImageUrls(body) + if len(imgUrls) == 0 { + return types.ActionContinue + } + imageIndex := 0 + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + imageIndex += 1 + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + if imageIndex < len(imgUrls) { + singleCall() + } else { + proxywasm.ResumeHttpResponse() + } + return + } + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Errorf("%+v", err) + if imageIndex < len(imgUrls) { + singleCall() + } else { + proxywasm.ResumeHttpResponse() + } + return + } + endTime := time.Now().UnixMilli() + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + if imageIndex >= len(imgUrls) { + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + proxywasm.ResumeHttpResponse() + } else { + singleCall() + } + return + } + proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1) + config.IncrementCounter("ai_sec_request_deny", 1) + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + singleCall = func() { + imgUrl := imgUrls[imageIndex] + path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "") + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpResponse() + } + } + singleCall() + return types.ActionPause +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go new file mode 100644 index 0000000000..2f2a388741 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text/openai.go @@ -0,0 +1,191 @@ +package text + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func parseContent(json gjson.Result) (text, imgUrl, imgBase64 string) { + if json.IsArray() { + for _, item := range json.Array() { + switch item.Get("type").String() { + case "text": + text += item.Get("text").String() + case "image_url": + imgContent := item.Get("image_url.url").String() + if strings.HasPrefix(imgContent, "data:image") { + imgBase64 = imgContent + } else { + imgUrl = imgContent + } + } + } + } else { + text = json.String() + } + return text, imgUrl, imgBase64 +} + +func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) + checkService := config.GetRequestCheckService(consumer) + checkImageService := config.GetRequestImageCheckService(consumer) + startTime := time.Now().UnixMilli() + // content := gjson.GetBytes(body, config.RequestContentJsonPath).String() + content, imgUrl, imgBase64 := parseContent(gjson.GetBytes(body, config.RequestContentJsonPath)) + log.Debugf("Raw request content is: %s", content) + if len(content) == 0 { + log.Info("request content is empty. skip") + return types.ActionContinue + } + contentIndex := 0 + sessionID, _ := utils.GenerateHexID(20) + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpRequest() + return + } + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Errorf("%+v", err) + proxywasm.ResumeHttpRequest() + return + } + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + proxywasm.ResumeHttpRequest() + } else { + singleCall() + } + return + } + denyMessage := cfg.DefaultDenyMessage + if config.DenyMessage != "" { + denyMessage = config.DenyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + singleCall = func() { + var nextContentIndex int + if contentIndex+cfg.LengthLimit >= len(content) { + nextContentIndex = len(content) + } else { + nextContentIndex = contentIndex + cfg.LengthLimit + } + contentPiece := content[contentIndex:nextContentIndex] + contentIndex = nextContentIndex + log.Debugf("current content piece: %s", contentPiece) + path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID) + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpRequest() + } + } + // check image + if imgUrl != "" || imgBase64 != "" { + path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64) + err := config.Client.Post(path, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + // start checking text + singleCall() + return + } + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Errorf("%+v", err) + // start checking text + singleCall() + return + } + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + // start checking text + singleCall() + return + } + denyMessage := cfg.DefaultDenyMessage + if config.DenyMessage != "" { + denyMessage = config.DenyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + }, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpRequest() + } + } else { + singleCall() + } + return types.ActionPause +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/handler.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/handler.go new file mode 100644 index 0000000000..987a41a45b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/handler.go @@ -0,0 +1,48 @@ +package text_moderation_plus + +import ( + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" +) + +func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + return types.ActionContinue +} + +func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + return text.HandleTextGenerationRequestBody(ctx, config, body) +} + +func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + switch config.ApiType { + case cfg.ApiTextGeneration: + return common_text.HandleTextGenerationResponseHeader(ctx, config) + default: + log.Errorf("text_moderation_plus don't support api: %s", config.ApiType) + return types.ActionContinue + } +} + +func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte { + switch config.ApiType { + case cfg.ApiTextGeneration: + return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream) + default: + log.Errorf("text_moderation_plus don't support api: %s", config.ApiType) + return data + } +} + +func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + switch config.ApiType { + case cfg.ApiTextGeneration: + return common_text.HandleTextGenerationResponseBody(ctx, config, body) + default: + log.Errorf("text_moderation_plus don't support api: %s", config.ApiType) + return types.ActionContinue + } +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text/openai.go b/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text/openai.go new file mode 100644 index 0000000000..316578c031 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text/openai.go @@ -0,0 +1,104 @@ +package text + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/higress-group/wasm-go/pkg/log" + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { + consumer, _ := ctx.GetContext("consumer").(string) + startTime := time.Now().UnixMilli() + content := gjson.GetBytes(body, config.RequestContentJsonPath).String() + log.Debugf("Raw request content is: %s", content) + if len(content) == 0 { + log.Info("request content is empty. skip") + return types.ActionContinue + } + contentIndex := 0 + sessionID, _ := utils.GenerateHexID(20) + var singleCall func() + callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Info(string(responseBody)) + if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { + proxywasm.ResumeHttpRequest() + return + } + var response cfg.Response + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Error("failed to unmarshal aliyun content security response at request phase") + proxywasm.ResumeHttpRequest() + return + } + if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) { + if contentIndex >= len(content) { + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "request pass") + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + proxywasm.ResumeHttpRequest() + } else { + singleCall() + } + return + } + denyMessage := cfg.DefaultDenyMessage + if config.DenyMessage != "" { + denyMessage = config.DenyMessage + } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { + denyMessage = response.Data.Advice[0].Answer + } + marshalledDenyMessage := wrapper.MarshalStr(denyMessage) + if config.ProtocolOriginal { + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) + } else if gjson.GetBytes(body, "stream").Bool() { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + randomID := utils.GenerateRandomChatID() + jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage)) + proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } + ctx.DontReadResponseBody() + config.IncrementCounter("ai_sec_request_deny", 1) + endTime := time.Now().UnixMilli() + ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) + ctx.SetUserAttribute("safecheck_status", "reqeust deny") + if response.Data.Advice != nil { + ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) + ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) + } + ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) + } + singleCall = func() { + var nextContentIndex int + if contentIndex+cfg.LengthLimit >= len(content) { + nextContentIndex = len(content) + } else { + nextContentIndex = contentIndex + cfg.LengthLimit + } + contentPiece := content[contentIndex:nextContentIndex] + contentIndex = nextContentIndex + checkService := config.GetRequestCheckService(consumer) + path, headers, body := common.GenerateRequestForText(config, cfg.TextModerationPlus, checkService, contentPiece, sessionID) + err := config.Client.Post(path, headers, body, callback, config.Timeout) + if err != nil { + log.Errorf("failed call the safe check service: %v", err) + proxywasm.ResumeHttpRequest() + } + } + singleCall() + return types.ActionPause +} diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index 1392db9118..1e2dd1864f 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -1,23 +1,9 @@ package main import ( - "bytes" - "crypto/hmac" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - mrand "math/rand" - "net/http" - "net/url" - "regexp" - "sort" - "strings" - "time" - + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" @@ -39,704 +25,36 @@ func init() { ) } -const ( - MaxRisk = "max" - HighRisk = "high" - MediumRisk = "medium" - LowRisk = "low" - NoRisk = "none" - - S4Sensitive = "S4" - S3Sensitive = "S3" - S2Sensitive = "S2" - S1Sensitive = "S1" - NoSensitive = "S0" - - ContentModerationType = "contentModeration" - PromptAttackType = "promptAttack" - SensitiveDataType = "sensitiveData" - MaliciousUrlDataType = "maliciousUrl" - ModelHallucinationDataType = "modelHallucination" - - OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` - OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` - - DefaultRequestCheckService = "llm_query_moderation" - DefaultResponseCheckService = "llm_response_moderation" - DefaultRequestJsonPath = "messages.@reverse.0.content" - DefaultResponseJsonPath = "choices.0.message.content" - DefaultStreamingResponseJsonPath = "choices.0.delta.content" - DefaultDenyCode = 200 - DefaultDenyMessage = "很抱歉,我无法回答您的问题" - DefaultTimeout = 2000 - - AliyunUserAgent = "CIPFrom/AIGateway" - LengthLimit = 1800 -) - -type Response struct { - Code int `json:"Code"` - Message string `json:"Message"` - RequestId string `json:"RequestId"` - Data Data `json:"Data"` -} - -type Data struct { - RiskLevel string `json:"RiskLevel"` - AttackLevel string `json:"AttackLevel,omitempty"` - Result []Result `json:"Result,omitempty"` - Advice []Advice `json:"Advice,omitempty"` - Detail []Detail `json:"Detail,omitempty"` -} - -type Result struct { - RiskWords string `json:"RiskWords,omitempty"` - Description string `json:"Description,omitempty"` - Confidence float64 `json:"Confidence,omitempty"` - Label string `json:"Label,omitempty"` -} - -type Advice struct { - Answer string `json:"Answer,omitempty"` - HitLabel string `json:"HitLabel,omitempty"` - HitLibName string `json:"HitLibName,omitempty"` -} - -type Detail struct { - Suggestion string `json:"Suggestion,omitempty"` - Type string `json:"Type,omitempty"` - Level string `json:"Level,omitempty"` -} - -type AISecurityConfig struct { - client wrapper.HttpClient - ak string - sk string - token string - action string - checkRequest bool - requestCheckService string - requestContentJsonPath string - checkResponse bool - responseCheckService string - responseContentJsonPath string - responseStreamContentJsonPath string - denyCode int64 - denyMessage string - protocolOriginal bool - riskLevelBar string - contentModerationLevelBar string - promptAttackLevelBar string - sensitiveDataLevelBar string - maliciousUrlLevelBar string - modelHallucinationLevelBar string - timeout uint32 - bufferLimit int - metrics map[string]proxywasm.MetricCounter - consumerRequestCheckService []map[string]interface{} - consumerResponseCheckService []map[string]interface{} - consumerRiskLevel []map[string]interface{} -} - -type Matcher struct { - Exact string - Prefix string - Re *regexp.Regexp -} - -func (m *Matcher) match(consumer string) bool { - if m.Exact != "" { - return consumer == m.Exact - } else if m.Prefix != "" { - return strings.HasPrefix(consumer, m.Prefix) - } else if m.Re != nil { - return m.Re.MatchString(consumer) - } else { - return false - } -} - -func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) { - counter, ok := config.metrics[metricName] - if !ok { - counter = proxywasm.DefineCounterMetric(metricName) - config.metrics[metricName] = counter - } - counter.Increment(inc) -} - -func (config *AISecurityConfig) getRequestCheckService(consumer string) string { - result := config.requestCheckService - for _, obj := range config.consumerRequestCheckService { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if requestCheckService, ok := obj["requestCheckService"]; ok { - result, _ = requestCheckService.(string) - } - break - } - } - } - return result -} - -func (config *AISecurityConfig) getResponseCheckService(consumer string) string { - result := config.responseCheckService - for _, obj := range config.consumerResponseCheckService { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if responseCheckService, ok := obj["responseCheckService"]; ok { - result, _ = responseCheckService.(string) - } - break - } - } - } - return result -} - -func (config *AISecurityConfig) getRiskLevelBar(consumer string) string { - result := config.riskLevelBar - for _, obj := range config.consumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if riskLevelBar, ok := obj["riskLevelBar"]; ok { - result, _ = riskLevelBar.(string) - } - break - } - } - } - return result -} - -func (config *AISecurityConfig) getContentModerationLevelBar(consumer string) string { - result := config.contentModerationLevelBar - for _, obj := range config.consumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok { - result, _ = contentModerationLevelBar.(string) - } - break - } - } - } - return result -} - -func (config *AISecurityConfig) getPromptAttackLevelBar(consumer string) string { - result := config.promptAttackLevelBar - for _, obj := range config.consumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok { - result, _ = promptAttackLevelBar.(string) - } - break - } - } - } - return result -} - -func (config *AISecurityConfig) getSensitiveDataLevelBar(consumer string) string { - result := config.sensitiveDataLevelBar - for _, obj := range config.consumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok { - result, _ = sensitiveDataLevelBar.(string) - } - break - } - } - } - return result -} - -func (config *AISecurityConfig) getMaliciousUrlLevelBar(consumer string) string { - result := config.maliciousUrlLevelBar - for _, obj := range config.consumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok { - result, _ = maliciousUrlLevelBar.(string) - } - break - } - } - } - return result -} - -func (config *AISecurityConfig) getModelHallucinationLevelBar(consumer string) string { - result := config.modelHallucinationLevelBar - for _, obj := range config.consumerRiskLevel { - if matcher, ok := obj["matcher"].(Matcher); ok { - if matcher.match(consumer) { - if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok { - result, _ = modelHallucinationLevelBar.(string) - } - break - } - } - } - return result -} - -func levelToInt(riskLevel string) int { - // First check against our defined constants - switch riskLevel { - case MaxRisk: - return 4 - case HighRisk: - return 3 - case MediumRisk: - return 2 - case LowRisk: - return 1 - case NoRisk: - return 0 - case S4Sensitive: - return 4 - case S3Sensitive: - return 3 - case S2Sensitive: - return 2 - case S1Sensitive: - return 1 - case NoSensitive: - return 0 - } - - // Then check against raw string values - switch riskLevel { - case "max", "MAX": - return 4 - case "high", "HIGH": - return 3 - case "medium", "MEDIUM": - return 2 - case "low", "LOW": - return 1 - case "none", "NONE": - return 0 - case "S4", "s4": - return 4 - case "S3", "s3": - return 3 - case "S2", "s2": - return 2 - case "S1", "s1": - return 1 - case "S0", "s0": - return 0 - default: - return -1 - } -} - -func isRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool { - if action == "MultiModalGuard" { - // Check top-level risk levels for MultiModalGuard - if levelToInt(data.RiskLevel) >= levelToInt(config.getContentModerationLevelBar(consumer)) { - return false - } - // Also check AttackLevel for prompt attack detection - if levelToInt(data.AttackLevel) >= levelToInt(config.getPromptAttackLevelBar(consumer)) { - return false - } - - // Check detailed results for backward compatibility - for _, detail := range data.Detail { - switch detail.Type { - case ContentModerationType: - if levelToInt(detail.Level) >= levelToInt(config.getContentModerationLevelBar(consumer)) { - return false - } - case PromptAttackType: - if levelToInt(detail.Level) >= levelToInt(config.getPromptAttackLevelBar(consumer)) { - return false - } - case SensitiveDataType: - if levelToInt(detail.Level) >= levelToInt(config.getSensitiveDataLevelBar(consumer)) { - return false - } - case MaliciousUrlDataType: - if levelToInt(detail.Level) >= levelToInt(config.getMaliciousUrlLevelBar(consumer)) { - return false - } - case ModelHallucinationDataType: - if levelToInt(detail.Level) >= levelToInt(config.getModelHallucinationLevelBar(consumer)) { - return false - } - } - } - return true - } else { - return levelToInt(data.RiskLevel) < levelToInt(config.getRiskLevelBar(consumer)) - } -} - -func urlEncoding(rawStr string) string { - encodedStr := url.PathEscape(rawStr) - encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B") - encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A") - encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D") - encodedStr = strings.ReplaceAll(encodedStr, "&", "%26") - encodedStr = strings.ReplaceAll(encodedStr, "$", "%24") - encodedStr = strings.ReplaceAll(encodedStr, "@", "%40") - return encodedStr -} - -func hmacSha1(message, secret string) string { - key := []byte(secret) - h := hmac.New(sha1.New, key) - h.Write([]byte(message)) - hash := h.Sum(nil) - return base64.StdEncoding.EncodeToString(hash) -} - -func getSign(params map[string]string, secret string) string { - paramArray := []string{} - for k, v := range params { - paramArray = append(paramArray, urlEncoding(k)+"="+urlEncoding(v)) - } - sort.Slice(paramArray, func(i, j int) bool { - return paramArray[i] <= paramArray[j] - }) - canonicalStr := strings.Join(paramArray, "&") - signStr := "POST&%2F&" + urlEncoding(canonicalStr) - proxywasm.LogDebugf("String to sign is: %s", signStr) - return hmacSha1(signStr, secret) -} - -func generateHexID(length int) (string, error) { - bytes := make([]byte, length/2) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - return hex.EncodeToString(bytes), nil +func parseConfig(json gjson.Result, config *cfg.AISecurityConfig) error { + return config.Parse(json) } -func parseConfig(json gjson.Result, config *AISecurityConfig) error { - serviceName := json.Get("serviceName").String() - servicePort := json.Get("servicePort").Int() - serviceHost := json.Get("serviceHost").String() - if serviceName == "" || servicePort == 0 || serviceHost == "" { - return errors.New("invalid service config") - } - config.ak = json.Get("accessKey").String() - config.sk = json.Get("secretKey").String() - if config.ak == "" || config.sk == "" { - return errors.New("invalid AK/SK config") - } - if obj := json.Get("riskLevelBar"); obj.Exists() { - config.riskLevelBar = obj.String() - } else { - config.riskLevelBar = HighRisk - } - config.token = json.Get("securityToken").String() - if obj := json.Get("action"); obj.Exists() { - config.action = json.Get("action").String() - } else { - config.action = "TextModerationPlus" - } - config.checkRequest = json.Get("checkRequest").Bool() - config.checkResponse = json.Get("checkResponse").Bool() - config.protocolOriginal = json.Get("protocol").String() == "original" - config.denyMessage = json.Get("denyMessage").String() - if obj := json.Get("denyCode"); obj.Exists() { - config.denyCode = obj.Int() - } else { - config.denyCode = DefaultDenyCode - } - if obj := json.Get("requestCheckService"); obj.Exists() { - config.requestCheckService = obj.String() - } else { - config.requestCheckService = DefaultRequestCheckService - } - if obj := json.Get("responseCheckService"); obj.Exists() { - config.responseCheckService = obj.String() - } else { - config.responseCheckService = DefaultResponseCheckService - } - if obj := json.Get("requestContentJsonPath"); obj.Exists() { - config.requestContentJsonPath = obj.String() - } else { - config.requestContentJsonPath = DefaultRequestJsonPath - } - if obj := json.Get("responseContentJsonPath"); obj.Exists() { - config.responseContentJsonPath = obj.String() - } else { - config.responseContentJsonPath = DefaultResponseJsonPath - } - if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() { - config.responseStreamContentJsonPath = obj.String() - } else { - config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath - } - if obj := json.Get("contentModerationLevelBar"); obj.Exists() { - config.contentModerationLevelBar = obj.String() - if levelToInt(config.contentModerationLevelBar) <= 0 { - return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]") - } - } else { - config.contentModerationLevelBar = MaxRisk - } - if obj := json.Get("promptAttackLevelBar"); obj.Exists() { - config.promptAttackLevelBar = obj.String() - if levelToInt(config.promptAttackLevelBar) <= 0 { - return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]") - } - } else { - config.promptAttackLevelBar = MaxRisk - } - if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() { - config.sensitiveDataLevelBar = obj.String() - if levelToInt(config.sensitiveDataLevelBar) <= 0 { - return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]") - } - } else { - config.sensitiveDataLevelBar = S4Sensitive - } - if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() { - config.modelHallucinationLevelBar = obj.String() - if levelToInt(config.modelHallucinationLevelBar) <= 0 { - return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]") - } - } else { - config.modelHallucinationLevelBar = MaxRisk - } - if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() { - config.maliciousUrlLevelBar = obj.String() - if levelToInt(config.maliciousUrlLevelBar) <= 0 { - return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]") - } - } else { - config.maliciousUrlLevelBar = MaxRisk - } - if obj := json.Get("timeout"); obj.Exists() { - config.timeout = uint32(obj.Int()) - } else { - config.timeout = DefaultTimeout - } - if obj := json.Get("bufferLimit"); obj.Exists() { - config.bufferLimit = int(obj.Int()) - } else { - config.bufferLimit = 1000 - } - if obj := json.Get("consumerRequestCheckService"); obj.Exists() { - for _, item := range json.Get("consumerRequestCheckService").Array() { - m := make(map[string]interface{}) - for k, v := range item.Map() { - m[k] = v.Value() - } - consumerName, ok1 := m["name"] - matchType, ok2 := m["matchType"] - if !ok1 || !ok2 { - continue - } - switch fmt.Sprint(matchType) { - case "exact": - m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} - case "prefix": - m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} - case "regexp": - m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} - } - config.consumerRequestCheckService = append(config.consumerRequestCheckService, m) - } - } - if obj := json.Get("consumerResponseCheckService"); obj.Exists() { - for _, item := range json.Get("consumerResponseCheckService").Array() { - m := make(map[string]interface{}) - for k, v := range item.Map() { - m[k] = v.Value() - } - consumerName, ok1 := m["name"] - matchType, ok2 := m["matchType"] - if !ok1 || !ok2 { - continue - } - switch fmt.Sprint(matchType) { - case "exact": - m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} - case "prefix": - m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} - case "regexp": - m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} - } - config.consumerResponseCheckService = append(config.consumerResponseCheckService, m) - } - } - if obj := json.Get("consumerRiskLevel"); obj.Exists() { - for _, item := range json.Get("consumerRiskLevel").Array() { - m := make(map[string]interface{}) - for k, v := range item.Map() { - m[k] = v.Value() - } - consumerName, ok1 := m["name"] - matchType, ok2 := m["matchType"] - if !ok1 || !ok2 { - continue - } - switch fmt.Sprint(matchType) { - case "exact": - m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)} - case "prefix": - m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)} - case "regexp": - m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))} - } - config.consumerRiskLevel = append(config.consumerRiskLevel, m) - } - } - config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ - FQDN: serviceName, - Port: servicePort, - Host: serviceHost, - }) - config.metrics = make(map[string]proxywasm.MetricCounter) - return nil -} - -func generateRandomID() string { - const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, 29) - for i := range b { - b[i] = charset[mrand.Intn(len(charset))] - } - return "chatcmpl-" + string(b) -} - -func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { consumer, _ := proxywasm.GetHttpRequestHeader("x-mse-consumer") - ctx.SetContext("consumer", consumer) + ctx.SetContext("x-mse-consumer", consumer) ctx.DisableReroute() - if !config.checkRequest { + if !config.CheckRequest { log.Debugf("request checking is disabled") ctx.DontReadRequestBody() } return types.ActionContinue } -func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action { - consumer, _ := ctx.GetContext("consumer").(string) +func onHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { log.Debugf("checking request body...") - startTime := time.Now().UnixMilli() - content := gjson.GetBytes(body, config.requestContentJsonPath).String() - log.Debugf("Raw request content is: %s", content) - if len(content) == 0 { - log.Info("request content is empty. skip") + switch config.Action { + case cfg.MultiModalGuard: + return multi_modal_guard.OnHttpRequestBody(ctx, config, body) + case cfg.TextModerationPlus: + return text_moderation_plus.OnHttpRequestBody(ctx, config, body) + default: + log.Warnf("Unknown action %s", config.Action) return types.ActionContinue } - contentIndex := 0 - sessionID, _ := generateHexID(20) - var singleCall func() - callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Info(string(responseBody)) - if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { - proxywasm.ResumeHttpRequest() - return - } - var response Response - err := json.Unmarshal(responseBody, &response) - if err != nil { - log.Error("failed to unmarshal aliyun content security response at request phase") - proxywasm.ResumeHttpRequest() - return - } - if isRiskLevelAcceptable(config.action, response.Data, config, consumer) { - if contentIndex >= len(content) { - endTime := time.Now().UnixMilli() - ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "request pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) - proxywasm.ResumeHttpRequest() - } else { - singleCall() - } - return - } - denyMessage := DefaultDenyMessage - if config.denyMessage != "" { - denyMessage = config.denyMessage - } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { - denyMessage = response.Data.Advice[0].Answer - } - marshalledDenyMessage := wrapper.MarshalStr(denyMessage) - if config.protocolOriginal { - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) - } else if gjson.GetBytes(body, "stream").Bool() { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } - ctx.DontReadResponseBody() - config.incrementCounter("ai_sec_request_deny", 1) - endTime := time.Now().UnixMilli() - ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "reqeust deny") - if response.Data.Advice != nil { - ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) - ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) - } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) - } - singleCall = func() { - timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") - randomID, _ := generateHexID(16) - var nextContentIndex int - if contentIndex+LengthLimit >= len(content) { - nextContentIndex = len(content) - } else { - nextContentIndex = contentIndex + LengthLimit - } - contentPiece := content[contentIndex:nextContentIndex] - contentIndex = nextContentIndex - log.Debugf("current content piece: %s", contentPiece) - checkService := config.getRequestCheckService(consumer) - params := map[string]string{ - "Format": "JSON", - "Version": "2022-03-02", - "SignatureMethod": "Hmac-SHA1", - "SignatureNonce": randomID, - "SignatureVersion": "1.0", - "Action": config.action, - "AccessKeyId": config.ak, - "Timestamp": timestamp, - "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent), - } - if config.token != "" { - params["SecurityToken"] = config.token - } - signature := getSign(params, config.sk+"&") - reqParams := url.Values{} - for k, v := range params { - reqParams.Add(k, v) - } - reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout) - if err != nil { - log.Errorf("failed call the safe check service: %v", err) - proxywasm.ResumeHttpRequest() - } - } - singleCall() - return types.ActionPause } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig) types.Action { - if !config.checkResponse { +func onHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action { + if !config.CheckResponse { log.Debugf("response checking is disabled") ctx.DontReadResponseBody() return types.ActionContinue @@ -747,257 +65,39 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig) typ ctx.DontReadResponseBody() return types.ActionContinue } - contentType, _ := proxywasm.GetHttpResponseHeader("content-type") - ctx.SetContext("end_of_stream_received", false) - ctx.SetContext("during_call", false) - ctx.SetContext("risk_detected", false) - sessionID, _ := generateHexID(20) - ctx.SetContext("sessionID", sessionID) - if strings.Contains(contentType, "text/event-stream") { - ctx.NeedPauseStreamingResponse() + switch config.Action { + case cfg.MultiModalGuard: + return multi_modal_guard.OnHttpResponseHeaders(ctx, config) + case cfg.TextModerationPlus: + return text_moderation_plus.OnHttpResponseHeaders(ctx, config) + default: + log.Warnf("Unknown action %s", config.Action) return types.ActionContinue - } else { - ctx.BufferResponseBody() - return types.HeaderStopIteration } } -func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, data []byte, endOfStream bool) []byte { - consumer, _ := ctx.GetContext("consumer").(string) - var bufferQueue [][]byte - var singleCall func() - callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Info(string(responseBody)) - if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { - if ctx.GetContext("end_of_stream_received").(bool) { - proxywasm.ResumeHttpResponse() - } - ctx.SetContext("during_call", false) - return - } - var response Response - err := json.Unmarshal(responseBody, &response) - if err != nil { - log.Error("failed to unmarshal aliyun content security response at response phase") - if ctx.GetContext("end_of_stream_received").(bool) { - proxywasm.ResumeHttpResponse() - } - ctx.SetContext("during_call", false) - return - } - if !isRiskLevelAcceptable(config.action, response.Data, config, consumer) { - denyMessage := DefaultDenyMessage - if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { - denyMessage = "\n" + response.Data.Advice[0].Answer - } else if config.denyMessage != "" { - denyMessage = config.denyMessage - } - marshalledDenyMessage := wrapper.MarshalStr(denyMessage) - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.InjectEncodedDataToFilterChain(jsonData, true) - return - } - endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0 - proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream) - bufferQueue = [][]byte{} - if !endStream { - ctx.SetContext("during_call", false) - singleCall() - } - } - singleCall = func() { - if ctx.GetContext("during_call").(bool) { - return - } - if ctx.BufferQueueSize() >= config.bufferLimit || ctx.GetContext("end_of_stream_received").(bool) { - var buffer string - for ctx.BufferQueueSize() > 0 { - front := ctx.PopBuffer() - bufferQueue = append(bufferQueue, front) - msg := gjson.GetBytes(front, config.responseStreamContentJsonPath).String() - buffer += msg - if len([]rune(buffer)) >= config.bufferLimit { - break - } - } - // if streaming body has reasoning_content, buffer maybe empty - log.Debugf("current content piece: %s", buffer) - if len(buffer) == 0 { - return - } - ctx.SetContext("during_call", true) - timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") - randomID, _ := generateHexID(16) - log.Debugf("current content piece: %s", buffer) - checkService := config.getResponseCheckService(consumer) - params := map[string]string{ - "Format": "JSON", - "Version": "2022-03-02", - "SignatureMethod": "Hmac-SHA1", - "SignatureNonce": randomID, - "SignatureVersion": "1.0", - "Action": config.action, - "AccessKeyId": config.ak, - "Timestamp": timestamp, - "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, ctx.GetContext("sessionID").(string), wrapper.MarshalStr(buffer), AliyunUserAgent), - } - if config.token != "" { - params["SecurityToken"] = config.token - } - signature := getSign(params, config.sk+"&") - reqParams := url.Values{} - for k, v := range params { - reqParams.Add(k, v) - } - reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout) - if err != nil { - log.Errorf("failed call the safe check service: %v", err) - if ctx.GetContext("end_of_stream_received").(bool) { - proxywasm.ResumeHttpResponse() - } - } - } - } - if !ctx.GetContext("risk_detected").(bool) { - for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) { - ctx.PushBuffer([]byte(string(chunk) + "\n\n")) - } - ctx.SetContext("end_of_stream_received", endOfStream) - if !ctx.GetContext("during_call").(bool) { - singleCall() - } - } else if endOfStream { - proxywasm.ResumeHttpResponse() +func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte { + log.Debugf("checking streaming response body...") + switch config.Action { + case cfg.MultiModalGuard: + return multi_modal_guard.OnHttpStreamingResponseBody(ctx, config, data, endOfStream) + case cfg.TextModerationPlus: + return text_moderation_plus.OnHttpStreamingResponseBody(ctx, config, data, endOfStream) + default: + log.Warnf("Unknown action %s", config.Action) + return data } - return []byte{} } -func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte) types.Action { - consumer, _ := ctx.GetContext("consumer").(string) +func onHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action { log.Debugf("checking response body...") - startTime := time.Now().UnixMilli() - contentType, _ := proxywasm.GetHttpResponseHeader("content-type") - isStreamingResponse := strings.Contains(contentType, "event-stream") - var content string - if isStreamingResponse { - content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) - } else { - content = gjson.GetBytes(body, config.responseContentJsonPath).String() - } - log.Debugf("Raw response content is: %s", content) - if len(content) == 0 { - log.Info("response content is empty. skip") + switch config.Action { + case cfg.MultiModalGuard: + return multi_modal_guard.OnHttpResponseBody(ctx, config, body) + case cfg.TextModerationPlus: + return text_moderation_plus.OnHttpResponseBody(ctx, config, body) + default: + log.Warnf("Unknown action %s", config.Action) return types.ActionContinue } - contentIndex := 0 - sessionID, _ := generateHexID(20) - var singleCall func() - callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Info(string(responseBody)) - if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 { - proxywasm.ResumeHttpResponse() - return - } - var response Response - err := json.Unmarshal(responseBody, &response) - if err != nil { - log.Error("failed to unmarshal aliyun content security response at response phase") - proxywasm.ResumeHttpResponse() - return - } - if isRiskLevelAcceptable(config.action, response.Data, config, consumer) { - if contentIndex >= len(content) { - endTime := time.Now().UnixMilli() - ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "response pass") - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) - proxywasm.ResumeHttpResponse() - } else { - singleCall() - } - return - } - denyMessage := DefaultDenyMessage - if config.denyMessage != "" { - denyMessage = config.denyMessage - } else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" { - denyMessage = response.Data.Advice[0].Answer - } - marshalledDenyMessage := wrapper.MarshalStr(denyMessage) - if config.protocolOriginal { - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) - } else if isStreamingResponse { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) - } else { - randomID := generateRandomID() - jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage)) - proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } - config.incrementCounter("ai_sec_response_deny", 1) - endTime := time.Now().UnixMilli() - ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime) - ctx.SetUserAttribute("safecheck_status", "response deny") - if response.Data.Advice != nil { - ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label) - ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords) - } - ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey) - } - singleCall = func() { - timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") - randomID, _ := generateHexID(16) - var nextContentIndex int - if contentIndex+LengthLimit >= len(content) { - nextContentIndex = len(content) - } else { - nextContentIndex = contentIndex + LengthLimit - } - contentPiece := content[contentIndex:nextContentIndex] - contentIndex = nextContentIndex - log.Debugf("current content piece: %s", contentPiece) - checkService := config.getResponseCheckService(consumer) - params := map[string]string{ - "Format": "JSON", - "Version": "2022-03-02", - "SignatureMethod": "Hmac-SHA1", - "SignatureNonce": randomID, - "SignatureVersion": "1.0", - "Action": config.action, - "AccessKeyId": config.ak, - "Timestamp": timestamp, - "Service": checkService, - "ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s","requestFrom": "%s"}`, sessionID, wrapper.MarshalStr(contentPiece), AliyunUserAgent), - } - if config.token != "" { - params["SecurityToken"] = config.token - } - signature := getSign(params, config.sk+"&") - reqParams := url.Values{} - for k, v := range params { - reqParams.Add(k, v) - } - reqParams.Add("Signature", signature) - err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout) - if err != nil { - log.Errorf("failed call the safe check service: %v", err) - proxywasm.ResumeHttpResponse() - } - } - singleCall() - return types.ActionPause -} - -func extractMessageFromStreamingBody(data []byte, jsonPath string) string { - chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) - strChunks := []string{} - for _, chunk := range chunks { - // Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}] - strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String()) - } - return strings.Join(strChunks, "") } diff --git a/plugins/wasm-go/extensions/ai-security-guard/main_test.go b/plugins/wasm-go/extensions/ai-security-guard/main_test.go index 18cdece649..351d2e377b 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main_test.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main_test.go @@ -18,6 +18,8 @@ import ( "encoding/json" "testing" + cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/test" "github.com/stretchr/testify/require" @@ -143,16 +145,16 @@ func TestParseConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, config) - securityConfig := config.(*AISecurityConfig) - require.Equal(t, "test-ak", securityConfig.ak) - require.Equal(t, "test-sk", securityConfig.sk) - require.Equal(t, true, securityConfig.checkRequest) - require.Equal(t, true, securityConfig.checkResponse) - require.Equal(t, "high", securityConfig.contentModerationLevelBar) - require.Equal(t, "high", securityConfig.promptAttackLevelBar) - require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar) - require.Equal(t, uint32(2000), securityConfig.timeout) - require.Equal(t, 1000, securityConfig.bufferLimit) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, "test-ak", securityConfig.AK) + require.Equal(t, "test-sk", securityConfig.SK) + require.Equal(t, true, securityConfig.CheckRequest) + require.Equal(t, true, securityConfig.CheckResponse) + require.Equal(t, "high", securityConfig.ContentModerationLevelBar) + require.Equal(t, "high", securityConfig.PromptAttackLevelBar) + require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar) + require.Equal(t, uint32(2000), securityConfig.Timeout) + require.Equal(t, 1000, securityConfig.BufferLimit) }) // 测试仅检查请求的配置 @@ -164,12 +166,12 @@ func TestParseConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, config) - securityConfig := config.(*AISecurityConfig) - require.Equal(t, true, securityConfig.checkRequest) - require.Equal(t, false, securityConfig.checkResponse) - require.Equal(t, "high", securityConfig.contentModerationLevelBar) - require.Equal(t, "high", securityConfig.promptAttackLevelBar) - require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, true, securityConfig.CheckRequest) + require.Equal(t, false, securityConfig.CheckResponse) + require.Equal(t, "high", securityConfig.ContentModerationLevelBar) + require.Equal(t, "high", securityConfig.PromptAttackLevelBar) + require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar) }) // 测试缺少必需字段的配置 @@ -202,13 +204,13 @@ func TestParseConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, config) - securityConfig := config.(*AISecurityConfig) - require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa")) - require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa")) - require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb")) - require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test")) - require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc")) - require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test")) + securityConfig := config.(*cfg.AISecurityConfig) + require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa")) + require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa")) + require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb")) + require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test")) + require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc")) + require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test")) }) }) } @@ -385,62 +387,27 @@ func TestOnHttpResponseHeaders(t *testing.T) { func TestRiskLevelFunctions(t *testing.T) { // 测试风险等级转换函数 t.Run("risk level conversion", func(t *testing.T) { - require.Equal(t, 4, levelToInt(MaxRisk)) - require.Equal(t, 3, levelToInt(HighRisk)) - require.Equal(t, 2, levelToInt(MediumRisk)) - require.Equal(t, 1, levelToInt(LowRisk)) - require.Equal(t, 0, levelToInt(NoRisk)) - require.Equal(t, -1, levelToInt("invalid")) + require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk)) + require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk)) + require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk)) + require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk)) + require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk)) + require.Equal(t, -1, cfg.LevelToInt("invalid")) }) // 测试风险等级比较 t.Run("risk level comparison", func(t *testing.T) { - require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk)) - require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk)) - require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk)) - require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk)) + require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk)) + require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk)) + require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk)) + require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk)) }) } func TestUtilityFunctions(t *testing.T) { - // 测试URL编码函数 - t.Run("url encoding", func(t *testing.T) { - original := "test+string:with=special&chars@$" - encoded := urlEncoding(original) - require.NotEqual(t, original, encoded) - require.Contains(t, encoded, "%2B") // + 应该被编码 - require.Contains(t, encoded, "%3A") // : 应该被编码 - require.Contains(t, encoded, "%3D") // = 应该被编码 - require.Contains(t, encoded, "%26") // & 应该被编码 - }) - - // 测试HMAC-SHA1签名函数 - t.Run("hmac sha1", func(t *testing.T) { - message := "test message" - secret := "test secret" - signature := hmacSha1(message, secret) - require.NotEmpty(t, signature) - require.NotEqual(t, message, signature) - }) - - // 测试签名生成函数 - t.Run("signature generation", func(t *testing.T) { - host, status := test.NewTestHost(basicConfig) - defer host.Reset() - require.Equal(t, types.OnPluginStartStatusOK, status) - - params := map[string]string{ - "key1": "value1", - "key2": "value2", - } - secret := "test-secret" - signature := getSign(params, secret) - require.NotEmpty(t, signature) - }) - // 测试十六进制ID生成函数 t.Run("hex id generation", func(t *testing.T) { - id, err := generateHexID(16) + id, err := utils.GenerateHexID(16) require.NoError(t, err) require.Len(t, id, 16) require.Regexp(t, "^[0-9a-f]+$", id) @@ -448,7 +415,7 @@ func TestUtilityFunctions(t *testing.T) { // 测试随机ID生成函数 t.Run("random id generation", func(t *testing.T) { - id := generateRandomID() + id := utils.GenerateRandomChatID() require.NotEmpty(t, id) require.Contains(t, id, "chatcmpl-") require.Len(t, id, 38) // "chatcmpl-" + 29 random chars diff --git a/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go b/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go new file mode 100644 index 0000000000..16d92f1935 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/utils/utils.go @@ -0,0 +1,43 @@ +package utils + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + mrand "math/rand" + "strings" + + "github.com/higress-group/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func GenerateHexID(length int) (string, error) { + bytes := make([]byte, length/2) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} + +func GenerateRandomChatID() string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, 29) + for i := range b { + b[i] = charset[mrand.Intn(len(charset))] + } + return "chatcmpl-" + string(b) +} + +func ExtractMessageFromStreamingBody(data []byte, jsonPath string) string { + chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) + strChunks := []string{} + for _, chunk := range chunks { + // Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}] + strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String()) + } + return strings.Join(strChunks, "") +} + +func GetConsumer(ctx wrapper.HttpContext) string { + return ctx.GetStringContext("x-mse-consumer", "") +}