diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 513b7996db..2e68e7f242 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -8,6 +8,11 @@ package main import ( "context" + "log" + "net/http" + "sync" + "time" + "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -17,14 +22,9 @@ import ( "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" - "log" - "net/http" - "sync" - "time" -) -import ( _ "embed" + _ "github.com/Wei-Shaw/sub2api/ent/runtime" ) @@ -181,7 +181,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService) - geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) + geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, concurrencyService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) soraS3Storage := service.NewSoraS3Storage(settingService) diff --git a/backend/ent/group.go b/backend/ent/group.go index fc691a9b4c..54098defd0 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -86,6 +86,8 @@ type Group struct { RequirePrivacySet bool `json:"require_privacy_set,omitempty"` // 默认映射模型 ID,当账号级映射找不到时使用此值 DefaultMappedModel string `json:"default_mapped_model,omitempty"` + // 是否启用基于 proxy_id 分桶的负载均衡调度 + ProxyBucketLoadBalanceEnabled bool `json:"proxy_bucket_load_balance_enabled,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -194,7 +196,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet, group.FieldProxyBucketLoadBalanceEnabled: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) @@ -447,6 +449,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.DefaultMappedModel = value.String } + case group.FieldProxyBucketLoadBalanceEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field proxy_bucket_load_balance_enabled", values[i]) + } else if value.Valid { + _m.ProxyBucketLoadBalanceEnabled = value.Bool + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -652,6 +660,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("default_mapped_model=") builder.WriteString(_m.DefaultMappedModel) + builder.WriteString(", ") + builder.WriteString("proxy_bucket_load_balance_enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.ProxyBucketLoadBalanceEnabled)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 352221275b..65adf4a880 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -83,6 +83,8 @@ const ( FieldRequirePrivacySet = "require_privacy_set" // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. FieldDefaultMappedModel = "default_mapped_model" + // FieldProxyBucketLoadBalanceEnabled holds the string denoting the proxy_bucket_load_balance_enabled field in the database. + FieldProxyBucketLoadBalanceEnabled = "proxy_bucket_load_balance_enabled" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -192,6 +194,7 @@ var Columns = []string{ FieldRequireOauthOnly, FieldRequirePrivacySet, FieldDefaultMappedModel, + FieldProxyBucketLoadBalanceEnabled, } var ( @@ -269,6 +272,8 @@ var ( DefaultDefaultMappedModel string // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. DefaultMappedModelValidator func(string) error + // DefaultProxyBucketLoadBalanceEnabled holds the default value on creation for the "proxy_bucket_load_balance_enabled" field. + DefaultProxyBucketLoadBalanceEnabled bool ) // OrderOption defines the ordering options for the Group queries. @@ -439,6 +444,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() } +// ByProxyBucketLoadBalanceEnabled orders the results by the proxy_bucket_load_balance_enabled field. +func ByProxyBucketLoadBalanceEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProxyBucketLoadBalanceEnabled, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 41bd575a53..89bdeb8d7d 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -215,6 +215,11 @@ func DefaultMappedModel(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) } +// ProxyBucketLoadBalanceEnabled applies equality check predicate on the "proxy_bucket_load_balance_enabled" field. It's identical to ProxyBucketLoadBalanceEnabledEQ. +func ProxyBucketLoadBalanceEnabled(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldProxyBucketLoadBalanceEnabled, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1585,6 +1590,16 @@ func DefaultMappedModelContainsFold(v string) predicate.Group { return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v)) } +// ProxyBucketLoadBalanceEnabledEQ applies the EQ predicate on the "proxy_bucket_load_balance_enabled" field. +func ProxyBucketLoadBalanceEnabledEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldProxyBucketLoadBalanceEnabled, v)) +} + +// ProxyBucketLoadBalanceEnabledNEQ applies the NEQ predicate on the "proxy_bucket_load_balance_enabled" field. +func ProxyBucketLoadBalanceEnabledNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldProxyBucketLoadBalanceEnabled, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index a635dfd999..3cfd666a9c 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -480,6 +480,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate { return _c } +// SetProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field. +func (_c *GroupCreate) SetProxyBucketLoadBalanceEnabled(v bool) *GroupCreate { + _c.mutation.SetProxyBucketLoadBalanceEnabled(v) + return _c +} + +// SetNillableProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field if the given value is not nil. +func (_c *GroupCreate) SetNillableProxyBucketLoadBalanceEnabled(v *bool) *GroupCreate { + if v != nil { + _c.SetProxyBucketLoadBalanceEnabled(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -685,6 +699,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultMappedModel _c.mutation.SetDefaultMappedModel(v) } + if _, ok := _c.mutation.ProxyBucketLoadBalanceEnabled(); !ok { + v := group.DefaultProxyBucketLoadBalanceEnabled + _c.mutation.SetProxyBucketLoadBalanceEnabled(v) + } return nil } @@ -772,6 +790,9 @@ func (_c *GroupCreate) check() error { return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} } } + if _, ok := _c.mutation.ProxyBucketLoadBalanceEnabled(); !ok { + return &ValidationError{Name: "proxy_bucket_load_balance_enabled", err: errors.New(`ent: missing required field "Group.proxy_bucket_load_balance_enabled"`)} + } return nil } @@ -935,6 +956,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _node.DefaultMappedModel = value } + if value, ok := _c.mutation.ProxyBucketLoadBalanceEnabled(); ok { + _spec.SetField(group.FieldProxyBucketLoadBalanceEnabled, field.TypeBool, value) + _node.ProxyBucketLoadBalanceEnabled = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1673,6 +1698,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert { return u } +// SetProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field. +func (u *GroupUpsert) SetProxyBucketLoadBalanceEnabled(v bool) *GroupUpsert { + u.Set(group.FieldProxyBucketLoadBalanceEnabled, v) + return u +} + +// UpdateProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field to the value that was provided on create. +func (u *GroupUpsert) UpdateProxyBucketLoadBalanceEnabled() *GroupUpsert { + u.SetExcluded(group.FieldProxyBucketLoadBalanceEnabled) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2397,6 +2434,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne { }) } +// SetProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field. +func (u *GroupUpsertOne) SetProxyBucketLoadBalanceEnabled(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetProxyBucketLoadBalanceEnabled(v) + }) +} + +// UpdateProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateProxyBucketLoadBalanceEnabled() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateProxyBucketLoadBalanceEnabled() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -3287,6 +3338,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk { }) } +// SetProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field. +func (u *GroupUpsertBulk) SetProxyBucketLoadBalanceEnabled(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetProxyBucketLoadBalanceEnabled(v) + }) +} + +// UpdateProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateProxyBucketLoadBalanceEnabled() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateProxyBucketLoadBalanceEnabled() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index a9a4b9da80..16de0e7b9f 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -681,6 +681,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate { return _u } +// SetProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field. +func (_u *GroupUpdate) SetProxyBucketLoadBalanceEnabled(v bool) *GroupUpdate { + _u.mutation.SetProxyBucketLoadBalanceEnabled(v) + return _u +} + +// SetNillableProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableProxyBucketLoadBalanceEnabled(v *bool) *GroupUpdate { + if v != nil { + _u.SetProxyBucketLoadBalanceEnabled(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1183,6 +1197,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } + if value, ok := _u.mutation.ProxyBucketLoadBalanceEnabled(); ok { + _spec.SetField(group.FieldProxyBucketLoadBalanceEnabled, field.TypeBool, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -2143,6 +2160,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO return _u } +// SetProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field. +func (_u *GroupUpdateOne) SetProxyBucketLoadBalanceEnabled(v bool) *GroupUpdateOne { + _u.mutation.SetProxyBucketLoadBalanceEnabled(v) + return _u +} + +// SetNillableProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableProxyBucketLoadBalanceEnabled(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetProxyBucketLoadBalanceEnabled(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2675,6 +2706,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } + if value, ok := _u.mutation.ProxyBucketLoadBalanceEnabled(); ok { + _spec.SetField(group.FieldProxyBucketLoadBalanceEnabled, field.TypeBool, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index bdbb9fdddd..0c057f1f9b 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -412,6 +412,7 @@ var ( {Name: "require_oauth_only", Type: field.TypeBool, Default: false}, {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, + {Name: "proxy_bucket_load_balance_enabled", Type: field.TypeBool, Default: false}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 28d9a0ef22..23b9cd9202 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -8256,6 +8256,7 @@ type GroupMutation struct { require_oauth_only *bool require_privacy_set *bool default_mapped_model *string + proxy_bucket_load_balance_enabled *bool clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -10144,6 +10145,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() { m.default_mapped_model = nil } +// SetProxyBucketLoadBalanceEnabled sets the "proxy_bucket_load_balance_enabled" field. +func (m *GroupMutation) SetProxyBucketLoadBalanceEnabled(b bool) { + m.proxy_bucket_load_balance_enabled = &b +} + +// ProxyBucketLoadBalanceEnabled returns the value of the "proxy_bucket_load_balance_enabled" field in the mutation. +func (m *GroupMutation) ProxyBucketLoadBalanceEnabled() (r bool, exists bool) { + v := m.proxy_bucket_load_balance_enabled + if v == nil { + return + } + return *v, true +} + +// OldProxyBucketLoadBalanceEnabled returns the old "proxy_bucket_load_balance_enabled" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldProxyBucketLoadBalanceEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProxyBucketLoadBalanceEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProxyBucketLoadBalanceEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProxyBucketLoadBalanceEnabled: %w", err) + } + return oldValue.ProxyBucketLoadBalanceEnabled, nil +} + +// ResetProxyBucketLoadBalanceEnabled resets all changes to the "proxy_bucket_load_balance_enabled" field. +func (m *GroupMutation) ResetProxyBucketLoadBalanceEnabled() { + m.proxy_bucket_load_balance_enabled = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -10605,6 +10642,9 @@ func (m *GroupMutation) Fields() []string { if m.default_mapped_model != nil { fields = append(fields, group.FieldDefaultMappedModel) } + if m.proxy_bucket_load_balance_enabled != nil { + fields = append(fields, group.FieldProxyBucketLoadBalanceEnabled) + } return fields } @@ -10681,6 +10721,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.RequirePrivacySet() case group.FieldDefaultMappedModel: return m.DefaultMappedModel() + case group.FieldProxyBucketLoadBalanceEnabled: + return m.ProxyBucketLoadBalanceEnabled() } return nil, false } @@ -10758,6 +10800,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldRequirePrivacySet(ctx) case group.FieldDefaultMappedModel: return m.OldDefaultMappedModel(ctx) + case group.FieldProxyBucketLoadBalanceEnabled: + return m.OldProxyBucketLoadBalanceEnabled(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -11005,6 +11049,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetDefaultMappedModel(v) return nil + case group.FieldProxyBucketLoadBalanceEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProxyBucketLoadBalanceEnabled(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -11444,6 +11495,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldDefaultMappedModel: m.ResetDefaultMappedModel() return nil + case group.FieldProxyBucketLoadBalanceEnabled: + m.ResetProxyBucketLoadBalanceEnabled() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 336b1f8243..23f02f7ee2 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -472,6 +472,10 @@ func init() { group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) + // groupDescProxyBucketLoadBalanceEnabled is the schema descriptor for proxy_bucket_load_balance_enabled field. + groupDescProxyBucketLoadBalanceEnabled := groupFields[29].Descriptor() + // group.DefaultProxyBucketLoadBalanceEnabled holds the default value on creation for the proxy_bucket_load_balance_enabled field. + group.DefaultProxyBucketLoadBalanceEnabled = groupDescProxyBucketLoadBalanceEnabled.Default.(bool) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index fd83bf26ad..04090a0975 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -163,6 +163,9 @@ func (Group) Fields() []ent.Field { MaxLen(100). Default(""). Comment("默认映射模型 ID,当账号级映射找不到时使用此值"), + field.Bool("proxy_bucket_load_balance_enabled"). + Default(false). + Comment("是否启用基于 proxy_id 分桶的负载均衡调度"), } } diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 12139b5165..4e19db72e7 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -377,7 +377,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "") + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "", nil) if err != nil { return nil, err } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 681da5e8fe..fe77d0f640 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -246,7 +246,17 @@ func (h *AccountHandler) List(c *gin.Context) { } } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode) + var proxyID *int64 + if proxyIDStr := strings.TrimSpace(c.Query("proxy_id")); proxyIDStr != "" { + parsedProxyID, parseErr := strconv.ParseInt(proxyIDStr, 10, 64) + if parseErr != nil || parsedProxyID <= 0 { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_PROXY_FILTER", "invalid proxy filter")) + return + } + proxyID = &parsedProxyID + } + + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode, proxyID) if err != nil { response.ErrorFrom(c, err) return @@ -368,7 +378,7 @@ func (h *AccountHandler) List(c *gin.Context) { result[i] = item } - etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite) + etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, groupID, privacyMode, proxyID, lite) if etag != "" { c.Header("ETag", etag) c.Header("Vary", "If-None-Match") @@ -386,6 +396,9 @@ func buildAccountsListETag( total int64, page, pageSize int, platform, accountType, status, search string, + groupID int64, + privacyMode string, + proxyID *int64, lite bool, ) string { payload := struct { @@ -396,6 +409,9 @@ func buildAccountsListETag( AccountType string `json:"type"` Status string `json:"status"` Search string `json:"search"` + GroupID int64 `json:"group_id"` + PrivacyMode string `json:"privacy_mode"` + ProxyID *int64 `json:"proxy_id"` Lite bool `json:"lite"` Items []AccountWithConcurrency `json:"items"` }{ @@ -406,6 +422,9 @@ func buildAccountsListETag( AccountType: accountType, Status: status, Search: search, + GroupID: groupID, + PrivacyMode: privacyMode, + ProxyID: proxyID, Lite: lite, Items: items, } @@ -2035,7 +2054,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "") + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", nil) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 9759cef5c0..c404be4b8e 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int return nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index caa27bc38c..227fea1e02 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -111,10 +111,11 @@ type CreateGroupRequest struct { // Sora 存储配额 SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool `json:"allow_messages_dispatch"` - RequireOAuthOnly bool `json:"require_oauth_only"` - RequirePrivacySet bool `json:"require_privacy_set"` - DefaultMappedModel string `json:"default_mapped_model"` + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` + DefaultMappedModel string `json:"default_mapped_model"` + ProxyBucketLoadBalanceEnabled bool `json:"proxy_bucket_load_balance_enabled"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -151,10 +152,11 @@ type UpdateGroupRequest struct { // Sora 存储配额 SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` - RequireOAuthOnly *bool `json:"require_oauth_only"` - RequirePrivacySet *bool `json:"require_privacy_set"` - DefaultMappedModel *string `json:"default_mapped_model"` + AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + RequireOAuthOnly *bool `json:"require_oauth_only"` + RequirePrivacySet *bool `json:"require_privacy_set"` + DefaultMappedModel *string `json:"default_mapped_model"` + ProxyBucketLoadBalanceEnabled *bool `json:"proxy_bucket_load_balance_enabled"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -274,6 +276,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, + ProxyBucketLoadBalanceEnabled: req.ProxyBucketLoadBalanceEnabled, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -329,6 +332,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, + ProxyBucketLoadBalanceEnabled: req.ProxyBucketLoadBalanceEnabled, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d9d657836d..1d3c421007 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -135,16 +135,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - MCPXMLInject: g.MCPXMLInject, - DefaultMappedModel: g.DefaultMappedModel, - SupportedModelScopes: g.SupportedModelScopes, - AccountCount: g.AccountCount, - ActiveAccountCount: g.ActiveAccountCount, - RateLimitedAccountCount: g.RateLimitedAccountCount, - SortOrder: g.SortOrder, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, + ProxyBucketLoadBalanceEnabled: g.ProxyBucketLoadBalanceEnabled, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, + ActiveAccountCount: g.ActiveAccountCount, + RateLimitedAccountCount: g.RateLimitedAccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -183,6 +184,7 @@ func groupFromServiceBase(g *service.Group) Group { AllowMessagesDispatch: g.AllowMessagesDispatch, RequireOAuthOnly: g.RequireOAuthOnly, RequirePrivacySet: g.RequirePrivacySet, + ProxyBucketLoadBalanceEnabled: g.ProxyBucketLoadBalanceEnabled, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 56b67c8c4d..631b4f76d6 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -100,7 +100,8 @@ type Group struct { SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) - AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + ProxyBucketLoadBalanceEnabled bool `json:"proxy_bucket_load_balance_enabled"` // 账号过滤控制(仅 OpenAI/Antigravity 平台有效) RequireOAuthOnly bool `json:"require_oauth_only"` @@ -123,7 +124,8 @@ type AdminGroup struct { MCPXMLInject bool `json:"mcp_xml_inject"` // OpenAI Messages 调度配置(仅 openai 平台使用) - DefaultMappedModel string `json:"default_mapped_model"` + DefaultMappedModel string `json:"default_mapped_model"` + ProxyBucketLoadBalanceEnabled bool `json:"proxy_bucket_load_balance_enabled"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index e053b668d3..aaf2c19bf9 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d45e8a1297..8dda40d59e 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -454,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "", 0, "") + return r.ListWithFilters(ctx, params, "", "", "", "", 0, "", nil) } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -504,6 +504,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati } })) } + if proxyID != nil { + q = q.Where(dbaccount.ProxyIDEQ(*proxyID)) + } total, err := q.Count(ctx) if err != nil { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index ade0d46486..41d2fbf82c 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -168,6 +168,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldSupportedModelScopes, group.FieldAllowMessagesDispatch, group.FieldDefaultMappedModel, + group.FieldProxyBucketLoadBalanceEnabled, ) }). Only(ctx) @@ -665,6 +666,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { RequireOAuthOnly: g.RequireOauthOnly, RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, + ProxyBucketLoadBalanceEnabled: g.ProxyBucketLoadBalanceEnabled, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 3cfd649bce..c44d1275c0 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). - SetDefaultMappedModel(groupIn.DefaultMappedModel) + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetProxyBucketLoadBalanceEnabled(groupIn.ProxyBucketLoadBalanceEnabled) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -134,7 +135,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). - SetDefaultMappedModel(groupIn.DefaultMappedModel) + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetProxyBucketLoadBalanceEnabled(groupIn.ProxyBucketLoadBalanceEnabled) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 450c312265..1d0ac33b5f 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -994,7 +994,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 328790a87f..8703483f9b 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -38,7 +38,7 @@ type AccountRepository interface { Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 81169a029b..bc2397d3c8 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index b6d7d634df..1a5eb14337 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -55,7 +55,7 @@ type AdminService interface { ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -161,10 +161,11 @@ type CreateGroupInput struct { // Sora 存储配额 SoraStorageQuotaBytes int64 // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool - DefaultMappedModel string - RequireOAuthOnly bool - RequirePrivacySet bool + AllowMessagesDispatch bool + DefaultMappedModel string + ProxyBucketLoadBalanceEnabled bool + RequireOAuthOnly bool + RequirePrivacySet bool // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -202,10 +203,11 @@ type UpdateGroupInput struct { // Sora 存储配额 SoraStorageQuotaBytes *int64 // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch *bool - DefaultMappedModel *string - RequireOAuthOnly *bool - RequirePrivacySet *bool + AllowMessagesDispatch *bool + DefaultMappedModel *string + ProxyBucketLoadBalanceEnabled *bool + RequireOAuthOnly *bool + RequirePrivacySet *bool // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -949,6 +951,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequireOAuthOnly: input.RequireOAuthOnly, RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, + ProxyBucketLoadBalanceEnabled: input.ProxyBucketLoadBalanceEnabled, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -1191,6 +1194,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.DefaultMappedModel != nil { group.DefaultMappedModel = *input.DefaultMappedModel } + if input.ProxyBucketLoadBalanceEnabled != nil { + group.ProxyBucketLoadBalanceEnabled = *input.ProxyBucketLoadBalanceEnabled + } if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err @@ -1512,9 +1518,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode, proxyID) if err != nil { return nil, 0, err } diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index eb213e6af6..4039b7bfe2 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -20,12 +20,13 @@ type accountRepoStubForAdminList struct { listWithFiltersStatus string listWithFiltersSearch string listWithFiltersPrivacy string + listWithFiltersProxyID *int64 listWithFiltersAccounts []Account listWithFiltersResult *pagination.PaginationResult listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform @@ -33,6 +34,7 @@ func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params s.listWithFiltersStatus = status s.listWithFiltersSearch = search s.listWithFiltersPrivacy = privacyMode + s.listWithFiltersProxyID = proxyID if s.listWithFiltersErr != nil { return nil, nil, s.listWithFiltersErr @@ -170,7 +172,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", nil) require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) @@ -192,7 +194,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, nil) require.NoError(t, err) require.Equal(t, int64(1), total) require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index e8ad5c9c32..fcba0f17cd 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -67,8 +67,9 @@ type APIKeyAuthGroupSnapshot struct { SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool `json:"allow_messages_dispatch"` - DefaultMappedModel string `json:"default_mapped_model,omitempty"` + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + DefaultMappedModel string `json:"default_mapped_model,omitempty"` + ProxyBucketLoadBalanceEnabled bool `json:"proxy_bucket_load_balance_enabled"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index f727ab10f3..01ad64292b 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -247,6 +247,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { SupportedModelScopes: apiKey.Group.SupportedModelScopes, AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, DefaultMappedModel: apiKey.Group.DefaultMappedModel, + ProxyBucketLoadBalanceEnabled: apiKey.Group.ProxyBucketLoadBalanceEnabled, } } return snapshot @@ -306,6 +307,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho SupportedModelScopes: snapshot.Group.SupportedModelScopes, AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, DefaultMappedModel: snapshot.Group.DefaultMappedModel, + ProxyBucketLoadBalanceEnabled: snapshot.Group.ProxyBucketLoadBalanceEnabled, } } s.compileAPIKeyIPRules(apiKey) diff --git a/backend/internal/service/gateway_account_selection_test.go b/backend/internal/service/gateway_account_selection_test.go index 0a82fade7a..935b28fbd1 100644 --- a/backend/internal/service/gateway_account_selection_test.go +++ b/backend/internal/service/gateway_account_selection_test.go @@ -189,18 +189,57 @@ func TestSelectByLRU_EarliestTimeWins(t *testing.T) { require.Equal(t, int64(3), result.account.ID) } -func TestSelectByLRU_TiePreferOAuth(t *testing.T) { - now := time.Now() - // 账号 1/2 LastUsedAt 相同,且同为最小值。 - accounts := []accountWithLoad{ - makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), - makeAccWithLoad(2, 1, 10, testTimePtr(now), AccountTypeOAuth), - makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(1*time.Hour)), AccountTypeAPIKey), +func TestIsProxyBucketScoreBetter_PrefersMoreAvailableAccountsOnEqualHealth(t *testing.T) { + left := proxyBucketScore{proxyID: 101, minLoadRate: 0, minWaiting: 0, minPriority: 1, availableCount: 3} + right := proxyBucketScore{proxyID: 202, minLoadRate: 0, minWaiting: 0, minPriority: 1, availableCount: 1} + + require.True(t, isProxyBucketScoreBetter(left, right)) + require.False(t, isProxyBucketScoreBetter(right, left)) +} + +func TestChooseBestProxyBucket_PrefersMoreAvailableAccountsOnEqualHealth(t *testing.T) { + proxy101 := int64(101) + proxy202 := int64(202) + candidates := []proxyBucketLoadCandidate{ + {account: &Account{ID: 1, Priority: 1, ProxyID: &proxy101}, loadInfo: &AccountLoadInfo{AccountID: 1, LoadRate: 0, WaitingCount: 0}}, + {account: &Account{ID: 2, Priority: 1, ProxyID: &proxy202}, loadInfo: &AccountLoadInfo{AccountID: 2, LoadRate: 0, WaitingCount: 0}}, + {account: &Account{ID: 3, Priority: 1, ProxyID: &proxy202}, loadInfo: &AccountLoadInfo{AccountID: 3, LoadRate: 0, WaitingCount: 0}}, } - for i := 0; i < 50; i++ { - result := selectByLRU(accounts, true) - require.NotNil(t, result) - require.Equal(t, AccountTypeOAuth, result.account.Type) - require.Equal(t, int64(2), result.account.ID) + + selectedProxyID, ok := chooseBestProxyBucket(candidates, "session-capacity") + require.True(t, ok) + require.Equal(t, proxy202, selectedProxyID) +} + +func TestChooseBestProxyBucket_UsesSpreadKeyForEqualBuckets(t *testing.T) { + proxy101 := int64(101) + proxy202 := int64(202) + candidates := []proxyBucketLoadCandidate{ + {account: &Account{ID: 1, Priority: 1, ProxyID: &proxy101}, loadInfo: &AccountLoadInfo{AccountID: 1, LoadRate: 0, WaitingCount: 0}}, + {account: &Account{ID: 2, Priority: 1, ProxyID: &proxy202}, loadInfo: &AccountLoadInfo{AccountID: 2, LoadRate: 0, WaitingCount: 0}}, } + + selectedA, ok := chooseBestProxyBucket(candidates, "session-a") + require.True(t, ok) + selectedB, ok := chooseBestProxyBucket(candidates, "session-b") + require.True(t, ok) + require.Contains(t, []int64{proxy101, proxy202}, selectedA) + require.Contains(t, []int64{proxy101, proxy202}, selectedB) + require.NotEqual(t, selectedA, selectedB) } + +func TestChooseBestProxyBucket_StillPrefersLowerLoadOverCapacity(t *testing.T) { + proxy101 := int64(101) + proxy202 := int64(202) + candidates := []proxyBucketLoadCandidate{ + {account: &Account{ID: 1, Priority: 1, ProxyID: &proxy101}, loadInfo: &AccountLoadInfo{AccountID: 1, LoadRate: 0, WaitingCount: 0}}, + {account: &Account{ID: 2, Priority: 1, ProxyID: &proxy202}, loadInfo: &AccountLoadInfo{AccountID: 2, LoadRate: 20, WaitingCount: 0}}, + {account: &Account{ID: 3, Priority: 1, ProxyID: &proxy202}, loadInfo: &AccountLoadInfo{AccountID: 3, LoadRate: 20, WaitingCount: 0}}, + {account: &Account{ID: 4, Priority: 1, ProxyID: &proxy202}, loadInfo: &AccountLoadInfo{AccountID: 4, LoadRate: 20, WaitingCount: 0}}, + } + + selectedProxyID, ok := chooseBestProxyBucket(candidates, "session-low-load") + require.True(t, ok) + require.Equal(t, proxy101, selectedProxyID) +} + diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 728328373c..c47f333007 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { @@ -2269,31 +2269,110 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { require.Equal(t, int64(2), updatedID, "粘性会话应绑定到新账号") }) - t.Run("无可用账号-返回错误", func(t *testing.T) { + t.Run("proxy bucket优先选择低负载代理桶", func(t *testing.T) { + groupID := int64(2) + proxy101 := int64(101) + proxy202 := int64(202) + testCtx := context.WithValue(ctx, ctxkey.Group, &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ProxyBucketLoadBalanceEnabled: true, + }) + repo := &mockAccountRepoForPlatform{ - accounts: []Account{}, + accounts: []Account{ + {ID: 11, Platform: PlatformAnthropic, Priority: 5, Status: StatusActive, Schedulable: true, Concurrency: 5, ProxyID: &proxy101}, + {ID: 12, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, ProxyID: &proxy202}, + {ID: 13, Platform: PlatformAnthropic, Priority: 0, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, accountsByID: map[int64]*Account{}, } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } cache := &mockGatewayCacheForPlatform{} - cfg := testConfig() - cfg.Gateway.Scheduling.LoadBatchEnabled = false + cfg.Gateway.Scheduling.LoadBatchEnabled = true + concurrencyCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 11: {AccountID: 11, LoadRate: 80}, + 12: {AccountID: 12, LoadRate: 10}, + 13: {AccountID: 13, LoadRate: 0}, + }, + } svc := &GatewayService{ accountRepo: repo, cache: cache, cfg: cfg, - concurrencyService: nil, + concurrencyService: NewConcurrencyService(concurrencyCache), } +<<<<<<< HEAD + result, err := svc.SelectAccountWithLoadAwareness(testCtx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(12), result.Account.ID) + require.Equal(t, 1, concurrencyCache.loadBatchCalls) +======= result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.Error(t, err) require.Nil(t, result) require.ErrorIs(t, err, ErrNoAvailableAccounts) +>>>>>>> upstream/main + }) + + t.Run("proxy bucket无代理账号时回退原逻辑", func(t *testing.T) { + groupID := int64(3) + testCtx := context.WithValue(ctx, ctxkey.Group, &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ProxyBucketLoadBalanceEnabled: true, + }) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 21, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5}, + {ID: 22, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + concurrencyCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 21: {AccountID: 21, LoadRate: 5}, + 22: {AccountID: 22, LoadRate: 30}, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + result, err := svc.SelectAccountWithLoadAwareness(testCtx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(22), result.Account.ID) + require.Equal(t, 1, concurrencyCache.loadBatchCalls) }) - t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) { + t.Run("无可用账号-返回错误", func(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a95b62b133..48d97a69d4 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -77,6 +77,19 @@ type accountWithLoad struct { loadInfo *AccountLoadInfo } +type proxyBucketLoadCandidate struct { + account *Account + loadInfo *AccountLoadInfo +} + +type proxyBucketScore struct { + proxyID int64 + minLoadRate int + minWaiting int + minPriority int + availableCount int +} + var ForceCacheBillingContextKey = forceCacheBillingKeyType{} var ( @@ -97,6 +110,14 @@ var ( modelsListCacheStoreTotal atomic.Int64 ) +func currentGroupFromContext(ctx context.Context) *Group { + group, _ := ctx.Value(ctxkey.Group).(*Group) + if !IsGroupContextValid(group) { + return nil + } + return group +} + func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { return windowCostPrefetchCacheHitTotal.Load(), windowCostPrefetchCacheMissTotal.Load(), @@ -117,6 +138,188 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } +func shouldApplyProxyBucketLoadBalance(group *Group) bool { + return group != nil && group.ProxyBucketLoadBalanceEnabled +} + +func chooseBestProxyBucket(candidates []proxyBucketLoadCandidate, spreadKey string) (int64, bool) { + if len(candidates) == 0 { + return 0, false + } + + scores := make(map[int64]proxyBucketScore, len(candidates)) + for _, candidate := range candidates { + if candidate.account == nil || candidate.account.ProxyID == nil { + continue + } + proxyID := *candidate.account.ProxyID + loadInfo := candidate.loadInfo + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: candidate.account.ID} + } + score := proxyBucketScore{ + proxyID: proxyID, + minLoadRate: loadInfo.LoadRate, + minWaiting: loadInfo.WaitingCount, + minPriority: candidate.account.Priority, + availableCount: 1, + } + if existing, ok := scores[proxyID]; ok { + score.availableCount = existing.availableCount + 1 + if isProxyBucketScoreBetter(score, existing) { + scores[proxyID] = score + } else { + existing.availableCount = score.availableCount + scores[proxyID] = existing + } + continue + } + scores[proxyID] = score + } + if len(scores) == 0 { + return 0, false + } + + bestScores := make([]proxyBucketScore, 0, len(scores)) + for _, score := range scores { + if len(bestScores) == 0 { + bestScores = append(bestScores, score) + continue + } + bestScore := bestScores[0] + if isProxyBucketScoreBetter(score, bestScore) { + bestScores = bestScores[:0] + bestScores = append(bestScores, score) + continue + } + if isProxyBucketScoreEqual(score, bestScore) { + bestScores = append(bestScores, score) + } + } + if len(bestScores) == 0 { + return 0, false + } + if len(bestScores) == 1 { + return bestScores[0].proxyID, true + } + + proxyIDs := make([]int64, 0, len(bestScores)) + for _, score := range bestScores { + proxyIDs = append(proxyIDs, score.proxyID) + } + return chooseProxyBucketTieWinner(proxyIDs, spreadKey) +} + +func isProxyBucketScoreBetter(left, right proxyBucketScore) bool { + if left.minLoadRate != right.minLoadRate { + return left.minLoadRate < right.minLoadRate + } + if left.minWaiting != right.minWaiting { + return left.minWaiting < right.minWaiting + } + if left.minPriority != right.minPriority { + return left.minPriority < right.minPriority + } + if left.availableCount != right.availableCount { + return left.availableCount > right.availableCount + } + return false +} + +func isProxyBucketScoreEqual(left, right proxyBucketScore) bool { + return left.minLoadRate == right.minLoadRate && + left.minWaiting == right.minWaiting && + left.minPriority == right.minPriority && + left.availableCount == right.availableCount +} + +func chooseProxyBucketTieWinner(proxyIDs []int64, spreadKey string) (int64, bool) { + if len(proxyIDs) == 0 { + return 0, false + } + if len(proxyIDs) == 1 { + return proxyIDs[0], true + } + + sorted := append([]int64(nil), proxyIDs...) + sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] }) + + if spreadKey != "" { + idx := int(xxhash.Sum64String(spreadKey) % uint64(len(sorted))) + return sorted[idx], true + } + return sorted[mathrand.Intn(len(sorted))], true +} + +func filterCandidatesByProxyBucket(candidates []proxyBucketLoadCandidate, proxyID int64) []proxyBucketLoadCandidate { + filtered := make([]proxyBucketLoadCandidate, 0, len(candidates)) + for _, candidate := range candidates { + if candidate.account == nil || candidate.account.ProxyID == nil { + continue + } + if *candidate.account.ProxyID == proxyID { + filtered = append(filtered, candidate) + } + } + return filtered +} + +func selectProxyBucketCandidates(group *Group, candidates []proxyBucketLoadCandidate, spreadKey string) []proxyBucketLoadCandidate { + if !shouldApplyProxyBucketLoadBalance(group) || len(candidates) == 0 { + return candidates + } + + proxyCandidates := make([]proxyBucketLoadCandidate, 0, len(candidates)) + for _, candidate := range candidates { + if candidate.account != nil && candidate.account.ProxyID != nil { + proxyCandidates = append(proxyCandidates, candidate) + } + } + if len(proxyCandidates) == 0 { + return candidates + } + + selectedProxyID, ok := chooseBestProxyBucket(proxyCandidates, spreadKey) + if !ok { + return candidates + } + selected := filterCandidatesByProxyBucket(proxyCandidates, selectedProxyID) + if len(selected) == 0 { + return candidates + } + return selected +} + +func selectProxyBucketAccounts(group *Group, accounts []Account, loadMap map[int64]*AccountLoadInfo, spreadKey string) []Account { + if !shouldApplyProxyBucketLoadBalance(group) || len(accounts) == 0 { + return accounts + } + + candidates := make([]proxyBucketLoadCandidate, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + candidates = append(candidates, proxyBucketLoadCandidate{ + account: acc, + loadInfo: loadMap[acc.ID], + }) + } + selected := selectProxyBucketCandidates(group, candidates, spreadKey) + if len(selected) == len(candidates) { + return accounts + } + + filtered := make([]Account, 0, len(selected)) + for _, candidate := range selected { + if candidate.account != nil { + filtered = append(filtered, *candidate.account) + } + } + if len(filtered) == 0 { + return accounts + } + return filtered +} + func openAIStreamEventIsTerminal(data string) bool { trimmed := strings.TrimSpace(data) if trimmed == "" { @@ -1166,6 +1369,29 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess } // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) +func (s *GatewayService) selectProxyBucketForCurrentGroup(ctx context.Context, accounts []Account, spreadKey string) []Account { + group := currentGroupFromContext(ctx) + if !shouldApplyProxyBucketLoadBalance(group) || len(accounts) == 0 { + return accounts + } + + loadMap := map[int64]*AccountLoadInfo{} + if s != nil && s.concurrencyService != nil { + accountLoads := make([]AccountWithConcurrency, 0, len(accounts)) + for _, acc := range accounts { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + if batchLoad, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads); err == nil { + loadMap = batchLoad + } + } + + return selectProxyBucketAccounts(group, accounts, loadMap, spreadKey) +} + func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) } @@ -1723,6 +1949,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err == nil { + proxyCandidates := make([]proxyBucketLoadCandidate, 0, len(candidates)) + for _, acc := range candidates { + proxyCandidates = append(proxyCandidates, proxyBucketLoadCandidate{ + account: acc, + loadInfo: loadMap[acc.ID], + }) + } + selectedCandidates := selectProxyBucketCandidates(group, proxyCandidates, sessionHash) + if len(selectedCandidates) == 0 { + return nil, ErrNoAvailableAccounts + } + candidates = candidates[:0] + for _, candidate := range selectedCandidates { + if candidate.account != nil { + candidates = append(candidates, candidate.account) + } + } + if len(candidates) == 0 { + return nil, ErrNoAvailableAccounts + } + } + if err != nil { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { return result, nil @@ -2847,9 +3096,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } accountsLoaded = true - // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) + accounts = s.selectProxyBucketForCurrentGroup(ctx, accounts, sessionHash) routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { @@ -2972,6 +3219,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) + accounts = s.selectProxyBucketForCurrentGroup(ctx, accounts, sessionHash) var selected *Account for i := range accounts { acc := &accounts[i] @@ -3054,6 +3302,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) // require_privacy_set: 获取分组信息 var schedGroup *Group @@ -3103,9 +3352,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } accountsLoaded = true - // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) + accounts = s.selectProxyBucketForCurrentGroup(ctx, accounts, sessionHash) routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { @@ -3227,8 +3474,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 - needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { acc := &accounts[i] diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index b35ebce5c9..44d5db9d14 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -52,6 +52,7 @@ type GeminiMessagesCompatService struct { rateLimitService *RateLimitService httpUpstream HTTPUpstream antigravityGatewayService *AntigravityGatewayService + concurrencyService *ConcurrencyService cfg *config.Config responseHeaderFilter *responseheaders.CompiledHeaderFilter } @@ -65,6 +66,7 @@ func NewGeminiMessagesCompatService( rateLimitService *RateLimitService, httpUpstream HTTPUpstream, antigravityGatewayService *AntigravityGatewayService, + concurrencyService *ConcurrencyService, cfg *config.Config, ) *GeminiMessagesCompatService { return &GeminiMessagesCompatService{ @@ -76,14 +78,56 @@ func NewGeminiMessagesCompatService( rateLimitService: rateLimitService, httpUpstream: httpUpstream, antigravityGatewayService: antigravityGatewayService, + concurrencyService: concurrencyService, cfg: cfg, responseHeaderFilter: compileResponseHeaderFilter(cfg), } } -// GetTokenProvider returns the token provider for OAuth accounts -func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider { - return s.tokenProvider +func (s *GeminiMessagesCompatService) selectProxyBucketForGroup(ctx context.Context, group *Group, accounts []Account, spreadKey string) []Account { + if !shouldApplyProxyBucketLoadBalance(group) || len(accounts) == 0 { + return accounts + } + + loadMap := map[int64]*AccountLoadInfo{} + if s != nil && s.concurrencyService != nil { + accountLoads := make([]AccountWithConcurrency, 0, len(accounts)) + for _, acc := range accounts { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + if batchLoad, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads); err == nil { + loadMap = batchLoad + } + } + + return selectProxyBucketAccounts(group, accounts, loadMap, spreadKey) +} + +func (s *GeminiMessagesCompatService) selectProxyBucketForCurrentGroup(ctx context.Context, groupID *int64, accounts []Account, spreadKey string) []Account { + if len(accounts) == 0 { + return accounts + } + + group := currentGroupFromContext(ctx) + if groupID != nil { + if group == nil || group.ID != *groupID { + if s == nil || s.groupRepo == nil { + return accounts + } + resolvedGroup, err := s.groupRepo.GetByIDLite(ctx, *groupID) + if err != nil { + return accounts + } + group = resolvedGroup + } + } else if group == nil { + return accounts + } + + return s.selectProxyBucketForGroup(ctx, group, accounts, spreadKey) } func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { @@ -93,10 +137,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 1. 确定目标平台和调度模式 // Determine target platform and scheduling mode - platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID) + platform, useMixedScheduling, hasForcePlatform, group, err := s.resolvePlatformAndSchedulingMode(ctx, groupID) if err != nil { return nil, err } + if group != nil { + ctx = context.WithValue(ctx, ctxkey.Group, group) + } cacheKey := "gemini:" + sessionHash @@ -120,7 +167,11 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } } - // 4. 按优先级 + LRU 选择最佳账号 + // 4. 在 sticky miss 后先做 proxy bucket 缩圈,再按原有优先级 + LRU 逻辑选账号 + // Apply proxy-bucket narrowing after sticky miss, before legacy account-level selection. + accounts = s.selectProxyBucketForGroup(ctx, group, accounts, sessionHash) + + // 5. 按优先级 + LRU 选择最佳账号 // Select best account by priority + LRU selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling) @@ -131,7 +182,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co return nil, errors.New("no available Gemini accounts") } - // 5. 设置粘性会话绑定 + // 6. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) @@ -145,30 +196,29 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co // // resolvePlatformAndSchedulingMode resolves target platform and scheduling mode. // Returns: platform name, whether to use mixed scheduling, whether force platform, error. -func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) { +func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, group *Group, err error) { // 优先检查 context 中的强制平台(/antigravity 路由) forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { - return forcePlatform, false, true, nil + return forcePlatform, false, true, nil, nil } if groupID != nil { // 根据分组 platform 决定查询哪种账号 - var group *Group if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID { group = ctxGroup } else { group, err = s.groupRepo.GetByIDLite(ctx, *groupID) if err != nil { - return "", false, false, fmt.Errorf("get group failed: %w", err) + return "", false, false, nil, fmt.Errorf("get group failed: %w", err) } } // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) - return group.Platform, group.Platform == PlatformGemini, false, nil + return group.Platform, group.Platform == PlatformGemini, false, group, nil } // 无分组时只使用原生 gemini 平台 - return PlatformGemini, true, false, nil + return PlatformGemini, true, false, nil, nil } // tryStickySessionHit 尝试从粘性会话获取账号。 @@ -482,6 +532,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont return nil, errors.New("no available Gemini accounts") } + accounts = s.selectProxyBucketForCurrentGroup(ctx, groupID, accounts, "gemini-ai-studio-endpoints") + rank := func(a *Account) int { if a == nil { return 999 diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 5e09b95af2..4821c4b7fb 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { @@ -289,6 +289,205 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, } // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ProxyBucketLoadBalance(t *testing.T) { + ctx := context.Background() + groupID := int64(7) + proxy101 := int64(101) + proxy202 := int64(202) + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 3, Status: StatusActive, Schedulable: true, ProxyID: &proxy101, Concurrency: 10}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, ProxyID: &proxy202, Concurrency: 10}, + {ID: 3, Platform: PlatformGemini, Priority: 0, Status: StatusActive, Schedulable: true, Concurrency: 10}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + group := &Group{ID: groupID, Platform: PlatformGemini, ProxyBucketLoadBalanceEnabled: true, Hydrated: true} + ctx = context.WithValue(ctx, ctxkey.Group, group) + + loadCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 8, LoadRate: 80}, + 2: {AccountID: 2, CurrentConcurrency: 1, LoadRate: 10}, + 3: {AccountID: 3, CurrentConcurrency: 0, LoadRate: 0}, + }, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: &mockGroupRepoForGemini{groups: map[int64]*Group{groupID: group}}, + cache: cache, + concurrencyService: NewConcurrencyService(loadCache), + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ProxyBucketFallsBackWithoutProxyAccounts(t *testing.T) { + ctx := context.Background() + groupID := int64(8) + group := &Group{ID: groupID, Platform: PlatformGemini, ProxyBucketLoadBalanceEnabled: true, Hydrated: true} + ctx = context.WithValue(ctx, ctxkey.Group, group) + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true}, + {ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: &mockGroupRepoForGemini{groups: map[int64]*Group{groupID: group}}, + cache: &mockGatewayCacheForGemini{}, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForAIStudioEndpoints_ProxyBucketLoadBalance(t *testing.T) { + ctx := context.Background() + groupID := int64(9) + proxy101 := int64(101) + proxy202 := int64(202) + group := &Group{ID: groupID, Platform: PlatformGemini, ProxyBucketLoadBalanceEnabled: true, Hydrated: true} + ctx = context.WithValue(ctx, ctxkey.Group, group) + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey, ProxyID: &proxy101, Concurrency: 10, Credentials: map[string]any{"api_key": "key-1"}}, + {ID: 2, Platform: PlatformGemini, Priority: 5, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey, ProxyID: &proxy202, Concurrency: 10, Credentials: map[string]any{"api_key": "key-2"}}, + {ID: 3, Platform: PlatformGemini, Priority: 0, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey, Concurrency: 10, Credentials: map[string]any{"api_key": "key-3"}}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + loadCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 7, LoadRate: 70}, + 2: {AccountID: 2, CurrentConcurrency: 1, LoadRate: 10}, + 3: {AccountID: 3, CurrentConcurrency: 0, LoadRate: 0}, + }, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: &mockGroupRepoForGemini{groups: map[int64]*Group{groupID: group}}, + cache: &mockGatewayCacheForGemini{}, + concurrencyService: NewConcurrencyService(loadCache), + } + + acc, err := svc.SelectAccountForAIStudioEndpoints(ctx, &groupID) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ProxyBucketPrefersLargerEqualBucket(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + proxy101 := int64(101) + proxy202 := int64(202) + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 11, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, ProxyID: &proxy101, Concurrency: 10}, + {ID: 12, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, ProxyID: &proxy202, Concurrency: 10}, + {ID: 13, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, ProxyID: &proxy202, Concurrency: 10}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + group := &Group{ID: groupID, Platform: PlatformGemini, ProxyBucketLoadBalanceEnabled: true, Hydrated: true} + ctx = context.WithValue(ctx, ctxkey.Group, group) + + loadCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 11: {AccountID: 11, CurrentConcurrency: 0, LoadRate: 0}, + 12: {AccountID: 12, CurrentConcurrency: 0, LoadRate: 0}, + 13: {AccountID: 13, CurrentConcurrency: 0, LoadRate: 0}, + }, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: &mockGroupRepoForGemini{groups: map[int64]*Group{groupID: group}}, + cache: cache, + concurrencyService: NewConcurrencyService(loadCache), + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "capacity-spread", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Contains(t, []int64{12, 13}, acc.ID) +} + +func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ProxyBucketEqualBucketsSpreadBySession(t *testing.T) { + ctx := context.Background() + groupID := int64(11) + proxy101 := int64(101) + proxy202 := int64(202) + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 21, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, ProxyID: &proxy101, Concurrency: 10}, + {ID: 22, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, ProxyID: &proxy202, Concurrency: 10}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + group := &Group{ID: groupID, Platform: PlatformGemini, ProxyBucketLoadBalanceEnabled: true, Hydrated: true} + ctx = context.WithValue(ctx, ctxkey.Group, group) + + loadCache := &mockConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 21: {AccountID: 21, CurrentConcurrency: 0, LoadRate: 0}, + 22: {AccountID: 22, CurrentConcurrency: 0, LoadRate: 0}, + }, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: &mockGroupRepoForGemini{groups: map[int64]*Group{groupID: group}}, + cache: &mockGatewayCacheForGemini{}, + concurrencyService: NewConcurrencyService(loadCache), + } + + accA, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "session-a", "gemini-2.5-flash", nil) + require.NoError(t, err) + accB, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "session-b", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, accA) + require.NotNil(t, accB) + require.NotEqual(t, accA.ID, accB.ID) +} + func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { ctx := context.Background() @@ -389,7 +588,7 @@ func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) require.NoError(t, err) require.NotNil(t, acc) require.Equal(t, 0, groupRepo.getByIDCalls) - require.Equal(t, 1, groupRepo.getByIDLiteCalls) + require.Equal(t, 2, groupRepo.getByIDLiteCalls) } // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组 diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index e0f81a39a4..ae1e1b4327 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -58,10 +58,11 @@ type Group struct { SortOrder int // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool - RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) - RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) - DefaultMappedModel string + AllowMessagesDispatch bool + RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) + RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) + DefaultMappedModel string + ProxyBucketLoadBalanceEnabled bool CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 6c09e354a1..4ce032265c 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -212,6 +212,14 @@ type defaultOpenAIAccountScheduler struct { stats *openAIAccountRuntimeStats } +func (s *defaultOpenAIAccountScheduler) proxyBucketGroup(ctx context.Context, groupID *int64) *Group { + group := currentGroupFromContext(ctx) + if groupID == nil || group == nil || group.ID != *groupID { + return nil + } + return group +} + func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler { if stats == nil { stats = newOpenAIAccountRuntimeStats() @@ -623,6 +631,31 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } } + proxyGroup := s.proxyBucketGroup(ctx, req.GroupID) + proxyCandidates := make([]proxyBucketLoadCandidate, 0, len(filtered)) + for _, account := range filtered { + proxyCandidates = append(proxyCandidates, proxyBucketLoadCandidate{ + account: account, + loadInfo: loadMap[account.ID], + }) + if loadMap[account.ID] == nil { + loadMap[account.ID] = &AccountLoadInfo{AccountID: account.ID} + } + } + selectedProxyCandidates := selectProxyBucketCandidates(proxyGroup, proxyCandidates, req.SessionHash) + if len(selectedProxyCandidates) == 0 { + return nil, 0, 0, 0, ErrNoAvailableAccounts + } + filtered = filtered[:0] + for _, candidate := range selectedProxyCandidates { + if candidate.account != nil { + filtered = append(filtered, candidate.account) + } + } + if len(filtered) == 0 { + return nil, 0, 0, 0, ErrNoAvailableAccounts + } + minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority maxWaiting := 1 loadRateSum := 0.0 @@ -730,7 +763,6 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } cfg := s.service.schedulingConfig() - // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 088815ed40..257cadd8dc 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/stretchr/testify/require" ) @@ -548,6 +549,188 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback } } +func TestOpenAIGatewayService_SelectAccountWithScheduler_ProxyBucketPrefersLowerLoad(t *testing.T) { + ctx := context.Background() + groupID := int64(13) + proxy101 := int64(101) + proxy202 := int64(202) + ctx = context.WithValue(ctx, ctxkey.Group, &Group{ + ID: groupID, + Platform: PlatformOpenAI, + Status: StatusActive, + Hydrated: true, + ProxyBucketLoadBalanceEnabled: true, + }) + + accounts := []Account{ + {ID: 5001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5, ProxyID: &proxy101}, + {ID: 5002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, ProxyID: &proxy202}, + {ID: 5003, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0}, + } + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5001: {AccountID: 5001, LoadRate: 80}, + 5002: {AccountID: 5002, LoadRate: 10}, + 5003: {AccountID: 5003, LoadRate: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(5002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 1, decision.CandidateCount) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_ProxyBucketPrefersLargerEqualBucket(t *testing.T) { + ctx := context.Background() + groupID := int64(15) + proxy101 := int64(101) + proxy202 := int64(202) + ctx = context.WithValue(ctx, ctxkey.Group, &Group{ + ID: groupID, + Platform: PlatformOpenAI, + Status: StatusActive, + Hydrated: true, + ProxyBucketLoadBalanceEnabled: true, + }) + + accounts := []Account{ + {ID: 5201, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, ProxyID: &proxy101}, + {ID: 5202, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, ProxyID: &proxy202}, + {ID: 5203, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, ProxyID: &proxy202}, + } + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5201: {AccountID: 5201, LoadRate: 0}, + 5202: {AccountID: 5202, LoadRate: 0}, + 5203: {AccountID: 5203, LoadRate: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "capacity-spread", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Contains(t, []int64{5202, 5203}, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 2, decision.CandidateCount) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_ProxyBucketEqualBucketsSpreadBySession(t *testing.T) { + ctx := context.Background() + groupID := int64(16) + proxy101 := int64(101) + proxy202 := int64(202) + ctx = context.WithValue(ctx, ctxkey.Group, &Group{ + ID: groupID, + Platform: PlatformOpenAI, + Status: StatusActive, + Hydrated: true, + ProxyBucketLoadBalanceEnabled: true, + }) + + accounts := []Account{ + {ID: 5301, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, ProxyID: &proxy101}, + {ID: 5302, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, ProxyID: &proxy202}, + } + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5301: {AccountID: 5301, LoadRate: 0}, + 5302: {AccountID: 5302, LoadRate: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selectionA, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session-a", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + selectionB, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session-b", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selectionA) + require.NotNil(t, selectionB) + require.NotNil(t, selectionA.Account) + require.NotNil(t, selectionB.Account) + require.NotEqual(t, selectionA.Account.ID, selectionB.Account.ID) + if selectionA.ReleaseFunc != nil { + selectionA.ReleaseFunc() + } + if selectionB.ReleaseFunc != nil { + selectionB.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_ProxyBucketStickyStillWins(t *testing.T) { + ctx := context.Background() + groupID := int64(14) + proxy101 := int64(101) + proxy202 := int64(202) + ctx = context.WithValue(ctx, ctxkey.Group, &Group{ + ID: groupID, + Platform: PlatformOpenAI, + Status: StatusActive, + Hydrated: true, + ProxyBucketLoadBalanceEnabled: true, + }) + + accounts := []Account{ + {ID: 5101, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5, ProxyID: &proxy101}, + {ID: 5102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, ProxyID: &proxy202}, + } + cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:proxy-sticky": 5101}} + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5101: {AccountID: 5101, LoadRate: 80}, + 5102: {AccountID: 5102, LoadRate: 10}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "proxy-sticky", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(5101), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { ctx := context.Background() groupID := int64(12) diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index dd416269f4..db49c404f6 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -103,6 +103,14 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } } + responsesBody, _, err = applyOpenAIRequestOverridesToBody(responsesBody, account) + if err != nil { + return nil, fmt.Errorf("apply openai request overrides: %w", err) + } + _, _, promptCacheKey = extractOpenAIRequestMetaFromBody(responsesBody) + effectiveServiceTier := extractOpenAIServiceTierFromBody(responsesBody) + effectiveReasoningEffort := extractOpenAIReasoningEffortFromBody(responsesBody, upstreamModel) + // 5. Get access token token, _, err := s.GetAccessToken(ctx, account) if err != nil { @@ -197,14 +205,8 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( // Propagate ServiceTier and ReasoningEffort to result for billing if handleErr == nil && result != nil { - if responsesReq.ServiceTier != "" { - st := responsesReq.ServiceTier - result.ServiceTier = &st - } - if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" { - re := responsesReq.Reasoning.Effort - result.ReasoningEffort = &re - } + result.ServiceTier = effectiveServiceTier + result.ReasoningEffort = effectiveReasoningEffort } // Extract and save Codex usage snapshot from response headers (for OAuth accounts) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 28c4b1f4f5..58f590ca33 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1808,6 +1808,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") } + body, overridesApplied, err := applyOpenAIRequestOverridesToBody(body, account) + if err != nil { + return nil, err + } + if overridesApplied { + clearOpenAIRequestBodyCache(c) + } + originalBody := body reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel @@ -4876,6 +4884,13 @@ func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error return reqBody, nil } +func clearOpenAIRequestBodyCache(c *gin.Context) { + if c == nil || c.Keys == nil { + return + } + delete(c.Keys, OpenAIParsedRequestBodyKey) +} + func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { if value == "" { diff --git a/backend/internal/service/openai_request_overrides.go b/backend/internal/service/openai_request_overrides.go new file mode 100644 index 0000000000..9a9eaddb82 --- /dev/null +++ b/backend/internal/service/openai_request_overrides.go @@ -0,0 +1,140 @@ +package service + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" +) + +const openAIRequestOverridesExtraKey = "openai_request_overrides" + +// GetOpenAIRequestOverrides returns the account-level OpenAI request overrides. +// Only top-level JSON objects are supported; invalid values are ignored. +func (a *Account) GetOpenAIRequestOverrides() map[string]any { + if a == nil || a.Platform != PlatformOpenAI || a.Extra == nil { + return nil + } + + raw, ok := a.Extra[openAIRequestOverridesExtraKey] + if !ok || raw == nil { + return nil + } + + overrides, ok := openAIRequestOverridesMap(raw) + if !ok || len(overrides) == 0 { + return nil + } + + return sanitizeOpenAIRequestOverrides(overrides) +} + +func applyOpenAIRequestOverridesToBody(body []byte, account *Account) ([]byte, bool, error) { + overrides := account.GetOpenAIRequestOverrides() + if len(body) == 0 || len(overrides) == 0 { + return body, false, nil + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, false, fmt.Errorf("parse request body for openai request overrides: %w", err) + } + if reqBody == nil { + reqBody = make(map[string]any) + } + + if !mergeOpenAIOverrideObjects(reqBody, overrides) { + return body, false, nil + } + + updatedBody, err := json.Marshal(reqBody) + if err != nil { + return nil, false, fmt.Errorf("serialize request body with openai request overrides: %w", err) + } + + return updatedBody, true, nil +} + +func mergeOpenAIOverrideObjects(dst map[string]any, src map[string]any) bool { + changed := false + + for key, value := range src { + if srcMap, ok := openAIRequestOverridesMap(value); ok { + if dstMap, ok := openAIRequestOverridesMap(dst[key]); ok { + if mergeOpenAIOverrideObjects(dstMap, srcMap) { + changed = true + } + dst[key] = dstMap + continue + } + + dst[key] = cloneOpenAIOverrideValue(srcMap) + changed = true + continue + } + + cloned := cloneOpenAIOverrideValue(value) + if existing, ok := dst[key]; !ok || !reflect.DeepEqual(existing, cloned) { + dst[key] = cloned + changed = true + } + } + + return changed +} + +func openAIRequestOverridesMap(value any) (map[string]any, bool) { + obj, ok := value.(map[string]any) + if !ok { + return nil, false + } + return obj, true +} + +func sanitizeOpenAIRequestOverrides(overrides map[string]any) map[string]any { + if len(overrides) == 0 { + return nil + } + + sanitized := make(map[string]any, len(overrides)) + for key, value := range overrides { + if !isAllowedOpenAIRequestOverrideKey(key) { + continue + } + sanitized[key] = cloneOpenAIOverrideValue(value) + } + + if len(sanitized) == 0 { + return nil + } + + return sanitized +} + +func isAllowedOpenAIRequestOverrideKey(key string) bool { + switch strings.ToLower(strings.TrimSpace(key)) { + case "model": + return false + default: + return true + } +} + +func cloneOpenAIOverrideValue(value any) any { + switch typed := value.(type) { + case map[string]any: + cloned := make(map[string]any, len(typed)) + for key, nested := range typed { + cloned[key] = cloneOpenAIOverrideValue(nested) + } + return cloned + case []any: + cloned := make([]any, len(typed)) + for i, nested := range typed { + cloned[i] = cloneOpenAIOverrideValue(nested) + } + return cloned + default: + return typed + } +} diff --git a/backend/internal/service/openai_request_overrides_forward_test.go b/backend/internal/service/openai_request_overrides_forward_test.go new file mode 100644 index 0000000000..12eb60e5aa --- /dev/null +++ b/backend/internal/service/openai_request_overrides_forward_test.go @@ -0,0 +1,275 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIRequestOverrides_ForwardPassthroughOverridesServiceTier(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 1, + Name: "openai-passthrough", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"}, + Extra: map[string]any{ + "openai_passthrough": true, + openAIRequestOverridesExtraKey: map[string]any{ + "service_tier": "fast", + }, + }, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "fast", gjson.GetBytes(upstream.lastBody, "service_tier").String()) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) +} + +func TestOpenAIRequestOverrides_ForwardNonPassthroughOverridesServiceTier(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 2, + Name: "openai-http", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"}, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "service_tier": "fast", + }, + }, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "fast", gjson.GetBytes(upstream.lastBody, "service_tier").String()) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) +} + +func TestOpenAIRequestOverrides_ForwardNonPassthroughInvalidatesStaleParsedRequestCache(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Set(OpenAIParsedRequestBodyKey, map[string]any{ + "model": "gpt-5.2", + "stream": false, + "input": []any{ + map[string]any{ + "type": "function_call_output", + "call_id": "call_123", + "output": "done", + }, + }, + "previous_response_id": "resp_123", + }) + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"function_call_output","call_id":"call_123","output":"done"}],"previous_response_id":"resp_123"}`) + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 22, + Name: "openai-http", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"}, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "service_tier": "fast", + }, + }, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "fast", gjson.GetBytes(upstream.lastBody, "service_tier").String()) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) +} + +func TestOpenAIRequestOverrides_ForwardAsAnthropicOverridesBetaFastMode(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"claude-sonnet-4-20250514","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("anthropic-beta", "fast-mode-2026-02-01") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 3, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "service_tier": "flex", + }, + }, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.Equal(t, "flex", gjson.GetBytes(upstream.lastBody, "service_tier").String()) +} + +func TestOpenAIRequestOverrides_ForwardAsAnthropicUsesOverriddenDerivedMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"claude-sonnet-4-20250514","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set("api_key", &APIKey{ID: 42}) + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 33, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "prompt_cache_key": "override-cache", + "service_tier": "flex", + "reasoning": map[string]any{ + "effort": "high", + }, + }, + }, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "override-cache", gjson.GetBytes(upstream.lastBody, "prompt_cache_key").String()) + require.Equal(t, "flex", gjson.GetBytes(upstream.lastBody, "service_tier").String()) + require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "reasoning.effort").String()) + require.Equal(t, generateSessionUUID(isolateOpenAISessionID(42, "override-cache")), upstream.lastReq.Header.Get("session_id")) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "flex", *result.ServiceTier) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "high", *result.ReasoningEffort) +} diff --git a/backend/internal/service/openai_request_overrides_test.go b/backend/internal/service/openai_request_overrides_test.go new file mode 100644 index 0000000000..db9fbc5aba --- /dev/null +++ b/backend/internal/service/openai_request_overrides_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestAccountGetOpenAIRequestOverrides(t *testing.T) { + t.Run("returns overrides for openai account", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "service_tier": "fast", + }, + }, + } + + got := account.GetOpenAIRequestOverrides() + require.Equal(t, "fast", got["service_tier"]) + }) + + t.Run("strips disallowed top level model override", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "model": "gpt-5.4", + "service_tier": "fast", + }, + }, + } + + got := account.GetOpenAIRequestOverrides() + require.NotNil(t, got) + require.Equal(t, "fast", got["service_tier"]) + _, exists := got["model"] + require.False(t, exists) + }) + + t.Run("ignores non-object overrides", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: "fast", + }, + } + + require.Nil(t, account.GetOpenAIRequestOverrides()) + }) +} + +func TestApplyOpenAIRequestOverridesToBody(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "service_tier": "fast", + "reasoning": map[string]any{ + "effort": "high", + }, + }, + }, + } + + body := []byte(`{"model":"gpt-5.2","reasoning":{"summary":"auto"},"input":[{"type":"text","text":"hi"}]}`) + got, modified, err := applyOpenAIRequestOverridesToBody(body, account) + require.NoError(t, err) + require.True(t, modified) + require.Equal(t, "fast", gjson.GetBytes(got, "service_tier").String()) + require.Equal(t, "high", gjson.GetBytes(got, "reasoning.effort").String()) + require.Equal(t, "auto", gjson.GetBytes(got, "reasoning.summary").String()) + require.Equal(t, "hi", gjson.GetBytes(got, "input.0.text").String()) +} + +func TestApplyOpenAIRequestOverridesToBody_NoOverrideChange(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "service_tier": "fast", + }, + }, + } + + body := []byte(`{"service_tier":"fast"}`) + got, modified, err := applyOpenAIRequestOverridesToBody(body, account) + require.NoError(t, err) + require.False(t, modified) + require.Equal(t, body, got) +} + +func TestApplyOpenAIRequestOverridesToBody_IgnoresModelOverride(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Extra: map[string]any{ + openAIRequestOverridesExtraKey: map[string]any{ + "model": "gpt-5.4", + "service_tier": "fast", + }, + }, + } + + body := []byte(`{"model":"gpt-5.2"}`) + got, modified, err := applyOpenAIRequestOverridesToBody(body, account) + require.NoError(t, err) + require.True(t, modified) + require.Equal(t, "gpt-5.2", gjson.GetBytes(got, "model").String()) + require.Equal(t, "fast", gjson.GetBytes(got, "service_tier").String()) +} diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index ffe7915262..d2e9b5877a 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -73,7 +73,7 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re return nil } -func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string, proxyID *int64) ([]Account, *pagination.PaginationResult, error) { _ = platform _ = accountType _ = status @@ -492,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount( } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "", nil) require.NoError(t, err) require.Equal(t, int64(1), total) require.Len(t, accounts, 1) diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index ad303d92fd..0f04534ec5 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "", 0, "") + }, platformFilter, "", "", "", 0, "", nil) if err != nil { return nil, err } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 4f5b57cc97..ad9b39ac7a 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -150,12 +150,19 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: - // OpenAI: token_invalidated / token_revoked 表示 token 被永久作废(非过期),直接标记 error + // OpenAI: token_invalidated / token_revoked / account_deactivated 表示认证或账号被永久作废,直接标记 error openai401Code := extractUpstreamErrorCode(responseBody) - if account.Platform == PlatformOpenAI && (openai401Code == "token_invalidated" || openai401Code == "token_revoked") { + if account.Platform == PlatformOpenAI && (openai401Code == "token_invalidated" || openai401Code == "token_revoked" || openai401Code == "account_deactivated") { msg := "Token revoked (401): account authentication permanently revoked" + if openai401Code == "account_deactivated" { + msg = "Account deactivated (401): account has been deactivated" + } if upstreamMsg != "" { - msg = "Token revoked (401): " + upstreamMsg + if openai401Code == "account_deactivated" { + msg = "Account deactivated (401): " + upstreamMsg + } else { + msg = "Token revoked (401): " + upstreamMsg + } } s.handleAuthError(ctx, account, msg) shouldDisable = true diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 67b22e5212..274e51135a 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -140,21 +140,25 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { require.Empty(t, invalidator.accounts) } -func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) { +func TestRateLimitService_HandleUpstreamError_OpenAIAccountDeactivatedUsesSetError(t *testing.T) { repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) account := &Account{ - ID: 103, + ID: 104, Platform: PlatformOpenAI, Type: AccountTypeOAuth, - Credentials: map[string]any{ - "access_token": "token", - }, } - shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + body := []byte(`{"error":{"code":"account_deactivated","message":"Your OpenAI account has been deactivated","type":"invalid_request_error"},"status":401}`) + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, body) require.True(t, shouldDisable) - require.Equal(t, 1, repo.updateCredentialsCalls) - require.NotEmpty(t, repo.lastCredentials["expires_at"]) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Equal(t, 0, repo.updateCredentialsCalls) + require.Empty(t, invalidator.accounts) + require.Contains(t, repo.lastErrorMsg, "Account deactivated (401)") } + diff --git a/backend/internal/service/ratelimit_session_window_test.go b/backend/internal/service/ratelimit_session_window_test.go index 7796a85e76..da56eda029 100644 --- a/backend/internal/service/ratelimit_session_window_test.go +++ b/backend/internal/service/ratelimit_session_window_test.go @@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic( func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } -func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { +func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string, *int64) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) { diff --git a/backend/migrations/081_add_group_proxy_bucket_load_balance.sql b/backend/migrations/081_add_group_proxy_bucket_load_balance.sql new file mode 100644 index 0000000000..b4c9af436f --- /dev/null +++ b/backend/migrations/081_add_group_proxy_bucket_load_balance.sql @@ -0,0 +1 @@ +ALTER TABLE groups ADD COLUMN proxy_bucket_load_balance_enabled BOOLEAN NOT NULL DEFAULT false; diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index fd93fe7ef0..5e2bd40cd2 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -37,6 +37,7 @@ export async function list( group?: string search?: string privacy_mode?: string + proxy_id?: string lite?: string }, options?: { @@ -70,6 +71,7 @@ export async function listWithEtag( group?: string search?: string privacy_mode?: string + proxy_id?: string lite?: string }, options?: { diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 2934fbd94f..bc35b9f203 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -82,6 +82,47 @@ +
+ {{ t('admin.accounts.requestOverridesDesc') }} +
+{{ t('admin.accounts.requestOverridesDesc') }}
+{{ t('admin.accounts.requestOverridesDesc') }}
+{{ t('admin.groups.openaiMessages.proxyBucketHint') }}
+{{ t('admin.groups.openaiMessages.proxyBucketHint') }}
+