Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/internal/pkg/apicompat/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ type AnthropicDelta struct {
type ResponsesRequest struct {
Model string `json:"model"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
Instructions string `json:"instructions,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Expand Down
192 changes: 192 additions & 0 deletions backend/internal/service/openai_compat_prompt_cache_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)

const compatPromptCacheKeyPrefix = "compat_cc_"
const responsesPromptCacheKeyPrefix = "compat_rsp_"

func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch normalizeCodexModel(strings.TrimSpace(model)) {
Expand Down Expand Up @@ -65,6 +66,197 @@
return compatPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
}

func deriveResponsesPromptCacheKey(req *apicompat.ResponsesRequest, mappedModel string) string {
if req == nil {
return ""
}

normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel))

Check failure on line 74 in backend/internal/service/openai_compat_prompt_cache_key.go

View workflow job for this annotation

GitHub Actions / backend-security

undefined: resolveOpenAIUpstreamModel
if normalizedModel == "" {
normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model))

Check failure on line 76 in backend/internal/service/openai_compat_prompt_cache_key.go

View workflow job for this annotation

GitHub Actions / backend-security

undefined: resolveOpenAIUpstreamModel
}
if normalizedModel == "" {
normalizedModel = strings.TrimSpace(req.Model)
}

seedParts := []string{"model=" + normalizedModel}
if req.Reasoning != nil && req.Reasoning.Effort != "" {
seedParts = append(seedParts, "reasoning_effort="+strings.TrimSpace(req.Reasoning.Effort))
}
if len(req.ToolChoice) > 0 {
seedParts = append(seedParts, "tool_choice="+normalizeCompatSeedJSON(req.ToolChoice))
}
if len(req.Tools) > 0 {
if raw, err := json.Marshal(req.Tools); err == nil {
seedParts = append(seedParts, "tools="+normalizeCompatSeedJSON(raw))
}
}
if instructions := normalizeResponsesStringSeed(strings.TrimSpace(extractResponsesInstructionsForSeed(req))); instructions != "" {
seedParts = append(seedParts, "system="+instructions)
}

systemParts, firstUser := extractResponsesInputSeedParts(req.Input)
for _, part := range systemParts {
seedParts = append(seedParts, "system="+part)
}
if firstUser != "" {
seedParts = append(seedParts, "first_user="+firstUser)
}

return responsesPromptCacheKeyPrefix + hashSensitiveValueForLog(strings.Join(seedParts, "|"))
}

func deriveResponsesPromptCacheKeyFromBody(body []byte, mappedModel string) string {
if len(body) == 0 {
return ""
}

var req apicompat.ResponsesRequest
if err := json.Unmarshal(body, &req); err != nil {
return ""
}
return deriveResponsesPromptCacheKey(&req, mappedModel)
}

func extractResponsesInstructionsForSeed(req *apicompat.ResponsesRequest) string {
if req == nil {
return ""
}
return req.Instructions
}

func extractResponsesInputSeedParts(raw json.RawMessage) ([]string, string) {
if len(raw) == 0 {
return nil, ""
}

var inputString string
if err := json.Unmarshal(raw, &inputString); err == nil {
return nil, normalizeResponsesStringSeed(strings.TrimSpace(inputString))
}

var items []json.RawMessage
if err := json.Unmarshal(raw, &items); err != nil {
return nil, ""
}

systemParts := make([]string, 0)
implicitUserItems := make([]json.RawMessage, 0)
firstUser := ""

for _, itemRaw := range items {
itemRaw = trimJSONRawMessage(itemRaw)
if len(itemRaw) == 0 {
continue
}

var item map[string]json.RawMessage
if err := json.Unmarshal(itemRaw, &item); err != nil {
if firstUser == "" && len(implicitUserItems) == 0 {
implicitUserItems = append(implicitUserItems, itemRaw)
}
continue
}

role := rawJSONStringField(item["role"])
contentRaw := trimJSONRawMessage(item["content"])

switch role {
case "system":
if normalized := normalizeResponsesSeedRaw(contentRaw, itemRaw); normalized != "" {
systemParts = append(systemParts, normalized)
}
continue
case "user":
if firstUser == "" {
firstUser = normalizeResponsesSeedRaw(contentRaw, itemRaw)
}
implicitUserItems = implicitUserItems[:0]
continue
}

if firstUser != "" {
continue
}

if isImplicitResponsesUserItem(item) {
implicitUserItems = append(implicitUserItems, itemRaw)
continue
}

if len(implicitUserItems) > 0 {
break
}
}

if firstUser == "" && len(implicitUserItems) > 0 {
if rawItems, err := json.Marshal(implicitUserItems); err == nil {
firstUser = normalizeCompatSeedJSON(rawItems)
}
}

return systemParts, firstUser
}

func normalizeResponsesStringSeed(v string) string {
if strings.TrimSpace(v) == "" {
return ""
}
raw, err := json.Marshal(v)
if err != nil {
return ""
}
return normalizeCompatSeedJSON(raw)
}

func normalizeResponsesSeedRaw(primary json.RawMessage, fallback json.RawMessage) string {
if normalized := normalizeCompatSeedJSON(trimJSONRawMessage(primary)); normalized != "" {
return normalized
}
return normalizeCompatSeedJSON(trimJSONRawMessage(fallback))
}

func rawJSONStringField(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if err := json.Unmarshal(raw, &s); err != nil {
return ""
}
return strings.TrimSpace(s)
}

func isImplicitResponsesUserItem(item map[string]json.RawMessage) bool {
typ := rawJSONStringField(item["type"])
switch typ {
case "input_text", "input_image", "text", "message":
return true
case "function_call", "function_call_output", "reasoning", "item_reference":
return false
}

if len(trimJSONRawMessage(item["content"])) > 0 {
return true
}
if len(trimJSONRawMessage(item["text"])) > 0 {
return true
}
if len(trimJSONRawMessage(item["image_url"])) > 0 {
return true
}

return false
}

func trimJSONRawMessage(raw json.RawMessage) json.RawMessage {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
return nil
}
return json.RawMessage(trimmed)
}

func normalizeCompatSeedJSON(v json.RawMessage) string {
if len(v) == 0 {
return ""
Expand Down
66 changes: 66 additions & 0 deletions backend/internal/service/openai_compat_prompt_cache_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,69 @@ func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) {
require.NotEmpty(t, k1)
require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key")
}

func TestDeriveResponsesPromptCacheKey_StableAcrossLaterTurns(t *testing.T) {
base := &apicompat.ResponsesRequest{
Model: "gpt-5.4",
Instructions: "You are helpful.",
Input: mustRawJSON(t, `[
{"role":"user","content":[{"type":"input_text","text":"Hello"}]}
]`),
}
extended := &apicompat.ResponsesRequest{
Model: "gpt-5.4",
Instructions: "You are helpful.",
Input: mustRawJSON(t, `[
{"role":"user","content":[{"type":"input_text","text":"Hello"}]},
{"role":"assistant","content":[{"type":"output_text","text":"Hi there!"}]},
{"role":"user","content":[{"type":"input_text","text":"How are you?"}]}
]`),
}

k1 := deriveResponsesPromptCacheKey(base, "gpt-5.4")
k2 := deriveResponsesPromptCacheKey(extended, "gpt-5.4")
require.Equal(t, k1, k2, "cache key should be stable across later turns")
require.NotEmpty(t, k1)
}

func TestDeriveResponsesPromptCacheKey_DiffersAcrossSessions(t *testing.T) {
req1 := &apicompat.ResponsesRequest{
Model: "gpt-5.3-codex",
Input: mustRawJSON(t, `[
{"type":"input_text","text":"Question A"}
]`),
}
req2 := &apicompat.ResponsesRequest{
Model: "gpt-5.3-codex",
Input: mustRawJSON(t, `[
{"type":"input_text","text":"Question B"}
]`),
}

k1 := deriveResponsesPromptCacheKey(req1, "gpt-5.3-codex")
k2 := deriveResponsesPromptCacheKey(req2, "gpt-5.3-codex")
require.NotEqual(t, k1, k2, "different first user inputs should yield different keys")
}

func TestDeriveResponsesPromptCacheKey_LeadingImplicitUserItemsStayStable(t *testing.T) {
base := &apicompat.ResponsesRequest{
Model: "gpt-5.3-codex",
Input: mustRawJSON(t, `[
{"type":"input_text","text":"Hello"},
{"type":"input_image","image_url":"data:image/png;base64,AAAA"}
]`),
}
extended := &apicompat.ResponsesRequest{
Model: "gpt-5.3-codex",
Input: mustRawJSON(t, `[
{"type":"input_text","text":"Hello"},
{"type":"input_image","image_url":"data:image/png;base64,AAAA"},
{"type":"reasoning","encrypted_content":"gAAA"}
]`),
}

k1 := deriveResponsesPromptCacheKey(base, "gpt-5.3-codex")
k2 := deriveResponsesPromptCacheKey(extended, "openai/gpt-5.3-codex")
require.NotEmpty(t, k1)
require.Equal(t, k1, k2, "implicit leading user items should remain stable when later non-user items are appended")
}
20 changes: 20 additions & 0 deletions backend/internal/service/openai_gateway_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if v, ok := reqBody["stream"].(bool); ok {
reqStream = v
}
if promptCacheKey == "" {
promptCacheKey = s.ExtractSessionID(c, body)
}
if promptCacheKey == "" {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
Expand Down Expand Up @@ -1967,6 +1970,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}

autoPromptCacheKeyInjected := false
if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) {
promptCacheKey = deriveResponsesPromptCacheKeyFromBody(body, upstreamModel)
autoPromptCacheKeyInjected = promptCacheKey != ""
}

if account.Type == AccountTypeOAuth {
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c))
if codexResult.Modified {
Expand All @@ -1978,8 +1987,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
bodyModified = true
markPatchSet("prompt_cache_key", promptCacheKey)
}
}
if autoPromptCacheKeyInjected {
logger.L().Debug("openai responses: stable prompt_cache_key injected",
zap.Int64("account_id", account.ID),
zap.String("upstream_model", upstreamModel),
zap.String("prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey)),
)
}

// Handle max_output_tokens based on platform and account type
if !isCodexCLI {
Expand Down
Loading
Loading