diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 513b7996db..fdc5c6acdb 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -102,12 +102,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) - soraAccountRepository := repository.NewSoraAccountRepository(db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -184,12 +183,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) - soraS3Storage := service.NewSoraS3Storage(settingService) - settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient) - soraGenerationRepository := repository.NewSoraGenerationRepository(db) - soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService) - soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -223,16 +217,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) - soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) - soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) - soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) - soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService) - soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -243,12 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService) application := &Application{ Server: httpServer, Cleanup: v, @@ -283,7 +271,6 @@ func provideCleanup( opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, opsSystemLogSink *service.OpsSystemLogSink, - soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, @@ -331,12 +318,6 @@ func provideCleanup( } return nil }}, - {"SoraMediaCleanupService", func() error { - if soraMediaCleanup != nil { - soraMediaCleanup.Stop() - } - return nil - }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 9d2a54b98b..6e4561c994 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -57,7 +57,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { &service.OpsCleanupService{}, &service.OpsScheduledReportService{}, opsSystemLogSinkSvc, - &service.SoraMediaCleanupService{}, schedulerSnapshotSvc, tokenRefreshSvc, accountExpirySvc, diff --git a/backend/ent/group.go b/backend/ent/group.go index fc691a9b4c..b15ac15dd9 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -52,16 +52,6 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` - // SoraImagePrice360 holds the value of the "sora_image_price_360" field. - SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` - // SoraImagePrice540 holds the value of the "sora_image_price_540" field. - SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` - // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field. - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` - // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. - SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` - // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -196,9 +186,9 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel: values[i] = new(sql.NullString) @@ -335,40 +325,6 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } - case group.FieldSoraImagePrice360: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i]) - } else if value.Valid { - _m.SoraImagePrice360 = new(float64) - *_m.SoraImagePrice360 = value.Float64 - } - case group.FieldSoraImagePrice540: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i]) - } else if value.Valid { - _m.SoraImagePrice540 = new(float64) - *_m.SoraImagePrice540 = value.Float64 - } - case group.FieldSoraVideoPricePerRequest: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i]) - } else if value.Valid { - _m.SoraVideoPricePerRequest = new(float64) - *_m.SoraVideoPricePerRequest = value.Float64 - } - case group.FieldSoraVideoPricePerRequestHd: - if value, ok := values[i].(*sql.NullFloat64); !ok { - return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i]) - } else if value.Valid { - _m.SoraVideoPricePerRequestHd = new(float64) - *_m.SoraVideoPricePerRequestHd = value.Float64 - } - case group.FieldSoraStorageQuotaBytes: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) - } else if value.Valid { - _m.SoraStorageQuotaBytes = value.Int64 - } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -590,29 +546,6 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") - if v := _m.SoraImagePrice360; v != nil { - builder.WriteString("sora_image_price_360=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - if v := _m.SoraImagePrice540; v != nil { - builder.WriteString("sora_image_price_540=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - if v := _m.SoraVideoPricePerRequest; v != nil { - builder.WriteString("sora_video_price_per_request=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - if v := _m.SoraVideoPricePerRequestHd; v != nil { - builder.WriteString("sora_video_price_per_request_hd=") - builder.WriteString(fmt.Sprintf("%v", *v)) - } - builder.WriteString(", ") - builder.WriteString("sora_storage_quota_bytes=") - builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) - builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 352221275b..21a7c2cb76 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,16 +49,6 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" - // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database. - FieldSoraImagePrice360 = "sora_image_price_360" - // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database. - FieldSoraImagePrice540 = "sora_image_price_540" - // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database. - FieldSoraVideoPricePerRequest = "sora_video_price_per_request" - // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. - FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" - // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. - FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -175,11 +165,6 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, - FieldSoraImagePrice360, - FieldSoraImagePrice540, - FieldSoraVideoPricePerRequest, - FieldSoraVideoPricePerRequestHd, - FieldSoraStorageQuotaBytes, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, @@ -247,8 +232,6 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int - // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. - DefaultSoraStorageQuotaBytes int64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -364,31 +347,6 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } -// BySoraImagePrice360 orders the results by the sora_image_price_360 field. -func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc() -} - -// BySoraImagePrice540 orders the results by the sora_image_price_540 field. -func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc() -} - -// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field. -func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc() -} - -// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field. -func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() -} - -// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. -func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() -} - // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 41bd575a53..cba2ce5f0e 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,31 +140,6 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } -// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ. -func SoraImagePrice360(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ. -func SoraImagePrice540(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) -} - -// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ. -func SoraVideoPricePerRequest(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ. -func SoraVideoPricePerRequestHd(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. -func SoraStorageQuotaBytes(v int64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1070,246 +1045,6 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } -// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field. -func SoraImagePrice360EQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field. -func SoraImagePrice360NEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field. -func SoraImagePrice360In(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...)) -} - -// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field. -func SoraImagePrice360NotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...)) -} - -// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field. -func SoraImagePrice360GT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field. -func SoraImagePrice360GTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field. -func SoraImagePrice360LT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field. -func SoraImagePrice360LTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v)) -} - -// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field. -func SoraImagePrice360IsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360)) -} - -// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field. -func SoraImagePrice360NotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360)) -} - -// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field. -func SoraImagePrice540EQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field. -func SoraImagePrice540NEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field. -func SoraImagePrice540In(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...)) -} - -// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field. -func SoraImagePrice540NotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...)) -} - -// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field. -func SoraImagePrice540GT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field. -func SoraImagePrice540GTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field. -func SoraImagePrice540LT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field. -func SoraImagePrice540LTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v)) -} - -// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field. -func SoraImagePrice540IsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540)) -} - -// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field. -func SoraImagePrice540NotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540)) -} - -// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestNEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...)) -} - -// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...)) -} - -// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestGT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestGTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestLT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestLTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v)) -} - -// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestIsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest)) -} - -// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field. -func SoraVideoPricePerRequestNotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest)) -} - -// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...)) -} - -// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...)) -} - -// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdGT(v float64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdLT(v float64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v)) -} - -// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdIsNil() predicate.Group { - return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd)) -} - -// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field. -func SoraVideoPricePerRequestHdNotNil() predicate.Group { - return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) -} - -// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesEQ(v int64) predicate.Group { - return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNEQ(v int64) predicate.Group { - return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group { - return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group { - return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGT(v int64) predicate.Group { - return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGTE(v int64) predicate.Group { - return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLT(v int64) predicate.Group { - return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLTE(v int64) predicate.Group { - return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) -} - // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index a635dfd999..a8c30b184d 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,76 +258,6 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate { - _c.mutation.SetSoraImagePrice360(v) - return _c -} - -// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraImagePrice360(*v) - } - return _c -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate { - _c.mutation.SetSoraImagePrice540(v) - return _c -} - -// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraImagePrice540(*v) - } - return _c -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate { - _c.mutation.SetSoraVideoPricePerRequest(v) - return _c -} - -// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraVideoPricePerRequest(*v) - } - return _c -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate { - _c.mutation.SetSoraVideoPricePerRequestHd(v) - return _c -} - -// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate { - if v != nil { - _c.SetSoraVideoPricePerRequestHd(*v) - } - return _c -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate { - _c.mutation.SetSoraStorageQuotaBytes(v) - return _c -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate { - if v != nil { - _c.SetSoraStorageQuotaBytes(*v) - } - return _c -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -645,10 +575,6 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - v := group.DefaultSoraStorageQuotaBytes - _c.mutation.SetSoraStorageQuotaBytes(v) - } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -737,9 +663,6 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)} - } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -867,26 +790,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } - if value, ok := _c.mutation.SoraImagePrice360(); ok { - _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - _node.SoraImagePrice360 = &value - } - if value, ok := _c.mutation.SoraImagePrice540(); ok { - _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - _node.SoraImagePrice540 = &value - } - if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - _node.SoraVideoPricePerRequest = &value - } - if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - _node.SoraVideoPricePerRequestHd = &value - } - if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - _node.SoraStorageQuotaBytes = value - } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -1379,120 +1282,6 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert { - u.Set(group.FieldSoraImagePrice360, v) - return u -} - -// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert { - u.SetExcluded(group.FieldSoraImagePrice360) - return u -} - -// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. -func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert { - u.Add(group.FieldSoraImagePrice360, v) - return u -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert { - u.SetNull(group.FieldSoraImagePrice360) - return u -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert { - u.Set(group.FieldSoraImagePrice540, v) - return u -} - -// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert { - u.SetExcluded(group.FieldSoraImagePrice540) - return u -} - -// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. -func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert { - u.Add(group.FieldSoraImagePrice540, v) - return u -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert { - u.SetNull(group.FieldSoraImagePrice540) - return u -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert { - u.Set(group.FieldSoraVideoPricePerRequest, v) - return u -} - -// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert { - u.SetExcluded(group.FieldSoraVideoPricePerRequest) - return u -} - -// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. -func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert { - u.Add(group.FieldSoraVideoPricePerRequest, v) - return u -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert { - u.SetNull(group.FieldSoraVideoPricePerRequest) - return u -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert { - u.Set(group.FieldSoraVideoPricePerRequestHd, v) - return u -} - -// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert { - u.SetExcluded(group.FieldSoraVideoPricePerRequestHd) - return u -} - -// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. -func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert { - u.Add(group.FieldSoraVideoPricePerRequestHd, v) - return u -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { - u.SetNull(group.FieldSoraVideoPricePerRequestHd) - return u -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert { - u.Set(group.FieldSoraStorageQuotaBytes, v) - return u -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert { - u.SetExcluded(group.FieldSoraStorageQuotaBytes) - return u -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert { - u.Add(group.FieldSoraStorageQuotaBytes, v) - return u -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -2054,139 +1843,6 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice360(v) - }) -} - -// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. -func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice360(v) - }) -} - -// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice360() - }) -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice360() - }) -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice540(v) - }) -} - -// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. -func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice540(v) - }) -} - -// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice540() - }) -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice540() - }) -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequest(v) - }) -} - -// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. -func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequest(v) - }) -} - -// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequest() - }) -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequest() - }) -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequestHd(v) - }) -} - -// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequestHd(v) - }) -} - -// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequestHd() - }) -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequestHd() - }) -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2944,139 +2600,6 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice360(v) - }) -} - -// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. -func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice360(v) - }) -} - -// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice360() - }) -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice360() - }) -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraImagePrice540(v) - }) -} - -// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. -func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraImagePrice540(v) - }) -} - -// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraImagePrice540() - }) -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraImagePrice540() - }) -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequest(v) - }) -} - -// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. -func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequest(v) - }) -} - -// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequest() - }) -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequest() - }) -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraVideoPricePerRequestHd(v) - }) -} - -// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraVideoPricePerRequestHd(v) - }) -} - -// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraVideoPricePerRequestHd() - }) -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.ClearSoraVideoPricePerRequestHd() - }) -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk { - return u.Update(func(s *GroupUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index a9a4b9da80..aa1a83d421 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -355,135 +355,6 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate { - _u.mutation.ResetSoraImagePrice360() - _u.mutation.SetSoraImagePrice360(v) - return _u -} - -// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraImagePrice360(*v) - } - return _u -} - -// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. -func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate { - _u.mutation.AddSoraImagePrice360(v) - return _u -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate { - _u.mutation.ClearSoraImagePrice360() - return _u -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate { - _u.mutation.ResetSoraImagePrice540() - _u.mutation.SetSoraImagePrice540(v) - return _u -} - -// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraImagePrice540(*v) - } - return _u -} - -// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. -func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate { - _u.mutation.AddSoraImagePrice540(v) - return _u -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate { - _u.mutation.ClearSoraImagePrice540() - return _u -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate { - _u.mutation.ResetSoraVideoPricePerRequest() - _u.mutation.SetSoraVideoPricePerRequest(v) - return _u -} - -// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraVideoPricePerRequest(*v) - } - return _u -} - -// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. -func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate { - _u.mutation.AddSoraVideoPricePerRequest(v) - return _u -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate { - _u.mutation.ClearSoraVideoPricePerRequest() - return _u -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate { - _u.mutation.ResetSoraVideoPricePerRequestHd() - _u.mutation.SetSoraVideoPricePerRequestHd(v) - return _u -} - -// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate { - if v != nil { - _u.SetSoraVideoPricePerRequestHd(*v) - } - return _u -} - -// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate { - _u.mutation.AddSoraVideoPricePerRequestHd(v) - return _u -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { - _u.mutation.ClearSoraVideoPricePerRequestHd() - return _u -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -1082,48 +953,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } - if value, ok := _u.mutation.SoraImagePrice360(); ok { - _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { - _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice360Cleared() { - _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraImagePrice540(); ok { - _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { - _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice540Cleared() { - _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestHdCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1817,135 +1646,6 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraImagePrice360() - _u.mutation.SetSoraImagePrice360(v) - return _u -} - -// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraImagePrice360(*v) - } - return _u -} - -// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. -func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne { - _u.mutation.AddSoraImagePrice360(v) - return _u -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne { - _u.mutation.ClearSoraImagePrice360() - return _u -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraImagePrice540() - _u.mutation.SetSoraImagePrice540(v) - return _u -} - -// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraImagePrice540(*v) - } - return _u -} - -// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. -func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne { - _u.mutation.AddSoraImagePrice540(v) - return _u -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne { - _u.mutation.ClearSoraImagePrice540() - return _u -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraVideoPricePerRequest() - _u.mutation.SetSoraVideoPricePerRequest(v) - return _u -} - -// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraVideoPricePerRequest(*v) - } - return _u -} - -// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. -func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne { - _u.mutation.AddSoraVideoPricePerRequest(v) - return _u -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne { - _u.mutation.ClearSoraVideoPricePerRequest() - return _u -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { - _u.mutation.ResetSoraVideoPricePerRequestHd() - _u.mutation.SetSoraVideoPricePerRequestHd(v) - return _u -} - -// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne { - if v != nil { - _u.SetSoraVideoPricePerRequestHd(*v) - } - return _u -} - -// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { - _u.mutation.AddSoraVideoPricePerRequestHd(v) - return _u -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { - _u.mutation.ClearSoraVideoPricePerRequestHd() - return _u -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -2574,48 +2274,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } - if value, ok := _u.mutation.SoraImagePrice360(); ok { - _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { - _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice360Cleared() { - _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraImagePrice540(); ok { - _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { - _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) - } - if _u.mutation.SoraImagePrice540Cleared() { - _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { - _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { - _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) - } - if _u.mutation.SoraVideoPricePerRequestHdCleared() { - _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) - } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index bdbb9fdddd..5400bf9319 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -395,11 +395,6 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, - {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, @@ -447,7 +442,7 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[30]}, + Columns: []*schema.Column{GroupsColumns[25]}, }, }, } @@ -770,7 +765,6 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, - {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, @@ -787,31 +781,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[35]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[36]}, + Columns: []*schema.Column{UsageLogsColumns[35]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[37]}, + Columns: []*schema.Column{UsageLogsColumns[36]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[38]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -820,32 +814,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[37]}, + Columns: []*schema.Column{UsageLogsColumns[36]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[35]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[36]}, + Columns: []*schema.Column{UsageLogsColumns[35]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[38]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_model", @@ -865,17 +859,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]}, }, }, } @@ -896,8 +890,6 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, - {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, - {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 28d9a0ef22..d206039af4 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -8230,16 +8230,6 @@ type GroupMutation struct { addimage_price_2k *float64 image_price_4k *float64 addimage_price_4k *float64 - sora_image_price_360 *float64 - addsora_image_price_360 *float64 - sora_image_price_540 *float64 - addsora_image_price_540 *float64 - sora_video_price_per_request *float64 - addsora_video_price_per_request *float64 - sora_video_price_per_request_hd *float64 - addsora_video_price_per_request_hd *float64 - sora_storage_quota_bytes *int64 - addsora_storage_quota_bytes *int64 claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 @@ -9260,342 +9250,6 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } -// SetSoraImagePrice360 sets the "sora_image_price_360" field. -func (m *GroupMutation) SetSoraImagePrice360(f float64) { - m.sora_image_price_360 = &f - m.addsora_image_price_360 = nil -} - -// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation. -func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) { - v := m.sora_image_price_360 - if v == nil { - return - } - return *v, true -} - -// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err) - } - return oldValue.SoraImagePrice360, nil -} - -// AddSoraImagePrice360 adds f to the "sora_image_price_360" field. -func (m *GroupMutation) AddSoraImagePrice360(f float64) { - if m.addsora_image_price_360 != nil { - *m.addsora_image_price_360 += f - } else { - m.addsora_image_price_360 = &f - } -} - -// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation. -func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) { - v := m.addsora_image_price_360 - if v == nil { - return - } - return *v, true -} - -// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. -func (m *GroupMutation) ClearSoraImagePrice360() { - m.sora_image_price_360 = nil - m.addsora_image_price_360 = nil - m.clearedFields[group.FieldSoraImagePrice360] = struct{}{} -} - -// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation. -func (m *GroupMutation) SoraImagePrice360Cleared() bool { - _, ok := m.clearedFields[group.FieldSoraImagePrice360] - return ok -} - -// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field. -func (m *GroupMutation) ResetSoraImagePrice360() { - m.sora_image_price_360 = nil - m.addsora_image_price_360 = nil - delete(m.clearedFields, group.FieldSoraImagePrice360) -} - -// SetSoraImagePrice540 sets the "sora_image_price_540" field. -func (m *GroupMutation) SetSoraImagePrice540(f float64) { - m.sora_image_price_540 = &f - m.addsora_image_price_540 = nil -} - -// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation. -func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) { - v := m.sora_image_price_540 - if v == nil { - return - } - return *v, true -} - -// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err) - } - return oldValue.SoraImagePrice540, nil -} - -// AddSoraImagePrice540 adds f to the "sora_image_price_540" field. -func (m *GroupMutation) AddSoraImagePrice540(f float64) { - if m.addsora_image_price_540 != nil { - *m.addsora_image_price_540 += f - } else { - m.addsora_image_price_540 = &f - } -} - -// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation. -func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) { - v := m.addsora_image_price_540 - if v == nil { - return - } - return *v, true -} - -// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. -func (m *GroupMutation) ClearSoraImagePrice540() { - m.sora_image_price_540 = nil - m.addsora_image_price_540 = nil - m.clearedFields[group.FieldSoraImagePrice540] = struct{}{} -} - -// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation. -func (m *GroupMutation) SoraImagePrice540Cleared() bool { - _, ok := m.clearedFields[group.FieldSoraImagePrice540] - return ok -} - -// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field. -func (m *GroupMutation) ResetSoraImagePrice540() { - m.sora_image_price_540 = nil - m.addsora_image_price_540 = nil - delete(m.clearedFields, group.FieldSoraImagePrice540) -} - -// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. -func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) { - m.sora_video_price_per_request = &f - m.addsora_video_price_per_request = nil -} - -// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation. -func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) { - v := m.sora_video_price_per_request - if v == nil { - return - } - return *v, true -} - -// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err) - } - return oldValue.SoraVideoPricePerRequest, nil -} - -// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field. -func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) { - if m.addsora_video_price_per_request != nil { - *m.addsora_video_price_per_request += f - } else { - m.addsora_video_price_per_request = &f - } -} - -// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation. -func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) { - v := m.addsora_video_price_per_request - if v == nil { - return - } - return *v, true -} - -// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. -func (m *GroupMutation) ClearSoraVideoPricePerRequest() { - m.sora_video_price_per_request = nil - m.addsora_video_price_per_request = nil - m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{} -} - -// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation. -func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool { - _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest] - return ok -} - -// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field. -func (m *GroupMutation) ResetSoraVideoPricePerRequest() { - m.sora_video_price_per_request = nil - m.addsora_video_price_per_request = nil - delete(m.clearedFields, group.FieldSoraVideoPricePerRequest) -} - -// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) { - m.sora_video_price_per_request_hd = &f - m.addsora_video_price_per_request_hd = nil -} - -// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation. -func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) { - v := m.sora_video_price_per_request_hd - if v == nil { - return - } - return *v, true -} - -// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err) - } - return oldValue.SoraVideoPricePerRequestHd, nil -} - -// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) { - if m.addsora_video_price_per_request_hd != nil { - *m.addsora_video_price_per_request_hd += f - } else { - m.addsora_video_price_per_request_hd = &f - } -} - -// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation. -func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) { - v := m.addsora_video_price_per_request_hd - if v == nil { - return - } - return *v, true -} - -// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() { - m.sora_video_price_per_request_hd = nil - m.addsora_video_price_per_request_hd = nil - m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{} -} - -// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation. -func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool { - _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd] - return ok -} - -// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field. -func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { - m.sora_video_price_per_request_hd = nil - m.addsora_video_price_per_request_hd = nil - delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) -} - -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) { - m.sora_storage_quota_bytes = &i - m.addsora_storage_quota_bytes = nil -} - -// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. -func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) { - v := m.sora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) - } - return oldValue.SoraStorageQuotaBytes, nil -} - -// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. -func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) { - if m.addsora_storage_quota_bytes != nil { - *m.addsora_storage_quota_bytes += i - } else { - m.addsora_storage_quota_bytes = &i - } -} - -// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. -func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { - v := m.addsora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. -func (m *GroupMutation) ResetSoraStorageQuotaBytes() { - m.sora_storage_quota_bytes = nil - m.addsora_storage_quota_bytes = nil -} - // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -10502,7 +10156,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 34) + fields := make([]string, 0, 29) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -10554,21 +10208,6 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } - if m.sora_image_price_360 != nil { - fields = append(fields, group.FieldSoraImagePrice360) - } - if m.sora_image_price_540 != nil { - fields = append(fields, group.FieldSoraImagePrice540) - } - if m.sora_video_price_per_request != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequest) - } - if m.sora_video_price_per_request_hd != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequestHd) - } - if m.sora_storage_quota_bytes != nil { - fields = append(fields, group.FieldSoraStorageQuotaBytes) - } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -10647,16 +10286,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() - case group.FieldSoraImagePrice360: - return m.SoraImagePrice360() - case group.FieldSoraImagePrice540: - return m.SoraImagePrice540() - case group.FieldSoraVideoPricePerRequest: - return m.SoraVideoPricePerRequest() - case group.FieldSoraVideoPricePerRequestHd: - return m.SoraVideoPricePerRequestHd() - case group.FieldSoraStorageQuotaBytes: - return m.SoraStorageQuotaBytes() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -10724,16 +10353,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) - case group.FieldSoraImagePrice360: - return m.OldSoraImagePrice360(ctx) - case group.FieldSoraImagePrice540: - return m.OldSoraImagePrice540(ctx) - case group.FieldSoraVideoPricePerRequest: - return m.OldSoraVideoPricePerRequest(ctx) - case group.FieldSoraVideoPricePerRequestHd: - return m.OldSoraVideoPricePerRequestHd(ctx) - case group.FieldSoraStorageQuotaBytes: - return m.OldSoraStorageQuotaBytes(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -10886,41 +10505,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil - case group.FieldSoraImagePrice360: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraImagePrice360(v) - return nil - case group.FieldSoraImagePrice540: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraImagePrice540(v) - return nil - case group.FieldSoraVideoPricePerRequest: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraVideoPricePerRequest(v) - return nil - case group.FieldSoraVideoPricePerRequestHd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraVideoPricePerRequestHd(v) - return nil - case group.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraStorageQuotaBytes(v) - return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -11037,21 +10621,6 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } - if m.addsora_image_price_360 != nil { - fields = append(fields, group.FieldSoraImagePrice360) - } - if m.addsora_image_price_540 != nil { - fields = append(fields, group.FieldSoraImagePrice540) - } - if m.addsora_video_price_per_request != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequest) - } - if m.addsora_video_price_per_request_hd != nil { - fields = append(fields, group.FieldSoraVideoPricePerRequestHd) - } - if m.addsora_storage_quota_bytes != nil { - fields = append(fields, group.FieldSoraStorageQuotaBytes) - } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -11085,16 +10654,6 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() - case group.FieldSoraImagePrice360: - return m.AddedSoraImagePrice360() - case group.FieldSoraImagePrice540: - return m.AddedSoraImagePrice540() - case group.FieldSoraVideoPricePerRequest: - return m.AddedSoraVideoPricePerRequest() - case group.FieldSoraVideoPricePerRequestHd: - return m.AddedSoraVideoPricePerRequestHd() - case group.FieldSoraStorageQuotaBytes: - return m.AddedSoraStorageQuotaBytes() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: @@ -11166,41 +10725,6 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil - case group.FieldSoraImagePrice360: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraImagePrice360(v) - return nil - case group.FieldSoraImagePrice540: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraImagePrice540(v) - return nil - case group.FieldSoraVideoPricePerRequest: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraVideoPricePerRequest(v) - return nil - case group.FieldSoraVideoPricePerRequestHd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraVideoPricePerRequestHd(v) - return nil - case group.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraStorageQuotaBytes(v) - return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -11254,18 +10778,6 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } - if m.FieldCleared(group.FieldSoraImagePrice360) { - fields = append(fields, group.FieldSoraImagePrice360) - } - if m.FieldCleared(group.FieldSoraImagePrice540) { - fields = append(fields, group.FieldSoraImagePrice540) - } - if m.FieldCleared(group.FieldSoraVideoPricePerRequest) { - fields = append(fields, group.FieldSoraVideoPricePerRequest) - } - if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) { - fields = append(fields, group.FieldSoraVideoPricePerRequestHd) - } if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } @@ -11313,18 +10825,6 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil - case group.FieldSoraImagePrice360: - m.ClearSoraImagePrice360() - return nil - case group.FieldSoraImagePrice540: - m.ClearSoraImagePrice540() - return nil - case group.FieldSoraVideoPricePerRequest: - m.ClearSoraVideoPricePerRequest() - return nil - case group.FieldSoraVideoPricePerRequestHd: - m.ClearSoraVideoPricePerRequestHd() - return nil case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil @@ -11393,21 +10893,6 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil - case group.FieldSoraImagePrice360: - m.ResetSoraImagePrice360() - return nil - case group.FieldSoraImagePrice540: - m.ResetSoraImagePrice540() - return nil - case group.FieldSoraVideoPricePerRequest: - m.ResetSoraVideoPricePerRequest() - return nil - case group.FieldSoraVideoPricePerRequestHd: - m.ResetSoraVideoPricePerRequestHd() - return nil - case group.FieldSoraStorageQuotaBytes: - m.ResetSoraStorageQuotaBytes() - return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -19770,7 +19255,6 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string - media_type *string cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} @@ -21713,55 +21197,6 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } -// SetMediaType sets the "media_type" field. -func (m *UsageLogMutation) SetMediaType(s string) { - m.media_type = &s -} - -// MediaType returns the value of the "media_type" field in the mutation. -func (m *UsageLogMutation) MediaType() (r string, exists bool) { - v := m.media_type - if v == nil { - return - } - return *v, true -} - -// OldMediaType returns the old "media_type" field's value of the UsageLog entity. -// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMediaType is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMediaType requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldMediaType: %w", err) - } - return oldValue.MediaType, nil -} - -// ClearMediaType clears the value of the "media_type" field. -func (m *UsageLogMutation) ClearMediaType() { - m.media_type = nil - m.clearedFields[usagelog.FieldMediaType] = struct{}{} -} - -// MediaTypeCleared returns if the "media_type" field was cleared in this mutation. -func (m *UsageLogMutation) MediaTypeCleared() bool { - _, ok := m.clearedFields[usagelog.FieldMediaType] - return ok -} - -// ResetMediaType resets all changes to the "media_type" field. -func (m *UsageLogMutation) ResetMediaType() { - m.media_type = nil - delete(m.clearedFields, usagelog.FieldMediaType) -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { m.cache_ttl_overridden = &b @@ -22003,7 +21438,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 38) + fields := make([]string, 0, 37) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -22109,9 +21544,6 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } - if m.media_type != nil { - fields = append(fields, usagelog.FieldMediaType) - } if m.cache_ttl_overridden != nil { fields = append(fields, usagelog.FieldCacheTTLOverridden) } @@ -22196,8 +21628,6 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() - case usagelog.FieldMediaType: - return m.MediaType() case usagelog.FieldCacheTTLOverridden: return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: @@ -22281,8 +21711,6 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) - case usagelog.FieldMediaType: - return m.OldMediaType(ctx) case usagelog.FieldCacheTTLOverridden: return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: @@ -22541,13 +21969,6 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil - case usagelog.FieldMediaType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMediaType(v) - return nil case usagelog.FieldCacheTTLOverridden: v, ok := value.(bool) if !ok { @@ -22865,9 +22286,6 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } - if m.FieldCleared(usagelog.FieldMediaType) { - fields = append(fields, usagelog.FieldMediaType) - } return fields } @@ -22924,9 +22342,6 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil - case usagelog.FieldMediaType: - m.ClearMediaType() - return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -23040,9 +22455,6 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil - case usagelog.FieldMediaType: - m.ResetMediaType() - return nil case usagelog.FieldCacheTTLOverridden: m.ResetCacheTTLOverridden() return nil @@ -23221,10 +22633,6 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time - sora_storage_quota_bytes *int64 - addsora_storage_quota_bytes *int64 - sora_storage_used_bytes *int64 - addsora_storage_used_bytes *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -23939,118 +23347,6 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) { - m.sora_storage_quota_bytes = &i - m.addsora_storage_quota_bytes = nil -} - -// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. -func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) { - v := m.sora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity. -// If the User object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) - } - return oldValue.SoraStorageQuotaBytes, nil -} - -// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. -func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) { - if m.addsora_storage_quota_bytes != nil { - *m.addsora_storage_quota_bytes += i - } else { - m.addsora_storage_quota_bytes = &i - } -} - -// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. -func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { - v := m.addsora_storage_quota_bytes - if v == nil { - return - } - return *v, true -} - -// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. -func (m *UserMutation) ResetSoraStorageQuotaBytes() { - m.sora_storage_quota_bytes = nil - m.addsora_storage_quota_bytes = nil -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (m *UserMutation) SetSoraStorageUsedBytes(i int64) { - m.sora_storage_used_bytes = &i - m.addsora_storage_used_bytes = nil -} - -// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation. -func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) { - v := m.sora_storage_used_bytes - if v == nil { - return - } - return *v, true -} - -// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity. -// If the User object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err) - } - return oldValue.SoraStorageUsedBytes, nil -} - -// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field. -func (m *UserMutation) AddSoraStorageUsedBytes(i int64) { - if m.addsora_storage_used_bytes != nil { - *m.addsora_storage_used_bytes += i - } else { - m.addsora_storage_used_bytes = &i - } -} - -// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation. -func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) { - v := m.addsora_storage_used_bytes - if v == nil { - return - } - return *v, true -} - -// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field. -func (m *UserMutation) ResetSoraStorageUsedBytes() { - m.sora_storage_used_bytes = nil - m.addsora_storage_used_bytes = nil -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -24571,7 +23867,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 16) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -24614,12 +23910,6 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } - if m.sora_storage_quota_bytes != nil { - fields = append(fields, user.FieldSoraStorageQuotaBytes) - } - if m.sora_storage_used_bytes != nil { - fields = append(fields, user.FieldSoraStorageUsedBytes) - } return fields } @@ -24656,10 +23946,6 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() - case user.FieldSoraStorageQuotaBytes: - return m.SoraStorageQuotaBytes() - case user.FieldSoraStorageUsedBytes: - return m.SoraStorageUsedBytes() } return nil, false } @@ -24697,10 +23983,6 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) - case user.FieldSoraStorageQuotaBytes: - return m.OldSoraStorageQuotaBytes(ctx) - case user.FieldSoraStorageUsedBytes: - return m.OldSoraStorageUsedBytes(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -24808,20 +24090,6 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil - case user.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraStorageQuotaBytes(v) - return nil - case user.FieldSoraStorageUsedBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSoraStorageUsedBytes(v) - return nil } return fmt.Errorf("unknown User field %s", name) } @@ -24836,12 +24104,6 @@ func (m *UserMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, user.FieldConcurrency) } - if m.addsora_storage_quota_bytes != nil { - fields = append(fields, user.FieldSoraStorageQuotaBytes) - } - if m.addsora_storage_used_bytes != nil { - fields = append(fields, user.FieldSoraStorageUsedBytes) - } return fields } @@ -24854,10 +24116,6 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalance() case user.FieldConcurrency: return m.AddedConcurrency() - case user.FieldSoraStorageQuotaBytes: - return m.AddedSoraStorageQuotaBytes() - case user.FieldSoraStorageUsedBytes: - return m.AddedSoraStorageUsedBytes() } return nil, false } @@ -24881,20 +24139,6 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil - case user.FieldSoraStorageQuotaBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraStorageQuotaBytes(v) - return nil - case user.FieldSoraStorageUsedBytes: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSoraStorageUsedBytes(v) - return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -24985,12 +24229,6 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil - case user.FieldSoraStorageQuotaBytes: - m.ResetSoraStorageQuotaBytes() - return nil - case user.FieldSoraStorageUsedBytes: - m.ResetSoraStorageUsedBytes() - return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 336b1f8243..803b7bc24b 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -430,44 +430,40 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) - // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. - groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor() - // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. - group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[19].Descriptor() + groupDescClaudeCodeOnly := groupFields[14].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[23].Descriptor() + groupDescModelRoutingEnabled := groupFields[18].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[24].Descriptor() + groupDescMcpXMLInject := groupFields[19].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[25].Descriptor() + groupDescSupportedModelScopes := groupFields[20].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[26].Descriptor() + groupDescSortOrder := groupFields[21].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. - groupDescAllowMessagesDispatch := groupFields[27].Descriptor() + groupDescAllowMessagesDispatch := groupFields[22].Descriptor() // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. - groupDescRequireOauthOnly := groupFields[28].Descriptor() + groupDescRequireOauthOnly := groupFields[23].Descriptor() // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. - groupDescRequirePrivacySet := groupFields[29].Descriptor() + groupDescRequirePrivacySet := groupFields[24].Descriptor() // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. - groupDescDefaultMappedModel := groupFields[30].Descriptor() + groupDescDefaultMappedModel := groupFields[25].Descriptor() // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. @@ -963,16 +959,12 @@ func init() { usagelogDescImageSize := usagelogFields[34].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) - // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[35].Descriptor() - // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. - usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[37].Descriptor() + usagelogDescCreatedAt := usagelogFields[36].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() @@ -1064,14 +1056,6 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) - // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. - userDescSoraStorageQuotaBytes := userFields[11].Descriptor() - // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. - user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64) - // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field. - userDescSoraStorageUsedBytes := userFields[12].Descriptor() - // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field. - user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index fd83bf26ad..0eb89c181b 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -87,28 +87,6 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - // Sora 按次计费配置(阶段 1) - field.Float("sora_image_price_360"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - field.Float("sora_image_price_540"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - field.Float("sora_video_price_per_request"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - field.Float("sora_video_price_per_request_hd"). - Optional(). - Nillable(). - SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), - - // Sora 存储配额 - field.Int64("sora_storage_quota_bytes"). - Default(0), - // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index f6c725a2b1..bd3ebfcc3c 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -134,12 +134,6 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), - // 媒体类型字段(sora 使用) - field.String("media_type"). - MaxLen(16). - Optional(). - Nillable(), - // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) field.Bool("cache_ttl_overridden"). Default(false), diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index 0a3b5d9ec2..d443ef455c 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,12 +72,6 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), - - // Sora 存储配额 - field.Int64("sora_storage_quota_bytes"). - Default(0), - field.Int64("sora_storage_used_bytes"). - Default(0), } } diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index b857afdbb5..a8e0cc6ce8 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -92,8 +92,6 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` - // MediaType holds the value of the "media_type" field. - MediaType *string `json:"media_type,omitempty"` // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. @@ -187,7 +185,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -436,13 +434,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } - case usagelog.FieldMediaType: - if value, ok := values[i].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field media_type", values[i]) - } else if value.Valid { - _m.MediaType = new(string) - *_m.MediaType = value.String - } case usagelog.FieldCacheTTLOverridden: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) @@ -649,11 +640,6 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") - if v := _m.MediaType; v != nil { - builder.WriteString("media_type=") - builder.WriteString(*v) - } - builder.WriteString(", ") builder.WriteString("cache_ttl_overridden=") builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) builder.WriteString(", ") diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 1567ad9b45..a7438e604f 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -84,8 +84,6 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" - // FieldMediaType holds the string denoting the media_type field in the database. - FieldMediaType = "media_type" // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. @@ -177,7 +175,6 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, - FieldMediaType, FieldCacheTTLOverridden, FieldCreatedAt, } @@ -245,8 +242,6 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error - // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. - MediaTypeValidator func(string) error // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. @@ -436,11 +431,6 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } -// ByMediaType orders the results by the media_type field. -func ByMediaType(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldMediaType, opts...).ToFunc() -} - // ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index a1fb36cbaa..b8439a0397 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -230,11 +230,6 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } -// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. -func MediaType(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) -} - // CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. func CacheTTLOverridden(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) @@ -1905,81 +1900,6 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } -// MediaTypeEQ applies the EQ predicate on the "media_type" field. -func MediaTypeEQ(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) -} - -// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. -func MediaTypeNEQ(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v)) -} - -// MediaTypeIn applies the In predicate on the "media_type" field. -func MediaTypeIn(vs ...string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...)) -} - -// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. -func MediaTypeNotIn(vs ...string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...)) -} - -// MediaTypeGT applies the GT predicate on the "media_type" field. -func MediaTypeGT(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldGT(FieldMediaType, v)) -} - -// MediaTypeGTE applies the GTE predicate on the "media_type" field. -func MediaTypeGTE(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v)) -} - -// MediaTypeLT applies the LT predicate on the "media_type" field. -func MediaTypeLT(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldLT(FieldMediaType, v)) -} - -// MediaTypeLTE applies the LTE predicate on the "media_type" field. -func MediaTypeLTE(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v)) -} - -// MediaTypeContains applies the Contains predicate on the "media_type" field. -func MediaTypeContains(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldContains(FieldMediaType, v)) -} - -// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. -func MediaTypeHasPrefix(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v)) -} - -// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. -func MediaTypeHasSuffix(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v)) -} - -// MediaTypeIsNil applies the IsNil predicate on the "media_type" field. -func MediaTypeIsNil() predicate.UsageLog { - return predicate.UsageLog(sql.FieldIsNull(FieldMediaType)) -} - -// MediaTypeNotNil applies the NotNil predicate on the "media_type" field. -func MediaTypeNotNil() predicate.UsageLog { - return predicate.UsageLog(sql.FieldNotNull(FieldMediaType)) -} - -// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. -func MediaTypeEqualFold(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v)) -} - -// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. -func MediaTypeContainsFold(v string) predicate.UsageLog { - return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) -} - // CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index d15e231d9c..fded364e0e 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -477,20 +477,6 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } -// SetMediaType sets the "media_type" field. -func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate { - _c.mutation.SetMediaType(v) - return _c -} - -// SetNillableMediaType sets the "media_type" field if the given value is not nil. -func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { - if v != nil { - _c.SetMediaType(*v) - } - return _c -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { _c.mutation.SetCacheTTLOverridden(v) @@ -768,11 +754,6 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } - if v, ok := _c.mutation.MediaType(); ok { - if err := usagelog.MediaTypeValidator(v); err != nil { - return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} - } - } if _, ok := _c.mutation.CacheTTLOverridden(); !ok { return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} } @@ -935,10 +916,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } - if value, ok := _c.mutation.MediaType(); ok { - _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) - _node.MediaType = &value - } if value, ok := _c.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) _node.CacheTTLOverridden = value @@ -1702,24 +1679,6 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } -// SetMediaType sets the "media_type" field. -func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert { - u.Set(usagelog.FieldMediaType, v) - return u -} - -// UpdateMediaType sets the "media_type" field to the value that was provided on create. -func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert { - u.SetExcluded(usagelog.FieldMediaType) - return u -} - -// ClearMediaType clears the value of the "media_type" field. -func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { - u.SetNull(usagelog.FieldMediaType) - return u -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { u.Set(usagelog.FieldCacheTTLOverridden, v) @@ -2498,27 +2457,6 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } -// SetMediaType sets the "media_type" field. -func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne { - return u.Update(func(s *UsageLogUpsert) { - s.SetMediaType(v) - }) -} - -// UpdateMediaType sets the "media_type" field to the value that was provided on create. -func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne { - return u.Update(func(s *UsageLogUpsert) { - s.UpdateMediaType() - }) -} - -// ClearMediaType clears the value of the "media_type" field. -func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { - return u.Update(func(s *UsageLogUpsert) { - s.ClearMediaType() - }) -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -3465,27 +3403,6 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } -// SetMediaType sets the "media_type" field. -func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk { - return u.Update(func(s *UsageLogUpsert) { - s.SetMediaType(v) - }) -} - -// UpdateMediaType sets the "media_type" field to the value that was provided on create. -func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk { - return u.Update(func(s *UsageLogUpsert) { - s.UpdateMediaType() - }) -} - -// ClearMediaType clears the value of the "media_type" field. -func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { - return u.Update(func(s *UsageLogUpsert) { - s.ClearMediaType() - }) -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 52f5a606dc..bb5ac86c78 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -739,26 +739,6 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } -// SetMediaType sets the "media_type" field. -func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate { - _u.mutation.SetMediaType(v) - return _u -} - -// SetNillableMediaType sets the "media_type" field if the given value is not nil. -func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate { - if v != nil { - _u.SetMediaType(*v) - } - return _u -} - -// ClearMediaType clears the value of the "media_type" field. -func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { - _u.mutation.ClearMediaType() - return _u -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { _u.mutation.SetCacheTTLOverridden(v) @@ -912,11 +892,6 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } - if v, ok := _u.mutation.MediaType(); ok { - if err := usagelog.MediaTypeValidator(v); err != nil { - return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} - } - } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -1124,12 +1099,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } - if value, ok := _u.mutation.MediaType(); ok { - _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) - } - if _u.mutation.MediaTypeCleared() { - _spec.ClearField(usagelog.FieldMediaType, field.TypeString) - } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } @@ -2005,26 +1974,6 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } -// SetMediaType sets the "media_type" field. -func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne { - _u.mutation.SetMediaType(v) - return _u -} - -// SetNillableMediaType sets the "media_type" field if the given value is not nil. -func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne { - if v != nil { - _u.SetMediaType(*v) - } - return _u -} - -// ClearMediaType clears the value of the "media_type" field. -func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { - _u.mutation.ClearMediaType() - return _u -} - // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { _u.mutation.SetCacheTTLOverridden(v) @@ -2191,11 +2140,6 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } - if v, ok := _u.mutation.MediaType(); ok { - if err := usagelog.MediaTypeValidator(v); err != nil { - return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} - } - } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -2420,12 +2364,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } - if value, ok := _u.mutation.MediaType(); ok { - _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) - } - if _u.mutation.MediaTypeCleared() { - _spec.ClearField(usagelog.FieldMediaType, field.TypeString) - } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } diff --git a/backend/ent/user.go b/backend/ent/user.go index b3f933f6fa..2435aa1b99 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,10 +45,6 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` - // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` - // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field. - SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -181,7 +177,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes: + case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: values[i] = new(sql.NullString) @@ -295,18 +291,6 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } - case user.FieldSoraStorageQuotaBytes: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) - } else if value.Valid { - _m.SoraStorageQuotaBytes = value.Int64 - } - case user.FieldSoraStorageUsedBytes: - if value, ok := values[i].(*sql.NullInt64); !ok { - return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i]) - } else if value.Valid { - _m.SoraStorageUsedBytes = value.Int64 - } default: _m.selectValues.Set(columns[i], values[i]) } @@ -440,12 +424,6 @@ func (_m *User) String() string { builder.WriteString("totp_enabled_at=") builder.WriteString(v.Format(time.ANSIC)) } - builder.WriteString(", ") - builder.WriteString("sora_storage_quota_bytes=") - builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) - builder.WriteString(", ") - builder.WriteString("sora_storage_used_bytes=") - builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index 155b916086..ae9418ff07 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,10 +43,6 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" - // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. - FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" - // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database. - FieldSoraStorageUsedBytes = "sora_storage_used_bytes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -156,8 +152,6 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, - FieldSoraStorageQuotaBytes, - FieldSoraStorageUsedBytes, } var ( @@ -214,10 +208,6 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool - // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. - DefaultSoraStorageQuotaBytes int64 - // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field. - DefaultSoraStorageUsedBytes int64 ) // OrderOption defines the ordering options for the User queries. @@ -298,16 +288,6 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } -// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. -func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() -} - -// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field. -func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc() -} - // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index e26afcf381..1de6103702 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,16 +125,6 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } -// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. -func SoraStorageQuotaBytes(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ. -func SoraStorageUsedBytes(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) -} - // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -870,86 +860,6 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } -// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesEQ(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNEQ(v int64) predicate.User { - return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) -} - -// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGT(v int64) predicate.User { - return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesGTE(v int64) predicate.User { - return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLT(v int64) predicate.User { - return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. -func SoraStorageQuotaBytesLTE(v int64) predicate.User { - return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) -} - -// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesEQ(v int64) predicate.User { - return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesNEQ(v int64) predicate.User { - return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...)) -} - -// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User { - return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...)) -} - -// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesGT(v int64) predicate.User { - return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesGTE(v int64) predicate.User { - return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesLT(v int64) predicate.User { - return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v)) -} - -// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field. -func SoraStorageUsedBytesLTE(v int64) predicate.User { - return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v)) -} - // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index df0c6bcc1a..f862a580c5 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -210,34 +210,6 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate { - _c.mutation.SetSoraStorageQuotaBytes(v) - return _c -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate { - if v != nil { - _c.SetSoraStorageQuotaBytes(*v) - } - return _c -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate { - _c.mutation.SetSoraStorageUsedBytes(v) - return _c -} - -// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. -func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate { - if v != nil { - _c.SetSoraStorageUsedBytes(*v) - } - return _c -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -452,14 +424,6 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - v := user.DefaultSoraStorageQuotaBytes - _c.mutation.SetSoraStorageQuotaBytes(v) - } - if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { - v := user.DefaultSoraStorageUsedBytes - _c.mutation.SetSoraStorageUsedBytes(v) - } return nil } @@ -523,12 +487,6 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } - if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { - return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)} - } - if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { - return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)} - } return nil } @@ -612,14 +570,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } - if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - _node.SoraStorageQuotaBytes = value - } - if value, ok := _c.mutation.SoraStorageUsedBytes(); ok { - _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - _node.SoraStorageUsedBytes = value - } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1006,42 +956,6 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert { - u.Set(user.FieldSoraStorageQuotaBytes, v) - return u -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert { - u.SetExcluded(user.FieldSoraStorageQuotaBytes) - return u -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert { - u.Add(user.FieldSoraStorageQuotaBytes, v) - return u -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert { - u.Set(user.FieldSoraStorageUsedBytes, v) - return u -} - -// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. -func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert { - u.SetExcluded(user.FieldSoraStorageUsedBytes) - return u -} - -// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. -func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert { - u.Add(user.FieldSoraStorageUsedBytes, v) - return u -} - // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1304,48 +1218,6 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageUsedBytes(v) - }) -} - -// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. -func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageUsedBytes(v) - }) -} - -// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. -func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageUsedBytes() - }) -} - // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1774,48 +1646,6 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageQuotaBytes(v) - }) -} - -// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. -func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageQuotaBytes(v) - }) -} - -// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. -func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageQuotaBytes() - }) -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.SetSoraStorageUsedBytes(v) - }) -} - -// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. -func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.AddSoraStorageUsedBytes(v) - }) -} - -// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. -func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk { - return u.Update(func(s *UserUpsert) { - s.UpdateSoraStorageUsedBytes() - }) -} - // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index f71f0cadfa..80222c92dc 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -242,48 +242,6 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate { - _u.mutation.ResetSoraStorageUsedBytes() - _u.mutation.SetSoraStorageUsedBytes(v) - return _u -} - -// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. -func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate { - if v != nil { - _u.SetSoraStorageUsedBytes(*v) - } - return _u -} - -// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. -func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate { - _u.mutation.AddSoraStorageUsedBytes(v) - return _u -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -751,18 +709,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { - _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { - _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1406,48 +1352,6 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } -// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. -func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne { - _u.mutation.ResetSoraStorageQuotaBytes() - _u.mutation.SetSoraStorageQuotaBytes(v) - return _u -} - -// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. -func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne { - if v != nil { - _u.SetSoraStorageQuotaBytes(*v) - } - return _u -} - -// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. -func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne { - _u.mutation.AddSoraStorageQuotaBytes(v) - return _u -} - -// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. -func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne { - _u.mutation.ResetSoraStorageUsedBytes() - _u.mutation.SetSoraStorageUsedBytes(v) - return _u -} - -// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. -func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne { - if v != nil { - _u.SetSoraStorageUsedBytes(*v) - } - return _u -} - -// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. -func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne { - _u.mutation.AddSoraStorageUsedBytes(v) - return _u -} - // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1945,18 +1849,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } - if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { - _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { - _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { - _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } - if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { - _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) - } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3ee5d6cdca..9b43037789 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -77,7 +77,6 @@ type Config struct { UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - Sora SoraConfig `mapstructure:"sora"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` @@ -197,8 +196,6 @@ type TokenRefreshConfig struct { MaxRetries int `mapstructure:"max_retries"` // 重试退避基础时间(秒) RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` - // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭) - SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"` } type PricingConfig struct { @@ -303,59 +300,6 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } -// SoraConfig 直连 Sora 配置 -type SoraConfig struct { - Client SoraClientConfig `mapstructure:"client"` - Storage SoraStorageConfig `mapstructure:"storage"` -} - -// SoraClientConfig 直连 Sora 客户端配置 -type SoraClientConfig struct { - BaseURL string `mapstructure:"base_url"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - MaxRetries int `mapstructure:"max_retries"` - CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` - PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` - MaxPollAttempts int `mapstructure:"max_poll_attempts"` - RecentTaskLimit int `mapstructure:"recent_task_limit"` - RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` - Debug bool `mapstructure:"debug"` - UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` - Headers map[string]string `mapstructure:"headers"` - UserAgent string `mapstructure:"user_agent"` - DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` - CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` -} - -// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 -type SoraCurlCFFISidecarConfig struct { - Enabled bool `mapstructure:"enabled"` - BaseURL string `mapstructure:"base_url"` - Impersonate string `mapstructure:"impersonate"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` - SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` -} - -// SoraStorageConfig 媒体存储配置 -type SoraStorageConfig struct { - Type string `mapstructure:"type"` - LocalPath string `mapstructure:"local_path"` - FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` - MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` - DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` - MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` - Debug bool `mapstructure:"debug"` - Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` -} - -// SoraStorageCleanupConfig 媒体清理配置 -type SoraStorageCleanupConfig struct { - Enabled bool `mapstructure:"enabled"` - Schedule string `mapstructure:"schedule"` - RetentionDays int `mapstructure:"retention_days"` -} - // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -424,24 +368,6 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` - // Sora 专用配置 - // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) - SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` - // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) - SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` - // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) - SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` - // SoraStreamMode: stream 强制策略(force/error) - SoraStreamMode string `mapstructure:"sora_stream_mode"` - // SoraModelFilters: 模型列表过滤配置 - SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` - // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key - SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` - // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) - SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` - // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) - SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` - // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) MaxAccountSwitches int `mapstructure:"max_account_switches"` // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) @@ -639,12 +565,6 @@ type GatewayUsageRecordConfig struct { AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` } -// SoraModelFiltersConfig Sora 模型过滤配置 -type SoraModelFiltersConfig struct { - // HidePromptEnhance 是否隐藏 prompt-enhance 模型 - HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` -} - // TLSFingerprintConfig TLS指纹伪装配置 // 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端 type TLSFingerprintConfig struct { @@ -1402,13 +1322,6 @@ func setDefaults() { viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) - viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) - viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) - viper.SetDefault("gateway.sora_request_timeout_seconds", 180) - viper.SetDefault("gateway.sora_stream_mode", "force") - viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true) - viper.SetDefault("gateway.sora_media_require_api_key", true) - viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大) @@ -1465,45 +1378,12 @@ func setDefaults() { viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) - // Sora 直连配置 - viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") - viper.SetDefault("sora.client.timeout_seconds", 120) - viper.SetDefault("sora.client.max_retries", 3) - viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900) - viper.SetDefault("sora.client.poll_interval_seconds", 2) - viper.SetDefault("sora.client.max_poll_attempts", 600) - viper.SetDefault("sora.client.recent_task_limit", 50) - viper.SetDefault("sora.client.recent_task_limit_max", 200) - viper.SetDefault("sora.client.debug", false) - viper.SetDefault("sora.client.use_openai_token_provider", false) - viper.SetDefault("sora.client.headers", map[string]string{}) - viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - viper.SetDefault("sora.client.disable_tls_fingerprint", false) - viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true) - viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080") - viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131") - viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60) - viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true) - viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600) - - viper.SetDefault("sora.storage.type", "local") - viper.SetDefault("sora.storage.local_path", "") - viper.SetDefault("sora.storage.fallback_to_upstream", true) - viper.SetDefault("sora.storage.max_concurrent_downloads", 4) - viper.SetDefault("sora.storage.download_timeout_seconds", 120) - viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20)) - viper.SetDefault("sora.storage.debug", false) - viper.SetDefault("sora.storage.cleanup.enabled", true) - viper.SetDefault("sora.storage.cleanup.retention_days", 7) - viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *") - // TokenRefresh viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 - viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET @@ -1879,86 +1759,6 @@ func (c *Config) Validate() error { if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") } - if c.Gateway.SoraMaxBodySize < 0 { - return fmt.Errorf("gateway.sora_max_body_size must be non-negative") - } - if c.Gateway.SoraStreamTimeoutSeconds < 0 { - return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative") - } - if c.Gateway.SoraRequestTimeoutSeconds < 0 { - return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative") - } - if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 { - return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative") - } - if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" { - switch mode { - case "force", "error": - default: - return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") - } - } - if c.Sora.Client.TimeoutSeconds < 0 { - return fmt.Errorf("sora.client.timeout_seconds must be non-negative") - } - if c.Sora.Client.MaxRetries < 0 { - return fmt.Errorf("sora.client.max_retries must be non-negative") - } - if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 { - return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") - } - if c.Sora.Client.PollIntervalSeconds < 0 { - return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") - } - if c.Sora.Client.MaxPollAttempts < 0 { - return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") - } - if c.Sora.Client.RecentTaskLimit < 0 { - return fmt.Errorf("sora.client.recent_task_limit must be non-negative") - } - if c.Sora.Client.RecentTaskLimitMax < 0 { - return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative") - } - if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 && - c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { - c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit - } - if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 { - return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative") - } - if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 { - return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") - } - if !c.Sora.Client.CurlCFFISidecar.Enabled { - return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true") - } - if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" { - return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required") - } - if c.Sora.Storage.MaxConcurrentDownloads < 0 { - return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") - } - if c.Sora.Storage.DownloadTimeoutSeconds < 0 { - return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative") - } - if c.Sora.Storage.MaxDownloadBytes < 0 { - return fmt.Errorf("sora.storage.max_download_bytes must be non-negative") - } - if c.Sora.Storage.Cleanup.Enabled { - if c.Sora.Storage.Cleanup.RetentionDays <= 0 { - return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") - } - if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" { - return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled") - } - } else { - if c.Sora.Storage.Cleanup.RetentionDays < 0 { - return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative") - } - } - if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" { - return fmt.Errorf("sora.storage.type must be 'local'") - } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index abb76549da..2de5451ee0 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1554,94 +1554,6 @@ func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) { } } -func TestSoraCurlCFFISidecarDefaults(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - if !cfg.Sora.Client.CurlCFFISidecar.Enabled { - t.Fatalf("Sora curl_cffi sidecar should be enabled by default") - } - if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 { - t.Fatalf("Sora cloudflare challenge cooldown should be positive by default") - } - if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" { - t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default") - } - if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" { - t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default") - } - if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled { - t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default") - } - if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 { - t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default") - } -} - -func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CurlCFFISidecar.Enabled = false - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") { - t.Fatalf("Validate() error = %v, want sidecar enabled error", err) - } -} - -func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CurlCFFISidecar.BaseURL = " " - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") { - t.Fatalf("Validate() error = %v, want sidecar base_url required error", err) - } -} - -func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1 - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") { - t.Fatalf("Validate() error = %v, want sidecar session ttl error", err) - } -} - -func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) { - resetViperWithJWTSecret(t) - - cfg, err := Load() - if err != nil { - t.Fatalf("Load() error: %v", err) - } - - cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1 - err = cfg.Validate() - if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") { - t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) - } -} - func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { resetViperWithJWTSecret(t) cfg, err := Load() diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 4e69ca0252..429486c3bf 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -22,7 +22,6 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" - PlatformSora = "sora" ) // Account type constants diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 12139b5165..20cc09eebc 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -567,15 +567,15 @@ func defaultProxyName(name string) string { // enrichCredentialsFromIDToken performs best-effort extraction of user info fields // (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials. -// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently. +// Only applies to OpenAI OAuth accounts. Skips expired token errors silently. // Existing credential values are never overwritten — only missing fields are filled. func enrichCredentialsFromIDToken(item *DataAccount) { if item.Credentials == nil { return } - // Only enrich OpenAI/Sora OAuth accounts + // Only enrich OpenAI OAuth accounts platform := strings.ToLower(strings.TrimSpace(item.Platform)) - if platform != service.PlatformOpenAI && platform != service.PlatformSora { + if platform != service.PlatformOpenAI { return } if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 681da5e8fe..9aed64d592 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1875,12 +1875,6 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } - // Handle Sora accounts - if account.Platform == service.PlatformSora { - response.Success(c, service.DefaultSoraModels(nil)) - return - } - // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 9759cef5c0..60d68913e8 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -380,7 +380,6 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se {Target: "openai", Status: "pass", HTTPStatus: 401}, {Target: "anthropic", Status: "pass", HTTPStatus: 401}, {Target: "gemini", Status: "pass", HTTPStatus: 200}, - {Target: "sora", Status: "pass", HTTPStatus: 401}, }, }, nil } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index caa27bc38c..458ed35d47 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -95,10 +95,6 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -108,8 +104,6 @@ type CreateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` - // Sora 存储配额 - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` RequireOAuthOnly bool `json:"require_oauth_only"` @@ -123,7 +117,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` @@ -135,10 +129,6 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -148,8 +138,6 @@ type UpdateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` - // Sora 存储配额 - SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` RequireOAuthOnly *bool `json:"require_oauth_only"` @@ -258,10 +246,6 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -269,7 +253,6 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, @@ -313,10 +296,6 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -324,7 +303,6 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, AllowMessagesDispatch: req.AllowMessagesDispatch, RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index 4e6179dbe1..cc0c933792 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -19,9 +19,6 @@ type OpenAIOAuthHandler struct { } func oauthPlatformFromPath(c *gin.Context) string { - if strings.Contains(c.FullPath(), "/admin/sora/") { - return service.PlatformSora - } return service.PlatformOpenAI } @@ -105,7 +102,6 @@ type OpenAIRefreshTokenRequest struct { // RefreshToken refreshes an OpenAI OAuth token // POST /api/v1/admin/openai/refresh-token -// POST /api/v1/admin/sora/rt2at func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { var req OpenAIRefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -145,39 +141,8 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { response.Success(c, tokenInfo) } -// ExchangeSoraSessionToken exchanges Sora session token to access token -// POST /api/v1/admin/sora/st2at -func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { - var req struct { - SessionToken string `json:"session_token"` - ST string `json:"st"` - ProxyID *int64 `json:"proxy_id"` - } - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - sessionToken := strings.TrimSpace(req.SessionToken) - if sessionToken == "" { - sessionToken = strings.TrimSpace(req.ST) - } - if sessionToken == "" { - response.BadRequest(c, "session_token is required") - return - } - - tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, tokenInfo) -} - -// RefreshAccountToken refreshes token for a specific OpenAI/Sora account +// RefreshAccountToken refreshes token for a specific OpenAI account // POST /api/v1/admin/openai/accounts/:id/refresh -// POST /api/v1/admin/sora/accounts/:id/refresh func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -232,9 +197,8 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { response.Success(c, dto.AccountFromService(updatedAccount)) } -// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info +// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // POST /api/v1/admin/openai/create-from-oauth -// POST /api/v1/admin/sora/create-from-oauth func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { var req struct { SessionID string `json:"session_id" binding:"required"` @@ -276,11 +240,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { name = tokenInfo.Email } if name == "" { - if platform == service.PlatformSora { - name = "Sora OAuth Account" - } else { - name = "OpenAI OAuth Account" - } + name = "OpenAI OAuth Account" } // Create account diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 397526a7ea..069169172d 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -41,17 +41,15 @@ type SettingHandler struct { emailService *service.EmailService turnstileService *service.TurnstileService opsService *service.OpsService - soraS3Storage *service.SoraS3Storage } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler { return &SettingHandler{ settingService: settingService, emailService: emailService, turnstileService: turnstileService, opsService: opsService, - soraS3Storage: soraS3Storage, } } @@ -108,7 +106,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, @@ -177,7 +174,6 @@ type UpdateSettingsRequest struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` @@ -566,7 +562,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: req.HideCcsImportButton, PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, - SoraClientEnabled: req.SoraClientEnabled, CustomMenuItems: customMenuJSON, CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, @@ -676,7 +671,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: updatedSettings.HideCcsImportButton, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, - SoraClientEnabled: updatedSettings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, @@ -1207,384 +1201,6 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { }) } -func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings { - if settings == nil { - return dto.SoraS3Settings{} - } - return dto.SoraS3Settings{ - Enabled: settings.Enabled, - Endpoint: settings.Endpoint, - Region: settings.Region, - Bucket: settings.Bucket, - AccessKeyID: settings.AccessKeyID, - SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured, - Prefix: settings.Prefix, - ForcePathStyle: settings.ForcePathStyle, - CDNURL: settings.CDNURL, - DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes, - } -} - -func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile { - return dto.SoraS3Profile{ - ProfileID: profile.ProfileID, - Name: profile.Name, - IsActive: profile.IsActive, - Enabled: profile.Enabled, - Endpoint: profile.Endpoint, - Region: profile.Region, - Bucket: profile.Bucket, - AccessKeyID: profile.AccessKeyID, - SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured, - Prefix: profile.Prefix, - ForcePathStyle: profile.ForcePathStyle, - CDNURL: profile.CDNURL, - DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes, - UpdatedAt: profile.UpdatedAt, - } -} - -func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error { - if !enabled { - return nil - } - if strings.TrimSpace(endpoint) == "" { - return fmt.Errorf("S3 Endpoint is required when enabled") - } - if strings.TrimSpace(bucket) == "" { - return fmt.Errorf("S3 Bucket is required when enabled") - } - if strings.TrimSpace(accessKeyID) == "" { - return fmt.Errorf("S3 Access Key ID is required when enabled") - } - if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret { - return nil - } - return fmt.Errorf("S3 Secret Access Key is required when enabled") -} - -func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile { - for idx := range items { - if items[idx].ProfileID == profileID { - return &items[idx] - } - } - return nil -} - -// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口) -// GET /api/v1/admin/settings/sora-s3 -func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) { - settings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, toSoraS3SettingsDTO(settings)) -} - -// ListSoraS3Profiles 获取 Sora S3 多配置 -// GET /api/v1/admin/settings/sora-s3/profiles -func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) { - result, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - items := make([]dto.SoraS3Profile, 0, len(result.Items)) - for idx := range result.Items { - items = append(items, toSoraS3ProfileDTO(result.Items[idx])) - } - response.Success(c, dto.ListSoraS3ProfilesResponse{ - ActiveProfileID: result.ActiveProfileID, - Items: items, - }) -} - -// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口) -type UpdateSoraS3SettingsRequest struct { - ProfileID string `json:"profile_id"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -type CreateSoraS3ProfileRequest struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - SetActive bool `json:"set_active"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -type UpdateSoraS3ProfileRequest struct { - Name string `json:"name"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -// CreateSoraS3Profile 创建 Sora S3 配置 -// POST /api/v1/admin/settings/sora-s3/profiles -func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) { - var req CreateSoraS3ProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - if req.DefaultStorageQuotaBytes < 0 { - req.DefaultStorageQuotaBytes = 0 - } - if strings.TrimSpace(req.Name) == "" { - response.BadRequest(c, "Name is required") - return - } - if strings.TrimSpace(req.ProfileID) == "" { - response.BadRequest(c, "Profile ID is required") - return - } - if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil { - response.BadRequest(c, err.Error()) - return - } - - created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{ - ProfileID: req.ProfileID, - Name: req.Name, - Enabled: req.Enabled, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, - }, req.SetActive) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, toSoraS3ProfileDTO(*created)) -} - -// UpdateSoraS3Profile 更新 Sora S3 配置 -// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id -func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) { - profileID := strings.TrimSpace(c.Param("profile_id")) - if profileID == "" { - response.BadRequest(c, "Profile ID is required") - return - } - - var req UpdateSoraS3ProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - if req.DefaultStorageQuotaBytes < 0 { - req.DefaultStorageQuotaBytes = 0 - } - if strings.TrimSpace(req.Name) == "" { - response.BadRequest(c, "Name is required") - return - } - - existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - existing := findSoraS3ProfileByID(existingList.Items, profileID) - if existing == nil { - response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound) - return - } - if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { - response.BadRequest(c, err.Error()) - return - } - - updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{ - Name: req.Name, - Enabled: req.Enabled, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, - }) - if updateErr != nil { - response.ErrorFrom(c, updateErr) - return - } - - response.Success(c, toSoraS3ProfileDTO(*updated)) -} - -// DeleteSoraS3Profile 删除 Sora S3 配置 -// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id -func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) { - profileID := strings.TrimSpace(c.Param("profile_id")) - if profileID == "" { - response.BadRequest(c, "Profile ID is required") - return - } - if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"deleted": true}) -} - -// SetActiveSoraS3Profile 切换激活 Sora S3 配置 -// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate -func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) { - profileID := strings.TrimSpace(c.Param("profile_id")) - if profileID == "" { - response.BadRequest(c, "Profile ID is required") - return - } - active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, toSoraS3ProfileDTO(*active)) -} - -// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口) -// PUT /api/v1/admin/settings/sora-s3 -func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) { - var req UpdateSoraS3SettingsRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - - if req.DefaultStorageQuotaBytes < 0 { - req.DefaultStorageQuotaBytes = 0 - } - if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { - response.BadRequest(c, err.Error()) - return - } - - settings := &service.SoraS3Settings{ - Enabled: req.Enabled, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, - } - if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil { - response.ErrorFrom(c, err) - return - } - - updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, toSoraS3SettingsDTO(updatedSettings)) -} - -// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket) -// POST /api/v1/admin/settings/sora-s3/test -func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { - if h.soraS3Storage == nil { - response.Error(c, 500, "S3 存储服务未初始化") - return - } - - var req UpdateSoraS3SettingsRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - if !req.Enabled { - response.BadRequest(c, "S3 未启用,无法测试连接") - return - } - - if req.SecretAccessKey == "" { - if req.ProfileID != "" { - profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) - if err == nil { - profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID) - if profile != nil { - req.SecretAccessKey = profile.SecretAccessKey - } - } - } - if req.SecretAccessKey == "" { - existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) - if err == nil { - req.SecretAccessKey = existing.SecretAccessKey - } - } - } - - testCfg := &service.SoraS3Settings{ - Enabled: true, - Endpoint: req.Endpoint, - Region: req.Region, - Bucket: req.Bucket, - AccessKeyID: req.AccessKeyID, - SecretAccessKey: req.SecretAccessKey, - Prefix: req.Prefix, - ForcePathStyle: req.ForcePathStyle, - CDNURL: req.CDNURL, - } - if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil { - response.Error(c, 400, "S3 连接测试失败: "+err.Error()) - return - } - response.Success(c, gin.H{"message": "S3 连接成功"}) -} - // GetRectifierSettings 获取请求整流器配置 // GET /api/v1/admin/settings/rectifier func (h *SettingHandler) GetRectifierSettings(c *gin.Context) { diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 998308dd99..a357657e20 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -34,14 +34,13 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi // CreateUserRequest represents admin create user request type CreateUserRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` - Username string `json:"username"` - Notes string `json:"notes"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - AllowedGroups []int64 `json:"allowed_groups"` - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + Username string `json:"username"` + Notes string `json:"notes"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + AllowedGroups []int64 `json:"allowed_groups"` } // UpdateUserRequest represents admin update user request @@ -57,8 +56,7 @@ type UpdateUserRequest struct { AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 `json:"group_rates"` - SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` + GroupRates map[int64]*float64 `json:"group_rates"` } // UpdateBalanceRequest represents balance update request @@ -182,14 +180,13 @@ func (h *UserHandler) Create(c *gin.Context) { } user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - AllowedGroups: req.AllowedGroups, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + AllowedGroups: req.AllowedGroups, }) if err != nil { response.ErrorFrom(c, err) @@ -216,16 +213,15 @@ func (h *UserHandler) Update(c *gin.Context) { // 使用指针类型直接传递,nil 表示未提供该字段 user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - Status: req.Status, - AllowedGroups: req.AllowedGroups, - GroupRates: req.GroupRates, - SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + Status: req.Status, + AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index d9d657836d..2eab670e75 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -59,11 +59,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, - GroupRates: u.GroupRates, - SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, - SoraStorageUsedBytes: u.SoraStorageUsedBytes, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, } } @@ -172,14 +170,9 @@ func groupFromServiceBase(g *service.Group) Group { ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, - SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, AllowMessagesDispatch: g.AllowMessagesDispatch, RequireOAuthOnly: g.RequireOAuthOnly, RequirePrivacySet: g.RequirePrivacySet, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 47bab091d7..acc1129cd2 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -61,7 +61,6 @@ type SystemSettings struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` @@ -128,49 +127,10 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - SoraClientEnabled bool `json:"sora_client_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version"` } -// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) -type SoraS3Settings struct { - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段) -type SoraS3Profile struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - IsActive bool `json:"is_active"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// ListSoraS3ProfilesResponse Sora S3 配置列表响应 -type ListSoraS3ProfilesResponse struct { - ActiveProfileID string `json:"active_profile_id"` - Items []SoraS3Profile `json:"items"` -} - // OverloadCooldownSettings 529过载冷却配置 DTO type OverloadCooldownSettings struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 56b67c8c4d..82065deb72 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -26,9 +26,7 @@ type AdminUser struct { Notes string `json:"notes"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier - GroupRates map[int64]float64 `json:"group_rates,omitempty"` - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` - SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"` + GroupRates map[int64]float64 `json:"group_rates,omitempty"` } type APIKey struct { @@ -84,21 +82,12 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` - // Sora 按次计费配置 - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 无效请求兜底分组 FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` - // Sora 存储配额 - SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` - // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) AllowMessagesDispatch bool `json:"allow_messages_dispatch"` diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go index b120098875..a897bc4054 100644 --- a/backend/internal/handler/endpoint.go +++ b/backend/internal/handler/endpoint.go @@ -31,7 +31,7 @@ const ( // ────────────────────────────────────────────────────────── // NormalizeInboundEndpoint maps a raw request path (which may carry -// prefixes like /antigravity, /openai, /sora) to its canonical form. +// prefixes like /antigravity, /openai) to its canonical form. // // "/antigravity/v1/messages" → "/v1/messages" // "/v1/chat/completions" → "/v1/chat/completions" @@ -61,7 +61,7 @@ func NormalizeInboundEndpoint(path string) string { // such as /v1/responses/compact preserved from the raw URL). // - Anthropic → /v1/messages // - Gemini → /v1beta/models -// - Sora → /v1/chat/completions +// - Antigravity → /v1/messages (Claude) or gemini (Gemini) // - Antigravity routes may target either Claude or Gemini, so the // inbound endpoint is used to distinguish. func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { @@ -82,9 +82,6 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { case service.PlatformGemini: return EndpointGeminiModels - case service.PlatformSora: - return EndpointChatCompletions - case service.PlatformAntigravity: // Antigravity accounts serve both Claude and Gemini. if inbound == EndpointGeminiModels { diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go index a3767ac499..1519bc9e62 100644 --- a/backend/internal/handler/endpoint_test.go +++ b/backend/internal/handler/endpoint_test.go @@ -27,11 +27,10 @@ func TestNormalizeInboundEndpoint(t *testing.T) { {"/v1/responses", EndpointResponses}, {"/v1beta/models", EndpointGeminiModels}, - // Prefixed paths (antigravity, openai, sora). + // Prefixed paths (antigravity, openai). {"/antigravity/v1/messages", EndpointMessages}, {"/openai/v1/responses", EndpointResponses}, {"/openai/v1/responses/compact", EndpointResponses}, - {"/sora/v1/chat/completions", EndpointChatCompletions}, {"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels}, // Gin route patterns with wildcards. @@ -68,9 +67,6 @@ func TestDeriveUpstreamEndpoint(t *testing.T) { // Gemini. {"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels}, - // Sora. - {"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions}, - // OpenAI — always /v1/responses. {"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses}, {"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"}, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index dfc9fb88b4..59619d508a 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -859,14 +859,6 @@ func (h *GatewayHandler) Models(c *gin.Context) { platform = forcedPlatform } - if platform == service.PlatformSora { - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": service.DefaultSoraModels(h.cfg), - }) - return - } - // Get available models from account configurations (without platform filter) availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index ebf8d5f674..d4c349fb65 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -45,8 +45,6 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler - SoraGateway *SoraGatewayHandler - SoraClient *SoraClientHandler Setting *SettingHandler Totp *TotpHandler } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2c999cf13b..977c2301bb 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -54,7 +54,6 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - SoraClientEnabled: settings.SoraClientEnabled, BackendModeEnabled: settings.BackendModeEnabled, Version: h.version, }) diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go deleted file mode 100644 index 80acc83349..0000000000 --- a/backend/internal/handler/sora_client_handler.go +++ /dev/null @@ -1,979 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" -) - -const ( - // 上游模型缓存 TTL - modelCacheTTL = 1 * time.Hour // 上游获取成功 - modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) -) - -// SoraClientHandler 处理 Sora 客户端 API 请求。 -type SoraClientHandler struct { - genService *service.SoraGenerationService - quotaService *service.SoraQuotaService - s3Storage *service.SoraS3Storage - soraGatewayService *service.SoraGatewayService - gatewayService *service.GatewayService - mediaStorage *service.SoraMediaStorage - apiKeyService *service.APIKeyService - - // 上游模型缓存 - modelCacheMu sync.RWMutex - cachedFamilies []service.SoraModelFamily - modelCacheTime time.Time - modelCacheUpstream bool // 是否来自上游(决定 TTL) -} - -// NewSoraClientHandler 创建 Sora 客户端 Handler。 -func NewSoraClientHandler( - genService *service.SoraGenerationService, - quotaService *service.SoraQuotaService, - s3Storage *service.SoraS3Storage, - soraGatewayService *service.SoraGatewayService, - gatewayService *service.GatewayService, - mediaStorage *service.SoraMediaStorage, - apiKeyService *service.APIKeyService, -) *SoraClientHandler { - return &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - s3Storage: s3Storage, - soraGatewayService: soraGatewayService, - gatewayService: gatewayService, - mediaStorage: mediaStorage, - apiKeyService: apiKeyService, - } -} - -// GenerateRequest 生成请求。 -type GenerateRequest struct { - Model string `json:"model" binding:"required"` - Prompt string `json:"prompt" binding:"required"` - MediaType string `json:"media_type"` // video / image,默认 video - VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) - ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) - APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID -} - -// Generate 异步生成 — 创建 pending 记录后立即返回。 -// POST /api/v1/sora/generate -func (h *SoraClientHandler) Generate(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - var req GenerateRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) - return - } - - if req.MediaType == "" { - req.MediaType = "video" - } - req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) - - // 并发数检查(最多 3 个) - activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) - if err != nil { - response.ErrorFrom(c, err) - return - } - if activeCount >= 3 { - response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") - return - } - - // 配额检查(粗略检查,实际文件大小在上传后才知道) - if h.quotaService != nil { - if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") - return - } - response.Error(c, http.StatusForbidden, err.Error()) - return - } - } - - // 获取 API Key ID 和 Group ID - var apiKeyID *int64 - var groupID *int64 - - if req.APIKeyID != nil && h.apiKeyService != nil { - // 前端传递了 api_key_id,需要校验 - apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) - if err != nil { - response.Error(c, http.StatusBadRequest, "API Key 不存在") - return - } - if apiKey.UserID != userID { - response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") - return - } - if apiKey.Status != service.StatusAPIKeyActive { - response.Error(c, http.StatusForbidden, "API Key 不可用") - return - } - apiKeyID = &apiKey.ID - groupID = apiKey.GroupID - } else if id, ok := c.Get("api_key_id"); ok { - // 兼容 API Key 认证路径(/sora/v1/ 网关路由) - if v, ok := id.(int64); ok { - apiKeyID = &v - } - } - - gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) - if err != nil { - if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { - response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") - return - } - response.ErrorFrom(c, err) - return - } - - // 启动后台异步生成 goroutine - go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) - - response.Success(c, gin.H{ - "generation_id": gen.ID, - "status": gen.Status, - }) -} - -// processGeneration 后台异步执行 Sora 生成任务。 -// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 -func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - - // 标记为生成中 - if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { - if errors.Is(err, service.ErrSoraGenerationStateConflict) { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) - return - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) - return - } - - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", - genID, - userID, - groupIDForLog(groupID), - model, - mediaType, - videoCount, - strings.TrimSpace(imageInput) != "", - len(strings.TrimSpace(prompt)), - ) - - // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 - if groupID == nil { - ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) - } - - if h.gatewayService == nil { - _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") - return - } - - // 选择 Sora 账号 - account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) - if err != nil { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", - genID, - userID, - groupIDForLog(groupID), - model, - err, - ) - _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) - return - } - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", - genID, - userID, - groupIDForLog(groupID), - model, - account.ID, - account.Name, - account.Platform, - account.Type, - ) - - // 构建 chat completions 请求体(非流式) - body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) - - if h.soraGatewayService == nil { - _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") - return - } - - // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) - recorder := httptest.NewRecorder() - mockGinCtx, _ := gin.CreateTestContext(recorder) - mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) - - // 调用 Forward(非流式) - result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) - if err != nil { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", - genID, - account.ID, - model, - recorder.Code, - trimForLog(recorder.Body.String(), 400), - err, - ) - // 检查是否已取消 - gen, _ := h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - return - } - _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) - return - } - - // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) - mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) - if mediaURL == "" { - logger.LegacyPrintf( - "handler.sora_client", - "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", - genID, - account.ID, - model, - recorder.Code, - trimForLog(recorder.Body.String(), 400), - ) - _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") - return - } - - // 检查任务是否已被取消 - gen, _ := h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) - return - } - - // 三层降级存储:S3 → 本地 → 上游临时 URL - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) - - usageAdded := false - if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { - if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") - return - } - _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) - return - } - usageAdded = true - } - - // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 - gen, _ = h.genService.GetByID(ctx, genID, userID) - if gen != nil && gen.Status == service.SoraGenStatusCancelled { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) - } - return - } - - // 标记完成 - if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { - if errors.Is(err, service.ErrSoraGenerationStateConflict) { - h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) - } - return - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) - return - } - - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) -} - -// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 -func (h *SoraClientHandler) storeMediaWithDegradation( - ctx context.Context, userID int64, mediaType string, - mediaURL string, mediaURLs []string, -) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { - urls := mediaURLs - if len(urls) == 0 { - urls = []string{mediaURL} - } - - // 第一层:尝试 S3 - if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { - keys := make([]string, 0, len(urls)) - var totalSize int64 - allOK := true - for _, u := range urls { - key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) - allOK = false - // 清理已上传的文件 - if len(keys) > 0 { - _ = h.s3Storage.DeleteObjects(ctx, keys) - } - break - } - keys = append(keys, key) - totalSize += size - } - if allOK && len(keys) > 0 { - accessURLs := make([]string, 0, len(keys)) - for _, key := range keys { - accessURL, err := h.s3Storage.GetAccessURL(ctx, key) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) - _ = h.s3Storage.DeleteObjects(ctx, keys) - allOK = false - break - } - accessURLs = append(accessURLs, accessURL) - } - if allOK && len(accessURLs) > 0 { - return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize - } - } - } - - // 第二层:尝试本地存储 - if h.mediaStorage != nil && h.mediaStorage.Enabled() { - storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) - if err == nil && len(storedPaths) > 0 { - firstPath := storedPaths[0] - totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) - if sizeErr != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) - } - return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize - } - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) - } - - // 第三层:保留上游临时 URL - return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 -} - -// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 -func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { - body := map[string]any{ - "model": model, - "messages": []map[string]string{ - {"role": "user", "content": prompt}, - }, - "stream": false, - } - if imageInput != "" { - body["image_input"] = imageInput - } - if videoCount > 1 { - body["video_count"] = videoCount - } - b, _ := json.Marshal(body) - return b -} - -func normalizeVideoCount(mediaType string, videoCount int) int { - if mediaType != "video" { - return 1 - } - if videoCount <= 0 { - return 1 - } - if videoCount > 3 { - return 3 - } - return videoCount -} - -// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 -// OAuth 路径:ForwardResult.MediaURL 已填充。 -// APIKey 路径:需从响应体解析 media_url / media_urls 字段。 -func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { - // 优先从 ForwardResult 获取(OAuth 路径) - if result != nil && result.MediaURL != "" { - // 尝试从响应体获取完整 URL 列表 - if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { - return urls[0], urls - } - return result.MediaURL, []string{result.MediaURL} - } - - // 从响应体解析(APIKey 路径) - if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { - return urls[0], urls - } - - return "", nil -} - -// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 -func parseMediaURLsFromBody(body []byte) []string { - if len(body) == 0 { - return nil - } - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return nil - } - - // 优先 media_urls(多图数组) - if rawURLs, ok := resp["media_urls"]; ok { - if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { - urls := make([]string, 0, len(arr)) - for _, item := range arr { - if s, ok := item.(string); ok && s != "" { - urls = append(urls, s) - } - } - if len(urls) > 0 { - return urls - } - } - } - - // 回退到 media_url(单个 URL) - if url, ok := resp["media_url"].(string); ok && url != "" { - return []string{url} - } - - return nil -} - -// ListGenerations 查询生成记录列表。 -// GET /api/v1/sora/generations -func (h *SoraClientHandler) ListGenerations(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) - - params := service.SoraGenerationListParams{ - UserID: userID, - Status: c.Query("status"), - StorageType: c.Query("storage_type"), - MediaType: c.Query("media_type"), - Page: page, - PageSize: pageSize, - } - - gens, total, err := h.genService.List(c.Request.Context(), params) - if err != nil { - response.ErrorFrom(c, err) - return - } - - // 为 S3 记录动态生成预签名 URL - for _, gen := range gens { - _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) - } - - response.Success(c, gin.H{ - "data": gens, - "total": total, - "page": page, - }) -} - -// GetGeneration 查询生成记录详情。 -// GET /api/v1/sora/generations/:id -func (h *SoraClientHandler) GetGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) - response.Success(c, gen) -} - -// DeleteGeneration 删除生成记录。 -// DELETE /api/v1/sora/generations/:id -func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 - if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { - paths := gen.MediaURLs - if len(paths) == 0 && gen.MediaURL != "" { - paths = []string{gen.MediaURL} - } - if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) - } - } - - if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - response.Success(c, gin.H{"message": "已删除"}) -} - -// GetQuota 查询用户存储配额。 -// GET /api/v1/sora/quota -func (h *SoraClientHandler) GetQuota(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - if h.quotaService == nil { - response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) - return - } - - quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, quota) -} - -// CancelGeneration 取消生成任务。 -// POST /api/v1/sora/generations/:id/cancel -func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - // 权限校验 - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - _ = gen - - if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { - if errors.Is(err, service.ErrSoraGenerationNotActive) { - response.Error(c, http.StatusConflict, "任务已结束,无法取消") - return - } - response.Error(c, http.StatusBadRequest, err.Error()) - return - } - - response.Success(c, gin.H{"message": "已取消"}) -} - -// SaveToStorage 手动保存 upstream 记录到 S3。 -// POST /api/v1/sora/generations/:id/save -func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { - userID := getUserIDFromContext(c) - if userID == 0 { - response.Error(c, http.StatusUnauthorized, "未登录") - return - } - - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.Error(c, http.StatusBadRequest, "无效的 ID") - return - } - - gen, err := h.genService.GetByID(c.Request.Context(), id, userID) - if err != nil { - response.Error(c, http.StatusNotFound, err.Error()) - return - } - - if gen.StorageType != service.SoraStorageTypeUpstream { - response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") - return - } - if gen.MediaURL == "" { - response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") - return - } - - if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { - response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") - return - } - - sourceURLs := gen.MediaURLs - if len(sourceURLs) == 0 && gen.MediaURL != "" { - sourceURLs = []string{gen.MediaURL} - } - if len(sourceURLs) == 0 { - response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") - return - } - - uploadedKeys := make([]string, 0, len(sourceURLs)) - accessURLs := make([]string, 0, len(sourceURLs)) - var totalSize int64 - - for _, sourceURL := range sourceURLs { - objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) - if uploadErr != nil { - if len(uploadedKeys) > 0 { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - } - var upstreamErr *service.UpstreamDownloadError - if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { - response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") - return - } - response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) - return - } - accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) - if err != nil { - uploadedKeys = append(uploadedKeys, objectKey) - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) - return - } - uploadedKeys = append(uploadedKeys, objectKey) - accessURLs = append(accessURLs, accessURL) - totalSize += fileSize - } - - usageAdded := false - if totalSize > 0 && h.quotaService != nil { - if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - var quotaErr *service.QuotaExceededError - if errors.As(err, "aErr) { - response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") - return - } - response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) - return - } - usageAdded = true - } - - if err := h.genService.UpdateStorageForCompleted( - c.Request.Context(), - id, - accessURLs[0], - accessURLs, - service.SoraStorageTypeS3, - uploadedKeys, - totalSize, - ); err != nil { - _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) - if usageAdded && h.quotaService != nil { - _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) - } - response.ErrorFrom(c, err) - return - } - - response.Success(c, gin.H{ - "message": "已保存到 S3", - "object_key": uploadedKeys[0], - "object_keys": uploadedKeys, - }) -} - -// GetStorageStatus 返回存储状态。 -// GET /api/v1/sora/storage-status -func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { - s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) - s3Healthy := false - if s3Enabled { - s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) - } - localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() - response.Success(c, gin.H{ - "s3_enabled": s3Enabled, - "s3_healthy": s3Healthy, - "local_enabled": localEnabled, - }) -} - -func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { - switch storageType { - case service.SoraStorageTypeS3: - if h.s3Storage != nil && len(s3Keys) > 0 { - if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) - } - } - case service.SoraStorageTypeLocal: - if h.mediaStorage != nil && len(localPaths) > 0 { - if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) - } - } - } -} - -// getUserIDFromContext 从 gin 上下文中提取用户 ID。 -func getUserIDFromContext(c *gin.Context) int64 { - if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { - return subject.UserID - } - - if id, ok := c.Get("user_id"); ok { - switch v := id.(type) { - case int64: - return v - case float64: - return int64(v) - case string: - n, _ := strconv.ParseInt(v, 10, 64) - return n - } - } - // 尝试从 JWT claims 获取 - if id, ok := c.Get("userID"); ok { - if v, ok := id.(int64); ok { - return v - } - } - return 0 -} - -func groupIDForLog(groupID *int64) int64 { - if groupID == nil { - return 0 - } - return *groupID -} - -func trimForLog(raw string, maxLen int) string { - trimmed := strings.TrimSpace(raw) - if maxLen <= 0 || len(trimmed) <= maxLen { - return trimmed - } - return trimmed[:maxLen] + "...(truncated)" -} - -// GetModels 获取可用 Sora 模型家族列表。 -// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 -// GET /api/v1/sora/models -func (h *SoraClientHandler) GetModels(c *gin.Context) { - families := h.getModelFamilies(c.Request.Context()) - response.Success(c, families) -} - -// getModelFamilies 获取模型家族列表(带缓存)。 -func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { - // 读锁检查缓存 - h.modelCacheMu.RLock() - ttl := modelCacheTTL - if !h.modelCacheUpstream { - ttl = modelCacheFailedTTL - } - if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { - families := h.cachedFamilies - h.modelCacheMu.RUnlock() - return families - } - h.modelCacheMu.RUnlock() - - // 写锁更新缓存 - h.modelCacheMu.Lock() - defer h.modelCacheMu.Unlock() - - // double-check - ttl = modelCacheTTL - if !h.modelCacheUpstream { - ttl = modelCacheFailedTTL - } - if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { - return h.cachedFamilies - } - - // 尝试从上游获取 - families, err := h.fetchUpstreamModels(ctx) - if err != nil { - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) - families = service.BuildSoraModelFamilies() - h.cachedFamilies = families - h.modelCacheTime = time.Now() - h.modelCacheUpstream = false - return families - } - - logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) - h.cachedFamilies = families - h.modelCacheTime = time.Now() - h.modelCacheUpstream = true - return families -} - -// fetchUpstreamModels 从上游 Sora API 获取模型列表。 -func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { - if h.gatewayService == nil { - return nil, fmt.Errorf("gatewayService 未初始化") - } - - // 设置 ForcePlatform 用于 Sora 账号选择 - ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) - - // 选择一个 Sora 账号 - account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") - if err != nil { - return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) - } - - // 仅支持 API Key 类型账号 - if account.Type != service.AccountTypeAPIKey { - return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) - } - - apiKey := account.GetCredential("api_key") - if apiKey == "" { - return nil, fmt.Errorf("账号缺少 api_key") - } - - baseURL := account.GetBaseURL() - if baseURL == "" { - return nil, fmt.Errorf("账号缺少 base_url") - } - - // 构建上游模型列表请求 - modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" - - reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) - if err != nil { - return nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+apiKey) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("请求上游失败: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) - } - - // 解析 OpenAI 格式的模型列表 - var modelsResp struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` - } - if err := json.Unmarshal(body, &modelsResp); err != nil { - return nil, fmt.Errorf("解析响应失败: %w", err) - } - - if len(modelsResp.Data) == 0 { - return nil, fmt.Errorf("上游返回空模型列表") - } - - // 提取模型 ID - modelIDs := make([]string, 0, len(modelsResp.Data)) - for _, m := range modelsResp.Data { - modelIDs = append(modelIDs, m.ID) - } - - // 转换为模型家族 - families := service.BuildSoraModelFamiliesFromIDs(modelIDs) - if len(families) == 0 { - return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") - } - - return families, nil -} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go deleted file mode 100644 index 5705578660..0000000000 --- a/backend/internal/handler/sora_client_handler_test.go +++ /dev/null @@ -1,3179 +0,0 @@ -//go:build unit - -package handler - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "os" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -func init() { - gin.SetMode(gin.TestMode) -} - -// ==================== Stub: SoraGenerationRepository ==================== - -var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) - -type stubSoraGenRepo struct { - gens map[int64]*service.SoraGeneration - nextID int64 - createErr error - getErr error - updateErr error - deleteErr error - listErr error - countErr error - countValue int64 - - // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 - updateCallCount *int32 - updateFailAfterN int32 - - // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus - getByIDCallCount int32 - getByIDOverrideAfterN int32 // 0 = 不覆盖 - getByIDOverrideStatus string -} - -func newStubSoraGenRepo() *stubSoraGenRepo { - return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} -} - -func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { - if r.createErr != nil { - return r.createErr - } - gen.ID = r.nextID - r.nextID++ - r.gens[gen.ID] = gen - return nil -} -func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { - if r.getErr != nil { - return nil, r.getErr - } - gen, ok := r.gens[id] - if !ok { - return nil, fmt.Errorf("not found") - } - // 条件性状态覆盖:模拟外部取消等场景 - if r.getByIDOverrideAfterN > 0 { - n := atomic.AddInt32(&r.getByIDCallCount, 1) - if n > r.getByIDOverrideAfterN { - cp := *gen - cp.Status = r.getByIDOverrideStatus - return &cp, nil - } - } - return gen, nil -} -func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { - // 条件性失败:前 N 次成功,之后失败 - if r.updateCallCount != nil { - n := atomic.AddInt32(r.updateCallCount, 1) - if n > r.updateFailAfterN { - return fmt.Errorf("conditional update error (call #%d)", n) - } - } - if r.updateErr != nil { - return r.updateErr - } - r.gens[gen.ID] = gen - return nil -} -func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { - if r.deleteErr != nil { - return r.deleteErr - } - delete(r.gens, id) - return nil -} -func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { - if r.listErr != nil { - return nil, 0, r.listErr - } - var result []*service.SoraGeneration - for _, gen := range r.gens { - if gen.UserID != params.UserID { - continue - } - result = append(result, gen) - } - return result, int64(len(result)), nil -} -func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { - if r.countErr != nil { - return 0, r.countErr - } - return r.countValue, nil -} - -// ==================== 辅助函数 ==================== - -func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { - genService := service.NewSoraGenerationService(repo, nil, nil) - return &SoraClientHandler{genService: genService} -} - -func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - if body != "" { - c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) - c.Request.Header.Set("Content-Type", "application/json") - } else { - c.Request = httptest.NewRequest(method, path, nil) - } - if userID > 0 { - c.Set("user_id", userID) - } - return c, rec -} - -func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { - t.Helper() - var resp map[string]any - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - return resp -} - -// ==================== 纯函数测试: buildAsyncRequestBody ==================== - -func TestBuildAsyncRequestBody(t *testing.T) { - body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "sora2-landscape-10s", parsed["model"]) - require.Equal(t, false, parsed["stream"]) - - msgs := parsed["messages"].([]any) - require.Len(t, msgs, 1) - msg := msgs[0].(map[string]any) - require.Equal(t, "user", msg["role"]) - require.Equal(t, "一只猫在跳舞", msg["content"]) -} - -func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { - body := buildAsyncRequestBody("gpt-image", "", "", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "gpt-image", parsed["model"]) - msgs := parsed["messages"].([]any) - msg := msgs[0].(map[string]any) - require.Equal(t, "", msg["content"]) -} - -func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { - body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) -} - -func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { - body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) - var parsed map[string]any - require.NoError(t, json.Unmarshal(body, &parsed)) - require.Equal(t, float64(3), parsed["video_count"]) -} - -func TestNormalizeVideoCount(t *testing.T) { - require.Equal(t, 1, normalizeVideoCount("video", 0)) - require.Equal(t, 2, normalizeVideoCount("video", 2)) - require.Equal(t, 3, normalizeVideoCount("video", 5)) - require.Equal(t, 1, normalizeVideoCount("image", 3)) -} - -// ==================== 纯函数测试: parseMediaURLsFromBody ==================== - -func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) - require.Equal(t, []string{"https://a.com/video.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody(nil)) - require.Nil(t, parseMediaURLsFromBody([]byte{})) -} - -func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) -} - -func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) -} - -func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) -} - -func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) -} - -func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { - body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` - urls := parseMediaURLsFromBody([]byte(body)) - require.Len(t, urls, 2) - require.Equal(t, "https://multi.com/a.mp4", urls[0]) -} - -func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { - urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) -} - -func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) -} - -func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { - // media_urls 不是 string 数组 - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) -} - -func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { - require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) -} - -// ==================== 纯函数测试: extractMediaURLsFromResult ==================== - -func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { - result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(result, recorder) - require.Equal(t, "https://oauth.com/video.mp4", url) - require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) -} - -func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { - result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} - recorder := httptest.NewRecorder() - _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) - url, urls := extractMediaURLsFromResult(result, recorder) - require.Equal(t, "https://body.com/1.mp4", url) - require.Len(t, urls, 2) -} - -func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { - recorder := httptest.NewRecorder() - _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) - url, urls := extractMediaURLsFromResult(nil, recorder) - require.Equal(t, "https://upstream.com/video.mp4", url) - require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) -} - -func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(nil, recorder) - require.Empty(t, url) - require.Nil(t, urls) -} - -func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { - result := &service.ForwardResult{MediaURL: ""} - recorder := httptest.NewRecorder() - url, urls := extractMediaURLsFromResult(result, recorder) - require.Empty(t, url) - require.Nil(t, urls) -} - -// ==================== getUserIDFromContext ==================== - -func TestGetUserIDFromContext_Int64(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", int64(42)) - require.Equal(t, int64(42), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_AuthSubject(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) - require.Equal(t, int64(777), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_Float64(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", float64(99)) - require.Equal(t, int64(99), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_String(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", "123") - require.Equal(t, int64(123), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("userID", int64(55)) - require.Equal(t, int64(55), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_NoID(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - require.Equal(t, int64(0), getUserIDFromContext(c)) -} - -func TestGetUserIDFromContext_InvalidString(t *testing.T) { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("user_id", "not-a-number") - require.Equal(t, int64(0), getUserIDFromContext(c)) -} - -// ==================== Handler: Generate ==================== - -func TestGenerate_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) - h.Generate(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGenerate_BadRequest_MissingModel(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGenerate_TooManyRequests(t *testing.T) { - repo := newStubSoraGenRepo() - repo.countValue = 3 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) -} - -func TestGenerate_CountError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.countErr = fmt.Errorf("db error") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestGenerate_Success(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.NotZero(t, data["generation_id"]) - require.Equal(t, "pending", data["status"]) -} - -func TestGenerate_DefaultMediaType(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "video", repo.gens[1].MediaType) -} - -func TestGenerate_ImageMediaType(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "image", repo.gens[1].MediaType) -} - -func TestGenerate_CreatePendingError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.createErr = fmt.Errorf("create failed") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestGenerate_APIKeyInContext(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - c.Set("api_key_id", int64(42)) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_NoAPIKeyInContext(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_ConcurrencyBoundary(t *testing.T) { - // activeCount == 2 应该允许 - repo := newStubSoraGenRepo() - repo.countValue = 2 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== Handler: ListGenerations ==================== - -func TestListGenerations_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) - h.ListGenerations(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestListGenerations_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} - repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} - repo.nextID = 3 - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - items := data["data"].([]any) - require.Len(t, items, 2) - require.Equal(t, float64(2), data["total"]) -} - -func TestListGenerations_ListError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.listErr = fmt.Errorf("db error") - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -func TestListGenerations_DefaultPagination(t *testing.T) { - repo := newStubSoraGenRepo() - h := newTestSoraClientHandler(repo) - // 不传分页参数,应默认 page=1 page_size=20 - c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) - h.ListGenerations(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, float64(1), data["page"]) -} - -// ==================== Handler: GetGeneration ==================== - -func TestGetGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGetGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.GetGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestGetGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.GetGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestGetGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestGetGeneration_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.GetGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, float64(1), data["id"]) -} - -// ==================== Handler: DeleteGeneration ==================== - -func TestDeleteGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestDeleteGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestDeleteGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestDeleteGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestDeleteGeneration_Success(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - _, exists := repo.gens[1] - require.False(t, exists) -} - -// ==================== Handler: CancelGeneration ==================== - -func TestCancelGeneration_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestCancelGeneration_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestCancelGeneration_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestCancelGeneration_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestCancelGeneration_Pending(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -func TestCancelGeneration_Generating(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -func TestCancelGeneration_Completed(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -func TestCancelGeneration_Failed(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -func TestCancelGeneration_Cancelled(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} - -// ==================== Handler: GetQuota ==================== - -func TestGetQuota_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) - h.GetQuota(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestGetQuota_NilQuotaService(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) - h.GetQuota(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, "unlimited", data["source"]) -} - -// ==================== Handler: GetModels ==================== - -func TestGetModels(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) - h.GetModels(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].([]any) - require.Len(t, data, 4) - // 验证类型分布 - videoCount, imageCount := 0, 0 - for _, item := range data { - m := item.(map[string]any) - if m["type"] == "video" { - videoCount++ - } else if m["type"] == "image" { - imageCount++ - } - } - require.Equal(t, 3, videoCount) - require.Equal(t, 1, imageCount) -} - -// ==================== Handler: GetStorageStatus ==================== - -func TestGetStorageStatus_NilS3(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, false, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) - require.Equal(t, false, data["local_enabled"]) -} - -func TestGetStorageStatus_LocalEnabled(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, false, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) - require.Equal(t, true, data["local_enabled"]) -} - -// ==================== Handler: SaveToStorage ==================== - -func TestSaveToStorage_Unauthorized(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusUnauthorized, rec.Code) -} - -func TestSaveToStorage_InvalidID(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "abc"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_NotFound(t *testing.T) { - h := newTestSoraClientHandler(newStubSoraGenRepo()) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "999"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -func TestSaveToStorage_NotUpstream(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_EmptyMediaURL(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -func TestSaveToStorage_S3Nil(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusServiceUnavailable, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "云存储") -} - -func TestSaveToStorage_WrongUser(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} - h := newTestSoraClientHandler(repo) - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -// ==================== storeMediaWithDegradation — nil guard 路径 ==================== - -func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { - h := &SoraClientHandler{} - url, urls, storageType, keys, size := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://upstream.com/v.mp4", url) - require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) - require.Nil(t, keys) - require.Equal(t, int64(0), size) -} - -func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { - h := &SoraClientHandler{} - url, urls, storageType, keys, size := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://a.com/1.mp4", url) - require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) - require.Nil(t, keys) - require.Equal(t, int64(0), size) -} - -func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { - h := &SoraClientHandler{} - url, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Equal(t, "https://upstream.com/v.mp4", url) -} - -// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== - -var _ service.UserRepository = (*stubUserRepoForHandler)(nil) - -type stubUserRepoForHandler struct { - users map[int64]*service.User - updateErr error -} - -func newStubUserRepoForHandler() *stubUserRepoForHandler { - return &stubUserRepoForHandler{users: make(map[int64]*service.User)} -} - -func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { - if u, ok := r.users[id]; ok { - return u, nil - } - return nil, fmt.Errorf("user not found") -} -func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { - if r.updateErr != nil { - return r.updateErr - } - r.users[user.ID] = user - return nil -} -func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } -func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { - return nil, nil -} -func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { - return nil, nil -} -func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } -func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } -func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } -func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { - return false, nil -} -func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - return nil -} -func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error { - return nil -} - -// ==================== NewSoraClientHandler ==================== - -func TestNewSoraClientHandler(t *testing.T) { - h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) - require.NotNil(t, h) -} - -func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { - h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) - require.NotNil(t, h) - require.Nil(t, h.apiKeyService) -} - -// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== - -var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) - -type stubAPIKeyRepoForHandler struct { - keys map[int64]*service.APIKey - getErr error -} - -func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { - return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} -} - -func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { - if r.getErr != nil { - return nil, r.getErr - } - if k, ok := r.keys[id]; ok { - return k, nil - } - return nil, fmt.Errorf("api key not found: %d", id) -} -func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } -func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { - return "", 0, nil -} -func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } -func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { - return false, nil -} -func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) { - var updated int64 - for id, key := range r.keys { - if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID { - continue - } - clone := *key - gid := newGroupID - clone.GroupID = &gid - r.keys[id] = &clone - updated++ - } - return updated, nil -} -func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { - return nil, nil -} -func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { - return 0, nil -} -func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error { - return nil -} -func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) { - return nil, nil -} - -// newTestAPIKeyService 创建测试用的 APIKeyService -func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { - return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) -} - -// ==================== Generate: API Key 校验(前端传递 api_key_id)==================== - -func TestGenerate_WithAPIKeyID_Success(t *testing.T) { - // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - groupID := int64(5) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: &groupID, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.NotZero(t, data["generation_id"]) - - // 验证 api_key_id 已关联到生成记录 - gen := repo.gens[1] - require.NotNil(t, gen.APIKeyID) - require.Equal(t, int64(42), *gen.APIKeyID) -} - -func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { - // 前端传递不存在的 api_key_id → 400 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不存在") -} - -func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { - // 前端传递别人的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 999, // 属于 user 999 - Status: service.StatusAPIKeyActive, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不属于") -} - -func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { - // 前端传递已禁用的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyDisabled, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "不可用") -} - -func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { - // 前端传递配额耗尽的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyQuotaExhausted, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { - // 前端传递已过期的 api_key_id → 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyExpired, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { - // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - h := &SoraClientHandler{genService: genService} // apiKeyService = nil - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { - // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: nil, // 无分组 - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { - // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - require.Nil(t, repo.gens[1].APIKeyID) -} - -func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { - // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - groupID := int64(10) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyRepo.keys[42] = &service.APIKey{ - ID: 42, - UserID: 1, - Status: service.StatusAPIKeyActive, - GroupID: &groupID, - } - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) - c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(42), *repo.gens[1].APIKeyID) -} - -func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { - // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - c.Set("api_key_id", int64(99)) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) - // 应使用 context 中的 api_key_id=99 - require.NotNil(t, repo.gens[1].APIKeyID) - require.Equal(t, int64(99), *repo.gens[1].APIKeyID) -} - -func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { - // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - apiKeyRepo := newStubAPIKeyRepoForHandler() - apiKeyService := newTestAPIKeyService(apiKeyRepo) - - h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} - // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 - // api_key_id=0 不存在 → 400 - c, rec := makeGinContext("POST", "/api/v1/sora/generate", - `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) - h.Generate(c) - require.Equal(t, http.StatusBadRequest, rec.Code) -} - -// ==================== processGeneration: groupID 传递与 ForcePlatform ==================== - -func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { - // groupID 不为 nil → 不设置 ForcePlatform - // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - gid := int64(5) - h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { - // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { - // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled - require.Equal(t, "cancelled", repo.gens[1].Status) -} - -// ==================== GenerateRequest JSON 解析 ==================== - -func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { - // 验证 api_key_id 在 JSON 中正确解析为 *int64 - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) - require.NoError(t, err) - require.NotNil(t, req.APIKeyID) - require.Equal(t, int64(42), *req.APIKeyID) -} - -func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { - // 不传 api_key_id → 解析后为 nil - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) - require.NoError(t, err) - require.Nil(t, req.APIKeyID) -} - -func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { - // api_key_id: null → 解析后为 nil - var req GenerateRequest - err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) - require.NoError(t, err) - require.Nil(t, req.APIKeyID) -} - -func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { - // 全字段解析 - var req GenerateRequest - err := json.Unmarshal([]byte(`{ - "model":"sora2-landscape-10s", - "prompt":"test prompt", - "media_type":"video", - "video_count":2, - "image_input":"data:image/png;base64,abc", - "api_key_id":100 - }`), &req) - require.NoError(t, err) - require.Equal(t, "sora2-landscape-10s", req.Model) - require.Equal(t, "test prompt", req.Prompt) - require.Equal(t, "video", req.MediaType) - require.Equal(t, 2, req.VideoCount) - require.Equal(t, "data:image/png;base64,abc", req.ImageInput) - require.NotNil(t, req.APIKeyID) - require.Equal(t, int64(100), *req.APIKeyID) -} - -func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { - // api_key_id 为 nil 时 JSON 序列化应省略 - req := GenerateRequest{Model: "sora2", Prompt: "test"} - b, err := json.Marshal(req) - require.NoError(t, err) - var parsed map[string]any - require.NoError(t, json.Unmarshal(b, &parsed)) - _, hasAPIKeyID := parsed["api_key_id"] - require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") -} - -func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { - // api_key_id 不为 nil 时 JSON 序列化应包含 - id := int64(42) - req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} - b, err := json.Marshal(req) - require.NoError(t, err) - var parsed map[string]any - require.NoError(t, json.Unmarshal(b, &parsed)) - require.Equal(t, float64(42), parsed["api_key_id"]) -} - -// ==================== GetQuota: 有配额服务 ==================== - -func TestGetQuota_WithQuotaService_Success(t *testing.T) { - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 10 * 1024 * 1024, - SoraStorageUsedBytes: 3 * 1024 * 1024, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) - h.GetQuota(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, "user", data["source"]) - require.Equal(t, float64(10*1024*1024), data["quota_bytes"]) - require.Equal(t, float64(3*1024*1024), data["used_bytes"]) -} - -func TestGetQuota_WithQuotaService_Error(t *testing.T) { - // 用户不存在时 GetQuota 返回错误 - userRepo := newStubUserRepoForHandler() - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) - h.GetQuota(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== Generate: 配额检查 ==================== - -func TestGenerate_QuotaCheckFailed(t *testing.T) { - // 配额超限时返回 429 - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 1024, - SoraStorageUsedBytes: 1025, // 已超限 - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) -} - -func TestGenerate_QuotaCheckPassed(t *testing.T) { - // 配额充足时允许生成 - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 10 * 1024 * 1024, - SoraStorageUsedBytes: 0, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{ - genService: genService, - quotaService: quotaService, - } - - c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) - h.Generate(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== - -var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) - -type stubSettingRepoForHandler struct { - values map[string]string -} - -func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { - if values == nil { - values = make(map[string]string) - } - return &stubSettingRepoForHandler{values: values} -} - -func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { - if v, ok := r.values[key]; ok { - return &service.Setting{Key: key, Value: v}, nil - } - return nil, service.ErrSettingNotFound -} -func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { - if v, ok := r.values[key]; ok { - return v, nil - } - return "", service.ErrSettingNotFound -} -func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { - r.values[key] = value - return nil -} -func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { - result := make(map[string]string) - for _, k := range keys { - if v, ok := r.values[k]; ok { - result[k] = v - } - } - return result, nil -} -func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { - for k, v := range settings { - r.values[k] = v - } - return nil -} -func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { - return r.values, nil -} -func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { - delete(r.values, key) - return nil -} - -// ==================== S3 / MediaStorage 辅助函数 ==================== - -// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 -func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { - settingRepo := newStubSettingRepoForHandler(map[string]string{ - "sora_s3_enabled": "true", - "sora_s3_endpoint": endpoint, - "sora_s3_region": "us-east-1", - "sora_s3_bucket": "test-bucket", - "sora_s3_access_key_id": "AKIATEST", - "sora_s3_secret_access_key": "test-secret", - "sora_s3_prefix": "sora", - "sora_s3_force_path_style": "true", - }) - settingService := service.NewSettingService(settingRepo, &config.Config{}) - return service.NewSoraS3Storage(settingService) -} - -// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 -func newFakeSourceServer() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "video/mp4") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("fake video data for test")) - })) -} - -// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 -// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 -func newFakeS3Server(mode string) *httptest.Server { - var counter atomic.Int32 - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.Copy(io.Discard, r.Body) - _ = r.Body.Close() - - switch mode { - case "ok": - w.Header().Set("ETag", `"test-etag"`) - w.WriteHeader(http.StatusOK) - case "fail": - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`AccessDenied`)) - case "fail-second": - n := counter.Add(1) - if n <= 1 { - w.Header().Set("ETag", `"test-etag"`) - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`AccessDenied`)) - } - } - })) -} - -// ==================== processGeneration 直接调用测试 ==================== - -func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - repo.updateErr = fmt.Errorf("db error") - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" - // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed - // 因此 ErrorMessage 为空(证明未调用 MarkFailed) - require.Equal(t, "generating", repo.gens[1].Status) - require.Empty(t, repo.gens[1].ErrorMessage) -} - -func TestProcessGeneration_GatewayServiceNil(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - // gatewayService 未设置 → MarkFailed - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") -} - -// ==================== storeMediaWithDegradation: S3 路径 ==================== - -func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeS3, storageType) - require.Len(t, s3Keys, 1) - require.NotEmpty(t, s3Keys[0]) - require.Len(t, storedURLs, 1) - require.Equal(t, storedURL, storedURLs[0]) - require.Contains(t, storedURL, fakeS3.URL) - require.Contains(t, storedURL, "/test-bucket/") - require.Greater(t, fileSize, int64(0)) -} - -func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} - storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, - ) - require.Equal(t, service.SoraStorageTypeS3, storageType) - require.Len(t, s3Keys, 2) - require.Len(t, storedURLs, 2) - require.Equal(t, storedURL, storedURLs[0]) - require.Contains(t, storedURLs[0], fakeS3.URL) - require.Contains(t, storedURLs[1], fakeS3.URL) - require.Greater(t, fileSize, int64(0)) -} - -func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { - // 上游返回 404 → 下载失败 → S3 上传不会开始 - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - })) - defer badSource.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeUpstream, storageType) -} - -func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - // S3 失败,降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Nil(t, s3Keys) -} - -func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail-second") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, - ) - // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) - require.Nil(t, s3Keys) -} - -// ==================== storeMediaWithDegradation: 本地存储路径 ==================== - -func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { - // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: "/dev/null/invalid_dir", - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, - ) - // 本地存储失败,降级到 upstream - require.Equal(t, service.SoraStorageTypeUpstream, storageType) -} - -func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - DownloadTimeoutSeconds: 5, - MaxDownloadBytes: 10 * 1024 * 1024, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - require.Equal(t, service.SoraStorageTypeLocal, storageType) - require.Nil(t, s3Keys) // 本地存储不返回 S3 keys -} - -func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - DownloadTimeoutSeconds: 5, - MaxDownloadBytes: 10 * 1024 * 1024, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{ - s3Storage: s3Storage, - mediaStorage: mediaStorage, - } - - _, _, storageType, _, _ := h.storeMediaWithDegradation( - context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, - ) - // S3 失败 → 本地存储成功 - require.Equal(t, service.SoraStorageTypeLocal, storageType) -} - -// ==================== SaveToStorage: S3 路径 ==================== - -func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "S3") -} - -func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { - expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer expiredServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: expiredServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusGone, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, fmt.Sprint(resp["message"]), "过期") -} - -func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Contains(t, data["message"], "S3") - require.NotEmpty(t, data["object_key"]) - // 验证记录已更新为 S3 存储 - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) -} - -func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v1.mp4", - MediaURLs: []string{ - sourceServer.URL + "/v1.mp4", - sourceServer.URL + "/v2.mp4", - }, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Len(t, data["object_keys"].([]any), 2) - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) - require.Len(t, repo.gens[1].S3ObjectKeys, 2) - require.Len(t, repo.gens[1].MediaURLs, 2) -} - -func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 100 * 1024 * 1024, - SoraStorageUsedBytes: 0, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusOK, rec.Code) - // 验证配额已累加 - require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) -} - -func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 - repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== GetStorageStatus: S3 路径 ==================== - -func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { - // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, true, data["s3_enabled"]) - require.Equal(t, false, data["s3_healthy"]) -} - -func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) - h.GetStorageStatus(c) - require.Equal(t, http.StatusOK, rec.Code) - resp := parseResponse(t, rec) - data := resp["data"].(map[string]any) - require.Equal(t, true, data["s3_enabled"]) - require.Equal(t, true, data["s3_healthy"]) -} - -// ==================== Stub: AccountRepository (用于 GatewayService) ==================== - -var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) - -type stubAccountRepoForHandler struct { - accounts []service.Account -} - -func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } -func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { - for i := range r.accounts { - if r.accounts[i].ID == id { - return &r.accounts[i], nil - } - } - return nil, fmt.Errorf("account not found") -} -func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { - return false, nil -} -func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } -func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) { - return nil, nil, nil -} -func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { - return nil, nil -} -func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } -func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { - return nil -} -func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { - return 0, nil -} -func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } -func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) { - return r.accounts, nil -} -func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { - return nil -} -func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { - return nil -} -func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { - return nil -} -func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } -func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { - return nil -} -func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { - return nil -} -func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { - return 0, nil -} - -func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error { - return nil -} - -func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error { - return nil -} - -// ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== - -var _ service.SoraClient = (*stubSoraClientForHandler)(nil) - -type stubSoraClientForHandler struct { - videoStatus *service.SoraVideoTaskStatus -} - -func (s *stubSoraClientForHandler) Enabled() bool { return true } -func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { - return "task-image", nil -} -func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { - return "task-video", nil -} -func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { - return "task-video", nil -} -func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { - return nil -} -func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { - return "", nil -} -func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { - return nil, nil -} -func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { - return s.videoStatus, nil -} - -// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== - -// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 -func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { - return service.NewGatewayService( - accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, - ) -} - -// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 -func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - return service.NewSoraGatewayService(soraClient, nil, nil, cfg) -} - -// ==================== processGeneration: 更多路径测试 ==================== - -func TestProcessGeneration_SelectAccountError(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") -} - -func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - // 提供可用账号使 SelectAccountForModel 成功 - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // soraGatewayService 为 nil - h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") -} - -func TestProcessGeneration_ForwardError(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // SoraClient 返回视频任务失败 - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "failed", - ErrorMsg: "content policy violation", - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") -} - -func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration - // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 - repo.getByIDOverrideAfterN = 1 - repo.getByIDOverrideStatus = "cancelled" - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) - require.Equal(t, "generating", repo.gens[1].Status) -} - -func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - // SoraClient 返回 completed 但无 URL - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: nil, // 无 URL - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") -} - -func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) - // 第 2 次返回 "cancelled" 状态,模拟外部取消 - repo.getByIDOverrideAfterN = 1 - repo.getByIDOverrideStatus = "cancelled" - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) - require.Equal(t, "generating", repo.gens[1].Status) -} - -func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - // 无 S3 和本地存储,降级到 upstream - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "completed", repo.gens[1].Status) - require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) - require.NotEmpty(t, repo.gens[1].MediaURL) -} - -func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{sourceServer.URL + "/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - s3Storage := newS3StorageForHandler(fakeS3.URL) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - s3Storage: s3Storage, - quotaService: quotaService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - require.Equal(t, "completed", repo.gens[1].Status) - require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) - require.NotEmpty(t, repo.gens[1].S3ObjectKeys) - require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) - // 验证配额已累加 - require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) -} - -func TestProcessGeneration_MarkCompletedFails(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 - repo.updateCallCount = new(int32) - repo.updateFailAfterN = 1 - genService := service.NewSoraGenerationService(repo, nil, nil) - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - soraClient := &stubSoraClientForHandler{ - videoStatus: &service.SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/video.mp4"}, - }, - } - soraGatewayService := newMinimalSoraGatewayService(soraClient) - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) - // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 - // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 - // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 - require.Equal(t, "completed", repo.gens[1].Status) -} - -// ==================== cleanupStoredMedia 直接测试 ==================== - -func TestCleanupStoredMedia_S3Path(t *testing.T) { - // S3 清理路径:s3Storage 为 nil 时不 panic - h := &SoraClientHandler{} - // 不应 panic - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) -} - -func TestCleanupStoredMedia_LocalPath(t *testing.T) { - // 本地清理路径:mediaStorage 为 nil 时不 panic - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) -} - -func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { - // upstream 类型不清理 - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) -} - -func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { - // 空 keys 不触发清理 - h := &SoraClientHandler{} - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) -} - -// ==================== DeleteGeneration: 本地存储清理路径 ==================== - -func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "video/test.mp4", - MediaURLs: []string{"video/test.mp4"}, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) - _, exists := repo.gens[1] - require.False(t, exists) -} - -func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { - // MediaURLs 为空,使用 MediaURL 作为清理路径 - tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "video/test.mp4", - MediaURLs: nil, // 空 - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { - // 非本地存储类型 → 跳过清理 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, - UserID: 1, - Status: "completed", - StorageType: service.SoraStorageTypeUpstream, - MediaURL: "https://upstream.com/v.mp4", - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -func TestDeleteGeneration_DeleteError(t *testing.T) { - // repo.Delete 出错 - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} - repo.deleteErr = fmt.Errorf("delete failed") - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusNotFound, rec.Code) -} - -// ==================== fetchUpstreamModels 测试 ==================== - -func TestFetchUpstreamModels_NilGateway(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - h := &SoraClientHandler{} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "gatewayService 未初始化") -} - -func TestFetchUpstreamModels_NoAccounts(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "选择 Sora 账号失败") -} - -func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "不支持模型同步") -} - -func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"base_url": "https://sora.test"}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "api_key") -} - -func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" - // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test"}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) -} - -func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "状态码 500") -} - -func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("not json")) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "解析响应失败") -} - -func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "空模型列表") -} - -func TestFetchUpstreamModels_Success(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 验证请求头 - require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) - require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - families, err := h.fetchUpstreamModels(context.Background()) - require.NoError(t, err) - require.NotEmpty(t, families) -} - -func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - _, err := h.fetchUpstreamModels(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "未能从上游模型列表中识别") -} - -// ==================== getModelFamilies 缓存测试 ==================== - -func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { - // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 - h := &SoraClientHandler{} - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - - // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) - families2 := h.getModelFamilies(context.Background()) - require.Equal(t, families, families2) - require.False(t, h.modelCacheUpstream) -} - -func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { - t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) - })) - defer ts.Close() - - accountRepo := &stubAccountRepoForHandler{ - accounts: []service.Account{ - {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, - Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, - }, - } - gatewayService := newMinimalGatewayService(accountRepo) - h := &SoraClientHandler{gatewayService: gatewayService} - - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - require.True(t, h.modelCacheUpstream) - - // 第二次调用命中缓存 - families2 := h.getModelFamilies(context.Background()) - require.Equal(t, families, families2) -} - -func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { - // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) - h := &SoraClientHandler{ - cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, - modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 - modelCacheUpstream: false, - } - // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 - families := h.getModelFamilies(context.Background()) - require.NotEmpty(t, families) - // 缓存已刷新,不再是 "old" - found := false - for _, f := range families { - if f.ID == "old" { - found = true - } - } - require.False(t, found, "过期缓存应被刷新") -} - -// ==================== processGeneration: groupID 与 ForcePlatform ==================== - -func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { - // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 空账号列表 → SelectAccountForModel 失败 - accountRepo := &stubAccountRepoForHandler{accounts: nil} - gatewayService := newMinimalGatewayService(accountRepo) - - h := &SoraClientHandler{ - genService: genService, - gatewayService: gatewayService, - } - - h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) - require.Equal(t, "failed", repo.gens[1].Status) - require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") -} - -// ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== - -func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { - // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403 - repo := newStubSoraGenRepo() - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error - userRepo := newStubUserRepoForHandler() - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - - h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) - - body := `{"model":"sora2-landscape-10s","prompt":"test"}` - c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) - h.Generate(c) - require.Equal(t, http.StatusForbidden, rec.Code) -} - -// ==================== Generate: CreatePending 并发限制错误 ==================== - -// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 -type stubSoraGenRepoWithAtomicCreate struct { - stubSoraGenRepo - limitErr error -} - -func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { - if r.limitErr != nil { - return r.limitErr - } - return r.stubSoraGenRepo.Create(context.Background(), gen) -} - -func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { - repo := &stubSoraGenRepoWithAtomicCreate{ - stubSoraGenRepo: *newStubSoraGenRepo(), - limitErr: service.ErrSoraGenerationConcurrencyLimit, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) - - body := `{"model":"sora2-landscape-10s","prompt":"test"}` - c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) - h.Generate(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "3") -} - -// ==================== SaveToStorage: 配额超限 ==================== - -func TestSaveToStorage_QuotaExceeded(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 用户配额已满 - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 10, - SoraStorageUsedBytes: 10, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusTooManyRequests, rec.Code) -} - -// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== - -func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error - userRepo := newStubUserRepoForHandler() - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== SaveToStorage: MediaURLs 全为空 ==================== - -func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: "", - MediaURLs: []string{}, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusBadRequest, rec.Code) - resp := parseResponse(t, rec) - require.Contains(t, resp["message"], "已过期") -} - -// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== - -func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("fail-second") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v1.mp4", - MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, - } - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== - -func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { - sourceServer := newFakeSourceServer() - defer sourceServer.Close() - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: "upstream", - MediaURL: sourceServer.URL + "/v.mp4", - } - repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) - genService := service.NewSoraGenerationService(repo, nil, nil) - - userRepo := newStubUserRepoForHandler() - userRepo.users[1] = &service.User{ - ID: 1, - SoraStorageQuotaBytes: 100 * 1024 * 1024, - SoraStorageUsedBytes: 0, - } - quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.SaveToStorage(c) - require.Equal(t, http.StatusInternalServerError, rec.Code) -} - -// ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== - -func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { - fakeS3 := newFakeS3Server("ok") - defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) -} - -func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { - fakeS3 := newFakeS3Server("fail") - defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) -} - -func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - h := &SoraClientHandler{mediaStorage: mediaStorage} - - h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) -} - -// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== - -func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "sora-del-test-*") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Sora: config.SoraConfig{ - Storage: config.SoraStorageConfig{ - Type: "local", - LocalPath: tmpDir, - }, - }, - } - mediaStorage := service.NewSoraMediaStorage(cfg) - - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ - ID: 1, UserID: 1, Status: "completed", - StorageType: service.SoraStorageTypeLocal, - MediaURL: "nonexistent/video.mp4", - MediaURLs: []string{"nonexistent/video.mp4"}, - } - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} - - c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.DeleteGeneration(c) - require.Equal(t, http.StatusOK, rec.Code) -} - -// ==================== CancelGeneration: 任务已结束冲突 ==================== - -func TestCancelGeneration_AlreadyCompleted(t *testing.T) { - repo := newStubSoraGenRepo() - repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} - genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService} - - c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) - c.Params = gin.Params{{Key: "id", Value: "1"}} - h.CancelGeneration(c) - require.Equal(t, http.StatusConflict, rec.Code) -} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go deleted file mode 100644 index d1e7e00fe5..0000000000 --- a/backend/internal/handler/sora_gateway_handler.go +++ /dev/null @@ -1,697 +0,0 @@ -package handler - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" - "github.com/Wei-Shaw/sub2api/internal/pkg/ip" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" - - "github.com/gin-gonic/gin" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.uber.org/zap" -) - -// SoraGatewayHandler handles Sora chat completions requests -// -// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。 -type SoraGatewayHandler struct { - gatewayService *service.GatewayService - soraGatewayService *service.SoraGatewayService - billingCacheService *service.BillingCacheService - usageRecordWorkerPool *service.UsageRecordWorkerPool - concurrencyHelper *ConcurrencyHelper - maxAccountSwitches int - streamMode string - soraTLSEnabled bool - soraMediaSigningKey string - soraMediaRoot string -} - -// NewSoraGatewayHandler creates a new SoraGatewayHandler -func NewSoraGatewayHandler( - gatewayService *service.GatewayService, - soraGatewayService *service.SoraGatewayService, - concurrencyService *service.ConcurrencyService, - billingCacheService *service.BillingCacheService, - usageRecordWorkerPool *service.UsageRecordWorkerPool, - cfg *config.Config, -) *SoraGatewayHandler { - pingInterval := time.Duration(0) - maxAccountSwitches := 3 - streamMode := "force" - soraTLSEnabled := true - signKey := "" - mediaRoot := "/app/data/sora" - if cfg != nil { - pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second - if cfg.Gateway.MaxAccountSwitches > 0 { - maxAccountSwitches = cfg.Gateway.MaxAccountSwitches - } - if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { - streamMode = mode - } - soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint - signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) - if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { - mediaRoot = root - } - } - return &SoraGatewayHandler{ - gatewayService: gatewayService, - soraGatewayService: soraGatewayService, - billingCacheService: billingCacheService, - usageRecordWorkerPool: usageRecordWorkerPool, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - maxAccountSwitches: maxAccountSwitches, - streamMode: strings.ToLower(streamMode), - soraTLSEnabled: soraTLSEnabled, - soraMediaSigningKey: signKey, - soraMediaRoot: mediaRoot, - } -} - -// ChatCompletions handles Sora /v1/chat/completions endpoint -func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { - apiKey, ok := middleware2.GetAPIKeyFromContext(c) - if !ok { - h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") - return - } - - subject, ok := middleware2.GetAuthSubjectFromContext(c) - if !ok { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") - return - } - reqLog := requestLogger( - c, - "handler.sora_gateway.chat_completions", - zap.Int64("user_id", subject.UserID), - zap.Int64("api_key_id", apiKey.ID), - zap.Any("group_id", apiKey.GroupID), - ) - - body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) - if err != nil { - if maxErr, ok := extractMaxBytesError(err); ok { - h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) - return - } - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") - return - } - if len(body) == 0 { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") - return - } - - setOpsRequestContext(c, "", false, body) - - // 校验请求体 JSON 合法性 - if !gjson.ValidBytes(body) { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") - return - } - - // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal - modelResult := gjson.GetBytes(body, "model") - if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") - return - } - reqModel := modelResult.String() - - msgsResult := gjson.GetBytes(body, "messages") - if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") - return - } - - clientStream := gjson.GetBytes(body, "stream").Bool() - reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream)) - if !clientStream { - if h.streamMode == "error" { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") - return - } - var err error - body, err = sjson.SetBytes(body, "stream", true) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - } - - setOpsRequestContext(c, reqModel, clientStream, body) - setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false))) - - platform := "" - if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { - platform = forced - } else if apiKey.Group != nil { - platform = apiKey.Group.Platform - } - if platform != service.PlatformSora { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") - return - } - - streamStarted := false - subscription, _ := middleware2.GetSubscriptionFromContext(c) - - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - waitCounted := false - if err != nil { - reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err)) - } else if !canWait { - reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait)) - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if err == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() - - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) - if err != nil { - reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err)) - h.handleConcurrencyError(c, err, "user", streamStarted) - return - } - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - waitCounted = false - } - userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) - if userReleaseFunc != nil { - defer userReleaseFunc() - } - - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) - h.handleStreamingAwareError(c, status, code, message, streamStarted) - return - } - - sessionHash := generateOpenAISessionHash(c, body) - - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 - var lastFailoverBody []byte - var lastFailoverHeaders http.Header - - for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0)) - if err != nil { - reqLog.Warn("sora.account_select_failed", - zap.Error(err), - zap.Int("excluded_account_count", len(failedAccountIDs)), - ) - if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int("last_upstream_status", lastFailoverStatus), - } - if rayID != "" { - fields = append(fields, zap.String("last_upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("last_upstream_content_type", contentType)) - } - reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...) - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) - return - } - account := selection.Account - setOpsSelectedAccount(c, account.ID, account.Platform) - proxyBound := account.ProxyID != nil - proxyID := int64(0) - if account.ProxyID != nil { - proxyID = *account.ProxyID - } - tlsFingerprintEnabled := h.soraTLSEnabled - - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - reqLog.Warn("sora.account_wait_counter_increment_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - } else if !canWait { - reqLog.Info("sora.account_wait_queue_full", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), - ) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - clientStream, - &streamStarted, - ) - if err != nil { - reqLog.Warn("sora.account_slot_acquire_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - } - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - - result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) - if accountReleaseFunc != nil { - accountReleaseFunc() - } - if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) - lastFailoverBody = failoverErr.ResponseBody - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - } - if rayID != "" { - fields = append(fields, zap.String("upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("upstream_content_type", contentType)) - } - reqLog.Warn("sora.upstream_failover_exhausted", fields...) - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) - return - } - lastFailoverStatus = failoverErr.StatusCode - lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) - lastFailoverBody = failoverErr.ResponseBody - switchCount++ - upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.String("upstream_error_code", upstreamErrCode), - zap.String("upstream_error_message", upstreamErrMsg), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - } - if rayID != "" { - fields = append(fields, zap.String("upstream_cf_ray", rayID)) - } - if mitigated != "" { - fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) - } - if contentType != "" { - fields = append(fields, zap.String("upstream_content_type", contentType)) - } - reqLog.Warn("sora.upstream_failover_switching", fields...) - continue - } - reqLog.Error("sora.forward_failed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Error(err), - ) - return - } - - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - requestPayloadHash := service.HashUsageRequestPayload(body) - inboundEndpoint := GetInboundEndpoint(c) - upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - - // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 - h.submitUsageRecordTask(func(ctx context.Context) { - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: inboundEndpoint, - UpstreamEndpoint: upstreamEndpoint, - UserAgent: userAgent, - IPAddress: clientIP, - RequestPayloadHash: requestPayloadHash, - }); err != nil { - logger.L().With( - zap.String("component", "handler.sora_gateway.chat_completions"), - zap.Int64("user_id", subject.UserID), - zap.Int64("api_key_id", apiKey.ID), - zap.Any("group_id", apiKey.GroupID), - zap.String("model", reqModel), - zap.Int64("account_id", account.ID), - ).Error("sora.record_usage_failed", zap.Error(err)) - } - }) - reqLog.Debug("sora.request_completed", - zap.Int64("account_id", account.ID), - zap.Int64("proxy_id", proxyID), - zap.Bool("proxy_bound", proxyBound), - zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), - zap.Int("switch_count", switchCount), - ) - return - } -} - -func generateOpenAISessionHash(c *gin.Context, body []byte) string { - if c == nil { - return "" - } - sessionID := strings.TrimSpace(c.GetHeader("session_id")) - if sessionID == "" { - sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) - } - if sessionID == "" && len(body) > 0 { - sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) - } - if sessionID == "" { - return "" - } - hash := sha256.Sum256([]byte(sessionID)) - return hex.EncodeToString(hash[:]) -} - -func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - defer func() { - if recovered := recover(); recovered != nil { - logger.L().With( - zap.String("component", "handler.sora_gateway.chat_completions"), - zap.Any("panic", recovered), - ).Error("sora.usage_record_task_panic_recovered") - } - }() - task(ctx) -} - -func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", - fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) -} - -func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { - upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) - service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") - - status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) - h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) -} - -func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { - if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { - baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) - return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) - } - - upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) - if strings.EqualFold(upstreamCode, "cf_shield_429") { - baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." - return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) - } - if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { - switch statusCode { - case 401, 403, 404, 500, 502, 503, 504: - return http.StatusBadGateway, "upstream_error", upstreamMessage - case 429: - return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage - } - } - - switch statusCode { - case 401: - return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" - case 403: - return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" - case 404: - if strings.EqualFold(upstreamCode, "unsupported_country_code") { - return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" - } - return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" - case 429: - return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" - case 529: - return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" - case 500, 502, 503, 504: - return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" - default: - return http.StatusBadGateway, "upstream_error", "Upstream request failed" - } -} - -func cloneHTTPHeaders(headers http.Header) http.Header { - if headers == nil { - return nil - } - return headers.Clone() -} - -func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) { - if headers != nil { - mitigated = strings.TrimSpace(headers.Get("cf-mitigated")) - contentType = strings.TrimSpace(headers.Get("content-type")) - if contentType == "" { - contentType = strings.TrimSpace(headers.Get("Content-Type")) - } - } - rayID = soraerror.ExtractCloudflareRayID(headers, body) - return rayID, mitigated, contentType -} - -func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { - return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) -} - -func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool { - message = strings.TrimSpace(message) - if message == "" { - return false - } - if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests { - lower := strings.ToLower(message) - if strings.Contains(lower, "Just a moment...`) - - h := &SoraGatewayHandler{} - h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) - - lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") - require.Len(t, lines, 2) - jsonStr := strings.TrimPrefix(lines[1], "data: ") - - var parsed map[string]any - require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) - - errorObj, ok := parsed["error"].(map[string]any) - require.True(t, ok) - require.Equal(t, "upstream_error", errorObj["type"]) - msg, _ := errorObj["message"].(string) - require.Contains(t, msg, "Cloudflare challenge") - require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") -} - -func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { - gin.SetMode(gin.TestMode) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest(http.MethodGet, "/", nil) - - headers := http.Header{} - headers.Set("cf-ray", "9d03b68c086027a1-SEA") - body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) - - h := &SoraGatewayHandler{} - h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) - - lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") - require.Len(t, lines, 2) - jsonStr := strings.TrimPrefix(lines[1], "data: ") - - var parsed map[string]any - require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) - - errorObj, ok := parsed["error"].(map[string]any) - require.True(t, ok) - require.Equal(t, "rate_limit_error", errorObj["type"]) - msg, _ := errorObj["message"].(string) - require.Contains(t, msg, "Cloudflare shield") - require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") -} - -func TestExtractSoraFailoverHeaderInsights(t *testing.T) { - headers := http.Header{} - headers.Set("cf-mitigated", "challenge") - headers.Set("content-type", "text/html") - body := []byte(``) - - rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body) - require.Equal(t, "9cff2d62d83bb98d", rayID) - require.Equal(t, "challenge", mitigated) - require.Equal(t, "text/html", contentType) -} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index c7c48e14bc..5c9458158a 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere }) require.True(t, called.Load(), "panic 后后续任务应仍可执行") } - -func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { - pool := newUsageRecordTestPool(t) - h := &SoraGatewayHandler{usageRecordWorkerPool: pool} - - done := make(chan struct{}) - h.submitUsageRecordTask(func(ctx context.Context) { - close(done) - }) - - select { - case <-done: - case <-time.After(time.Second): - t.Fatal("task not executed") - } -} - -func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { - h := &SoraGatewayHandler{} - var called atomic.Bool - - h.submitUsageRecordTask(func(ctx context.Context) { - if _, ok := ctx.Deadline(); !ok { - t.Fatal("expected deadline in fallback context") - } - called.Store(true) - }) - - require.True(t, called.Load()) -} - -func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { - h := &SoraGatewayHandler{} - require.NotPanics(t, func() { - h.submitUsageRecordTask(nil) - }) -} - -func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { - h := &SoraGatewayHandler{} - var called atomic.Bool - - require.NotPanics(t, func() { - h.submitUsageRecordTask(func(ctx context.Context) { - panic("usage task panic") - }) - }) - - h.submitUsageRecordTask(func(ctx context.Context) { - called.Store(true) - }) - require.True(t, called.Load(), "panic 后后续任务应仍可执行") -} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index c917f24a0d..d9622594f2 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -86,8 +86,6 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, - soraGatewayHandler *SoraGatewayHandler, - soraClientHandler *SoraClientHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, _ *service.IdempotencyCoordinator, @@ -104,8 +102,6 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, - SoraGateway: soraGatewayHandler, - SoraClient: soraClientHandler, Setting: settingHandler, Totp: totpHandler, } @@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet( NewAnnouncementHandler, NewGatewayHandler, NewOpenAIGatewayHandler, - NewSoraGatewayHandler, NewTotpHandler, ProvideSettingHandler, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index 6b8521bdcd..618b6adb60 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -17,8 +17,6 @@ import ( const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - // OAuth Client ID for Sora mobile flow (aligned with sora2api) - SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" @@ -39,8 +37,6 @@ const ( const ( // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client. OAuthPlatformOpenAI = "openai" - // OAuthPlatformSora uses Sora OAuth client. - OAuthPlatformSora = "sora" ) // OAuthSession stores OAuth flow state for OpenAI @@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor } // OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled. -// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri), -// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。 func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) { - switch strings.ToLower(strings.TrimSpace(platform)) { - case OAuthPlatformSora: - return ClientID, false - default: - return ClientID, true - } + return ClientID, true } // TokenRequest represents the token exchange request body diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go index 2970addff5..56b42fc9fb 100644 --- a/backend/internal/pkg/openai/oauth_test.go +++ b/backend/internal/pkg/openai/oauth_test.go @@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) { t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) } } - -// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id, -// 但不启用 codex_cli_simplified_flow。 -func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) { - authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora) - parsed, err := url.Parse(authURL) - if err != nil { - t.Fatalf("Parse URL failed: %v", err) - } - q := parsed.Query() - if got := q.Get("client_id"); got != ClientID { - t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID) - } - if got := q.Get("codex_cli_simplified_flow"); got != "" { - t.Fatalf("codex flow should be empty for sora, got=%q", got) - } - if got := q.Get("id_token_add_organizations"); got != "true" { - t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) - } -} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d45e8a1297..94bfb09d58 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1692,20 +1692,13 @@ func itoa(v int) string { } // FindByExtraField 根据 extra 字段中的键值对查找账号。 -// 该方法限定 platform='sora',避免误查询其他平台的账号。 // 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。 // -// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。 -// // FindByExtraField finds accounts by key-value pairs in the extra field. -// Limited to platform='sora' to avoid querying accounts from other platforms. // Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). -// -// Use case: Finding Sora accounts linked via linked_openai_account_id. func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { accounts, err := r.client.Account.Query(). Where( - dbaccount.PlatformEQ("sora"), // 限定平台为 sora dbaccount.DeletedAtIsNil(), func(s *entsql.Selector) { path := sqljson.Path(key) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index ade0d46486..b3b12e8113 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, - group.FieldSoraImagePrice360, - group.FieldSoraImagePrice540, - group.FieldSoraVideoPricePerRequest, - group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, @@ -608,22 +604,20 @@ func userEntityToService(u *dbent.User) *service.User { return nil } return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, - SoraStorageUsedBytes: u.SoraStorageUsedBytes, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } @@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, - SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 3cfd649bce..a075b586c4 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). - SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). - SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). - SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). - SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). @@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). - SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). - SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). - SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). - SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetMcpXMLInject(groupIn.MCPXMLInject). - SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index 44fa291bed..c1901d71aa 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) } -// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。 -func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() { - var seenClientIDs []string - s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - clientID := r.PostForm.Get("client_id") - seenClientIDs = append(seenClientIDs, clientID) - if clientID == openai.SoraClientID { - w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) - return - } - w.WriteHeader(http.StatusBadRequest) - })) - - resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID) - require.NoError(s.T(), err, "RefreshTokenWithClientID") - require.Equal(s.T(), "at-sora", resp.AccessToken) - require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs) -} - func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { const customClientID = "custom-client-id" var seenClientIDs []string @@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { } func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { - wantClientID := openai.SoraClientID + wantClientID := "custom-exchange-client-id" errCh := make(chan string, 1) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = r.ParseForm() diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go deleted file mode 100644 index ad2ae638fe..0000000000 --- a/backend/internal/repository/sora_account_repo.go +++ /dev/null @@ -1,98 +0,0 @@ -package repository - -import ( - "context" - "database/sql" - "errors" - - "github.com/Wei-Shaw/sub2api/internal/service" -) - -// soraAccountRepository 实现 service.SoraAccountRepository 接口。 -// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。 -// -// 设计说明: -// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理 -// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义 -// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除 -type soraAccountRepository struct { - sql *sql.DB -} - -// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例 -func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository { - return &soraAccountRepository{sql: sqlDB} -} - -// Upsert 创建或更新 Sora 账号扩展信息 -// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert -func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { - accessToken, accessOK := updates["access_token"].(string) - refreshToken, refreshOK := updates["refresh_token"].(string) - sessionToken, sessionOK := updates["session_token"].(string) - - if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" { - if !sessionOK { - return errors.New("缺少 access_token/refresh_token,且未提供可更新字段") - } - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_accounts - SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END, - updated_at = NOW() - WHERE account_id = $1 - `, accountID, sessionToken) - if err != nil { - return err - } - rows, err := result.RowsAffected() - if err != nil { - return err - } - if rows == 0 { - return errors.New("sora_accounts 记录不存在,无法仅更新 session_token") - } - return nil - } - - _, err := r.sql.ExecContext(ctx, ` - INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at) - VALUES ($1, $2, $3, $4, NOW(), NOW()) - ON CONFLICT (account_id) DO UPDATE SET - access_token = EXCLUDED.access_token, - refresh_token = EXCLUDED.refresh_token, - session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END, - updated_at = NOW() - `, accountID, accessToken, refreshToken, sessionToken) - return err -} - -// GetByAccountID 根据账号 ID 获取 Sora 扩展信息 -func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { - rows, err := r.sql.QueryContext(ctx, ` - SELECT account_id, access_token, refresh_token, COALESCE(session_token, '') - FROM sora_accounts - WHERE account_id = $1 - `, accountID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - if !rows.Next() { - return nil, nil // 记录不存在 - } - - var sa service.SoraAccount - if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil { - return nil, err - } - return &sa, nil -} - -// Delete 删除 Sora 账号扩展信息 -func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error { - _, err := r.sql.ExecContext(ctx, ` - DELETE FROM sora_accounts WHERE account_id = $1 - `, accountID) - return err -} diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go deleted file mode 100644 index aaf3cb2f54..0000000000 --- a/backend/internal/repository/sora_generation_repo.go +++ /dev/null @@ -1,419 +0,0 @@ -package repository - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/service" -) - -// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。 -// 使用原生 SQL 操作 sora_generations 表。 -type soraGenerationRepository struct { - sql *sql.DB -} - -// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。 -func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository { - return &soraGenerationRepository{sql: sqlDB} -} - -func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error { - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - - err := r.sql.QueryRowContext(ctx, ` - INSERT INTO sora_generations ( - user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - RETURNING id, created_at - `, - gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, - gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, - ).Scan(&gen.ID, &gen.CreatedAt) - return err -} - -// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。 -func (r *soraGenerationRepository) CreatePendingWithLimit( - ctx context.Context, - gen *service.SoraGeneration, - activeStatuses []string, - maxActive int64, -) error { - if gen == nil { - return fmt.Errorf("generation is nil") - } - if maxActive <= 0 { - return r.Create(ctx, gen) - } - if len(activeStatuses) == 0 { - activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating} - } - - tx, err := r.sql.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - - // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。 - if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil { - return err - } - - placeholders := make([]string, len(activeStatuses)) - args := make([]any, 0, 1+len(activeStatuses)) - args = append(args, gen.UserID) - for i, s := range activeStatuses { - placeholders[i] = fmt.Sprintf("$%d", i+2) - args = append(args, s) - } - countQuery := fmt.Sprintf( - `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`, - strings.Join(placeholders, ","), - ) - var activeCount int64 - if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil { - return err - } - if activeCount >= maxActive { - return service.ErrSoraGenerationConcurrencyLimit - } - - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - if err := tx.QueryRowContext(ctx, ` - INSERT INTO sora_generations ( - user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - RETURNING id, created_at - `, - gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, - gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, - ).Scan(&gen.ID, &gen.CreatedAt); err != nil { - return err - } - - return tx.Commit() -} - -func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) { - gen := &service.SoraGeneration{} - var mediaURLsJSON, s3KeysJSON []byte - var completedAt sql.NullTime - var apiKeyID sql.NullInt64 - - err := r.sql.QueryRowContext(ctx, ` - SELECT id, user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message, - created_at, completed_at - FROM sora_generations WHERE id = $1 - `, id).Scan( - &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, - &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, - &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, - &gen.CreatedAt, &completedAt, - ) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("生成记录不存在") - } - return nil, err - } - - if apiKeyID.Valid { - gen.APIKeyID = &apiKeyID.Int64 - } - if completedAt.Valid { - gen.CompletedAt = &completedAt.Time - } - _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) - _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) - return gen, nil -} - -func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error { - mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) - s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) - - var completedAt *time.Time - if gen.CompletedAt != nil { - completedAt = gen.CompletedAt - } - - _, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations SET - status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5, - storage_type = $6, s3_object_keys = $7, upstream_task_id = $8, - error_message = $9, completed_at = $10 - WHERE id = $1 - `, - gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, - gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, - gen.ErrorMessage, completedAt, - ) - return err -} - -// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。 -func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, upstream_task_id = $3 - WHERE id = $1 AND status = $4 - `, - id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。 -func (r *soraGenerationRepository) UpdateCompletedIfActive( - ctx context.Context, - id int64, - mediaURL string, - mediaURLs []string, - storageType string, - s3Keys []string, - fileSizeBytes int64, - completedAt time.Time, -) (bool, error) { - mediaURLsJSON, _ := json.Marshal(mediaURLs) - s3KeysJSON, _ := json.Marshal(s3Keys) - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, - media_url = $3, - media_urls = $4, - file_size_bytes = $5, - storage_type = $6, - s3_object_keys = $7, - error_message = '', - completed_at = $8 - WHERE id = $1 AND status IN ($9, $10) - `, - id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes, - storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。 -func (r *soraGenerationRepository) UpdateFailedIfActive( - ctx context.Context, - id int64, - errMsg string, - completedAt time.Time, -) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, - error_message = $3, - completed_at = $4 - WHERE id = $1 AND status IN ($5, $6) - `, - id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。 -func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) { - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET status = $2, completed_at = $3 - WHERE id = $1 AND status IN ($4, $5) - `, - id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。 -func (r *soraGenerationRepository) UpdateStorageIfCompleted( - ctx context.Context, - id int64, - mediaURL string, - mediaURLs []string, - storageType string, - s3Keys []string, - fileSizeBytes int64, -) (bool, error) { - mediaURLsJSON, _ := json.Marshal(mediaURLs) - s3KeysJSON, _ := json.Marshal(s3Keys) - result, err := r.sql.ExecContext(ctx, ` - UPDATE sora_generations - SET media_url = $2, - media_urls = $3, - file_size_bytes = $4, - storage_type = $5, - s3_object_keys = $6 - WHERE id = $1 AND status = $7 - `, - id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted, - ) - if err != nil { - return false, err - } - affected, err := result.RowsAffected() - if err != nil { - return false, err - } - return affected > 0, nil -} - -func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error { - _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id) - return err -} - -func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { - // 构建 WHERE 条件 - conditions := []string{"user_id = $1"} - args := []any{params.UserID} - argIdx := 2 - - if params.Status != "" { - // 支持逗号分隔的多状态 - statuses := strings.Split(params.Status, ",") - placeholders := make([]string, len(statuses)) - for i, s := range statuses { - placeholders[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, strings.TrimSpace(s)) - argIdx++ - } - conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) - } - if params.StorageType != "" { - storageTypes := strings.Split(params.StorageType, ",") - placeholders := make([]string, len(storageTypes)) - for i, s := range storageTypes { - placeholders[i] = fmt.Sprintf("$%d", argIdx) - args = append(args, strings.TrimSpace(s)) - argIdx++ - } - conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ","))) - } - if params.MediaType != "" { - conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx)) - args = append(args, params.MediaType) - argIdx++ - } - - whereClause := "WHERE " + strings.Join(conditions, " AND ") - - // 计数 - var total int64 - countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause) - if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { - return nil, 0, err - } - - // 分页查询 - offset := (params.Page - 1) * params.PageSize - listQuery := fmt.Sprintf(` - SELECT id, user_id, api_key_id, model, prompt, media_type, - status, media_url, media_urls, file_size_bytes, - storage_type, s3_object_keys, upstream_task_id, error_message, - created_at, completed_at - FROM sora_generations %s - ORDER BY created_at DESC - LIMIT $%d OFFSET $%d - `, whereClause, argIdx, argIdx+1) - args = append(args, params.PageSize, offset) - - rows, err := r.sql.QueryContext(ctx, listQuery, args...) - if err != nil { - return nil, 0, err - } - defer func() { - _ = rows.Close() - }() - - var results []*service.SoraGeneration - for rows.Next() { - gen := &service.SoraGeneration{} - var mediaURLsJSON, s3KeysJSON []byte - var completedAt sql.NullTime - var apiKeyID sql.NullInt64 - - if err := rows.Scan( - &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, - &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, - &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, - &gen.CreatedAt, &completedAt, - ); err != nil { - return nil, 0, err - } - - if apiKeyID.Valid { - gen.APIKeyID = &apiKeyID.Int64 - } - if completedAt.Valid { - gen.CompletedAt = &completedAt.Time - } - _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) - _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) - results = append(results, gen) - } - - return results, total, rows.Err() -} - -func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) { - if len(statuses) == 0 { - return 0, nil - } - - placeholders := make([]string, len(statuses)) - args := []any{userID} - for i, s := range statuses { - placeholders[i] = fmt.Sprintf("$%d", i+2) - args = append(args, s) - } - - var count int64 - query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ",")) - err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count) - return count, err -} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 66d0b4ec19..d7bcd0944d 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" // usageLogInsertArgTypes must stay in the same order as: // 1. prepareUsageLogInsert().args @@ -73,7 +73,6 @@ var usageLogInsertArgTypes = [...]string{ "text", // ip_address "integer", // image_count "text", // image_size - "text", // media_type "text", // service_tier "text", // reasoning_effort "text", // inbound_endpoint @@ -352,7 +351,6 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -369,7 +367,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -790,7 +788,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -803,7 +800,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*47) + args := make([]any, 0, len(keys)*46) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -867,7 +864,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -915,7 +911,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -1003,7 +998,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -1016,7 +1010,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*46) + args := make([]any, 0, len(preparedList)*45) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -1077,7 +1071,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -1125,7 +1118,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -1181,7 +1173,6 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ip_address, image_count, image_size, - media_type, service_tier, reasoning_effort, inbound_endpoint, @@ -1198,7 +1189,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1225,7 +1216,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) - mediaType := nullString(log.MediaType) serviceTier := nullString(log.ServiceTier) reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) @@ -1286,7 +1276,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ipAddress, log.ImageCount, imageSize, - mediaType, serviceTier, reasoningEffort, inboundEndpoint, @@ -4051,7 +4040,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ipAddress sql.NullString imageCount int imageSize sql.NullString - mediaType sql.NullString serviceTier sql.NullString reasoningEffort sql.NullString inboundEndpoint sql.NullString @@ -4101,7 +4089,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &ipAddress, &imageCount, &imageSize, - &mediaType, &serviceTier, &reasoningEffort, &inboundEndpoint, @@ -4179,9 +4166,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } - if mediaType.Valid { - log.MediaType = &mediaType.String - } if serviceTier.Valid { log.ServiceTier = &serviceTier.String } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 77f695e378..ce0c5f0040 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -76,7 +76,6 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { sqlmock.AnyArg(), // ip_address log.ImageCount, sqlmock.AnyArg(), // image_size - sqlmock.AnyArg(), // media_type sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // reasoning_effort sqlmock.AnyArg(), // inbound_endpoint @@ -155,7 +154,6 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { sqlmock.AnyArg(), log.ImageCount, sqlmock.AnyArg(), - sqlmock.AnyArg(), serviceTier, sqlmock.AnyArg(), sqlmock.AnyArg(), @@ -471,7 +469,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, - sql.NullString{}, sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, sql.NullString{}, @@ -519,7 +516,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, - sql.NullString{}, sql.NullString{Valid: true, String: "flex"}, sql.NullString{}, sql.NullString{}, @@ -567,7 +563,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, - sql.NullString{}, sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, sql.NullString{}, diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 575754e03b..06c79113e2 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). - SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). - SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). - SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) @@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } -// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。 -func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) { - if deltaBytes <= 0 { - user, err := r.GetByID(ctx, userID) - if err != nil { - return 0, err - } - return user.SoraStorageUsedBytes, nil - } - var newUsed int64 - err := scanSingleRow(ctx, r.sql, ` - UPDATE users - SET sora_storage_used_bytes = sora_storage_used_bytes + $2 - WHERE id = $1 - AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3) - RETURNING sora_storage_used_bytes - `, []any{userID, deltaBytes, effectiveQuota}, &newUsed) - if err == nil { - return newUsed, nil - } - if errors.Is(err, sql.ErrNoRows) { - // 区分用户不存在和配额冲突 - exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx) - if existsErr != nil { - return 0, existsErr - } - if !exists { - return 0, service.ErrUserNotFound - } - return 0, service.ErrSoraStorageQuotaExceeded - } - return 0, err -} - -// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。 -func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) { - if deltaBytes <= 0 { - user, err := r.GetByID(ctx, userID) - if err != nil { - return 0, err - } - return user.SoraStorageUsedBytes, nil - } - var newUsed int64 - err := scanSingleRow(ctx, r.sql, ` - UPDATE users - SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0) - WHERE id = $1 - RETURNING sora_storage_used_bytes - `, []any{userID, deltaBytes}, &newUsed) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return 0, service.ErrUserNotFound - } - return 0, err - } - return newUsed, nil -} - func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 4548c02882..657e3ed66c 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, - NewSoraAccountRepository, // Sora 账号扩展表仓储 NewScheduledTestPlanRepository, // 定时测试计划仓储 NewScheduledTestResultRepository, // 定时测试结果仓储 NewProxyRepository, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 450c312265..d412ea34d6 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -204,11 +204,6 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, - "sora_image_price_360": null, - "sora_image_price_540": null, - "sora_storage_quota_bytes": 0, - "sora_video_price_per_request": null, - "sora_video_price_per_request_hd": null, "claude_code_only": false, "allow_messages_dispatch": false, "fallback_group_id": null, @@ -532,7 +527,6 @@ func TestAPIContracts(t *testing.T) { "fallback_model_openai": "gpt-4o", "enable_identity_patch": true, "identity_patch_prompt": "", - "sora_client_enabled": false, "invitation_code_enabled": false, "home_content": "", "hide_ccs_import_button": false, @@ -653,11 +647,11 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index d9ec951e77..73210bfcbb 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool { return strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/responses") } diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 997015317f..d60a142c87 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -109,7 +109,6 @@ func registerRoutes( // 注册各模块路由 routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) - routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService) routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 76f4c4b4d6..b921da9511 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -34,8 +34,6 @@ func RegisterAdminRoutes( // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) - // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立) - registerSoraOAuthRoutes(admin, h) // Gemini OAuth registerGeminiOAuthRoutes(admin, h) @@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } -func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { - sora := admin.Group("/sora") - { - sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) - sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) - sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) - sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken) - sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken) - sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) - sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) - } -} - func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { gemini := admin.Group("/gemini") { @@ -422,15 +407,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // Beta 策略配置 adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings) adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings) - // Sora S3 存储配置 - adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) - adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) - adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection) - adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles) - adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile) - adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile) - adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile) - adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile) } } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 072cfdee37..cbf9829326 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -23,11 +23,6 @@ func RegisterGatewayRoutes( cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) - soraMaxBodySize := cfg.Gateway.SoraMaxBodySize - if soraMaxBodySize <= 0 { - soraMaxBodySize = cfg.Gateway.MaxBodySize - } - soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) endpointNorm := handler.InboundEndpointMiddleware() @@ -163,28 +158,6 @@ func RegisterGatewayRoutes( antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } - // Sora 专用路由(强制使用 sora 平台) - soraV1 := r.Group("/sora/v1") - soraV1.Use(soraBodyLimit) - soraV1.Use(clientRequestID) - soraV1.Use(opsErrorLogger) - soraV1.Use(endpointNorm) - soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) - soraV1.Use(gin.HandlerFunc(apiKeyAuth)) - soraV1.Use(requireGroupAnthropic) - { - soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) - soraV1.GET("/models", h.Gateway.Models) - } - - // Sora 媒体代理(可选 API Key 验证) - if cfg.Gateway.SoraMediaRequireAPIKey { - r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy) - } else { - r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy) - } - // Sora 媒体代理(签名 URL,无需 API Key) - r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } // getGroupPlatform extracts the group platform from the API Key stored in context. diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go index 00edd31b84..4d65a626f9 100644 --- a/backend/internal/server/routes/gateway_test.go +++ b/backend/internal/server/routes/gateway_test.go @@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine { &handler.Handlers{ Gateway: &handler.GatewayHandler{}, OpenAIGateway: &handler.OpenAIGatewayHandler{}, - SoraGateway: &handler.SoraGatewayHandler{}, }, servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) { c.Next() diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go deleted file mode 100644 index 13fceb812c..0000000000 --- a/backend/internal/server/routes/sora_client.go +++ /dev/null @@ -1,36 +0,0 @@ -package routes - -import ( - "github.com/Wei-Shaw/sub2api/internal/handler" - "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。 -func RegisterSoraClientRoutes( - v1 *gin.RouterGroup, - h *handler.Handlers, - jwtAuth middleware.JWTAuthMiddleware, - settingService *service.SettingService, -) { - if h.SoraClient == nil { - return - } - - authenticated := v1.Group("/sora") - authenticated.Use(gin.HandlerFunc(jwtAuth)) - authenticated.Use(middleware.BackendModeUserGuard(settingService)) - { - authenticated.POST("/generate", h.SoraClient.Generate) - authenticated.GET("/generations", h.SoraClient.ListGenerations) - authenticated.GET("/generations/:id", h.SoraClient.GetGeneration) - authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration) - authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration) - authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage) - authenticated.GET("/quota", h.SoraClient.GetQuota) - authenticated.GET("/models", h.SoraClient.GetModels) - authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus) - } -} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 328790a87f..3189a7290f 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -28,8 +28,7 @@ type AccountRepository interface { // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) - // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') - // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 + // FindByExtraField 根据 extra 字段中的键值对查找账号 FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) // ListCRSAccountIDs returns a map of crs_account_id -> local account ID // for all accounts that have been synced from CRS. diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 8218c2db0e..55865945c6 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -13,18 +13,14 @@ import ( "log" "net/http" "net/http/httptest" - "net/url" "regexp" "strings" - "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" - soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 - soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" - soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" - soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" - soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -71,13 +62,8 @@ type AccountTestService struct { httpUpstream HTTPUpstream cfg *config.Config tlsFPProfileService *TLSFingerprintProfileService - soraTestGuardMu sync.Mutex - soraTestLastRun map[int64]time.Time - soraTestCooldown time.Duration } -const defaultSoraTestCooldown = 10 * time.Second - // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, @@ -94,8 +80,6 @@ func NewAccountTestService( httpUpstream: httpUpstream, cfg: cfg, tlsFPProfileService: tlsFPProfileService, - soraTestLastRun: make(map[int64]time.Time), - soraTestCooldown: defaultSoraTestCooldown, } } @@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.routeAntigravityTest(c, account, modelID, prompt) } - if account.Platform == PlatformSora { - return s.testSoraAccountConnection(c, account) - } - return s.testClaudeAccountConnection(c, account, modelID) } @@ -634,698 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } -type soraProbeStep struct { - Name string `json:"name"` - Status string `json:"status"` - HTTPStatus int `json:"http_status,omitempty"` - ErrorCode string `json:"error_code,omitempty"` - Message string `json:"message,omitempty"` -} - -type soraProbeSummary struct { - Status string `json:"status"` - Steps []soraProbeStep `json:"steps"` -} - -type soraProbeRecorder struct { - steps []soraProbeStep -} - -func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) { - r.steps = append(r.steps, soraProbeStep{ - Name: name, - Status: status, - HTTPStatus: httpStatus, - ErrorCode: strings.TrimSpace(errorCode), - Message: strings.TrimSpace(message), - }) -} - -func (r *soraProbeRecorder) finalize() soraProbeSummary { - meSuccess := false - partial := false - for _, step := range r.steps { - if step.Name == "me" { - meSuccess = strings.EqualFold(step.Status, "success") - continue - } - if strings.EqualFold(step.Status, "failed") { - partial = true - } - } - - status := "success" - if !meSuccess { - status = "failed" - } else if partial { - status = "partial_success" - } - - return soraProbeSummary{ - Status: status, - Steps: append([]soraProbeStep(nil), r.steps...), - } -} - -func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) { - if rec == nil { - return - } - summary := rec.finalize() - code := "" - for _, step := range summary.Steps { - if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" { - code = step.ErrorCode - break - } - } - s.sendEvent(c, TestEvent{ - Type: "sora_test_result", - Status: summary.Status, - Code: code, - Data: summary, - }) -} - -func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) { - if accountID <= 0 { - return 0, true - } - s.soraTestGuardMu.Lock() - defer s.soraTestGuardMu.Unlock() - - if s.soraTestLastRun == nil { - s.soraTestLastRun = make(map[int64]time.Time) - } - cooldown := s.soraTestCooldown - if cooldown <= 0 { - cooldown = defaultSoraTestCooldown - } - - now := time.Now() - if lastRun, ok := s.soraTestLastRun[accountID]; ok { - elapsed := now.Sub(lastRun) - if elapsed < cooldown { - return cooldown - elapsed, false - } - } - s.soraTestLastRun[accountID] = now - return 0, true -} - -func ceilSeconds(d time.Duration) int { - if d <= 0 { - return 1 - } - sec := int(d / time.Second) - if d%time.Second != 0 { - sec++ - } - if sec < 1 { - sec = 1 - } - return sec -} - -// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。 -// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。 -func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error { - ctx := c.Request.Context() - - apiKey := account.GetCredential("api_key") - if apiKey == "" { - return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证") - } - - baseURL := account.GetBaseURL() - if baseURL == "" { - return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url") - } - - // 验证 base_url 格式 - normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error())) - } - upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions" - - // 设置 SSE 头 - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { - msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) - return s.sendErrorAndEnd(c, msg) - } - - s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"}) - - // 构建轻量级 prompt-enhance 请求作为连通性测试 - testPayload := map[string]any{ - "model": "prompt-enhance-short-10s", - "messages": []map[string]string{{"role": "user", "content": "test"}}, - "stream": false, - } - payloadBytes, _ := json.Marshal(testPayload) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes)) - if err != nil { - return s.sendErrorAndEnd(c, "构建测试请求失败") - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - // 获取代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error())) - } - defer func() { _ = resp.Body.Close() }() - - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) - - if resp.StatusCode == http.StatusOK { - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)}) - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { - return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode)) - } - - // 其他错误但能连通(如 400 参数错误)也算连通性测试通过 - if resp.StatusCode == http.StatusBadRequest { - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)}) - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256))) -} - -// testSoraAccountConnection 测试 Sora 账号的连接 -// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性 -// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性 -func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { - // apikey 类型走独立测试流程 - if account.Type == AccountTypeAPIKey { - return s.testSoraAPIKeyAccountConnection(c, account) - } - - ctx := c.Request.Context() - recorder := &soraProbeRecorder{} - - authToken := account.GetCredential("access_token") - if authToken == "" { - recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "No access token available") - } - - // Set SSE headers - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { - msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) - recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, msg) - } - - // Send test_start event - s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) - - req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) - if err != nil { - recorder.addStep("me", "failed", 0, "request_build_failed", err.Error()) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "Failed to create request") - } - - // 使用 Sora 客户端标准请求头 - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - req.Header.Set("Accept", "application/json") - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - req.Header.Set("Origin", "https://sora.chatgpt.com") - req.Header.Set("Referer", "https://sora.chatgpt.com/") - - // Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - soraTLSProfile := s.resolveSoraTLSProfile() - - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile) - if err != nil { - recorder.addStep("me", "failed", 0, "network_error", err.Error()) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) - } - defer func() { _ = resp.Body.Close() }() - - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { - recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected") - s.emitSoraProbeSummary(c, recorder) - s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) - return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body)) - } - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body) - switch { - case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"): - recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号") - case strings.EqualFold(upstreamCode, "unsupported_country_code"): - recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试") - case strings.TrimSpace(upstreamMessage) != "": - recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage) - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage)) - default: - recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed") - s.emitSoraProbeSummary(c, recorder) - return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) - } - } - recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok") - - // 解析 /me 响应,提取用户信息 - var meResp map[string]any - if err := json.Unmarshal(body, &meResp); err != nil { - // 能收到 200 就说明 token 有效 - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"}) - } else { - // 尝试提取用户名或邮箱信息 - info := "Sora connection OK" - if name, ok := meResp["name"].(string); ok && name != "" { - info = fmt.Sprintf("Sora connection OK - User: %s", name) - } else if email, ok := meResp["email"].(string); ok && email != "" { - info = fmt.Sprintf("Sora connection OK - Email: %s", email) - } - s.sendEvent(c, TestEvent{Type: "content", Text: info}) - } - - // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试) - subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil) - if err == nil { - subReq.Header.Set("Authorization", "Bearer "+authToken) - subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - subReq.Header.Set("Accept", "application/json") - subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") - subReq.Header.Set("Origin", "https://sora.chatgpt.com") - subReq.Header.Set("Referer", "https://sora.chatgpt.com/") - - subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile) - if subErr != nil { - recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) - } else { - subBody, _ := io.ReadAll(subResp.Body) - _ = subResp.Body.Close() - if subResp.StatusCode == http.StatusOK { - recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok") - if summary := parseSoraSubscriptionSummary(subBody); summary != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: summary}) - } else { - s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) - } - } else { - if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) { - recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected") - s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)}) - } else { - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody) - recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage) - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) - } - } - } - } - - // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 - s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder) - - s.emitSoraProbeSummary(c, recorder) - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil -} - -func (s *AccountTestService) testSora2Capabilities( - c *gin.Context, - ctx context.Context, - account *Account, - authToken string, - proxyURL string, - tlsProfile *tlsfingerprint.Profile, - recorder *soraProbeRecorder, -) { - inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraInviteMineURL, - proxyURL, - tlsProfile, - ) - if err != nil { - if recorder != nil { - recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) - return - } - - if inviteStatus == http.StatusUnauthorized { - bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraBootstrapURL, - proxyURL, - tlsProfile, - ) - if bootstrapErr == nil && bootstrapStatus == http.StatusOK { - if recorder != nil { - recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok") - } - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) - inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraInviteMineURL, - proxyURL, - tlsProfile, - ) - if err != nil { - if recorder != nil { - recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) - return - } - } else if recorder != nil { - code := "" - msg := "" - if bootstrapErr != nil { - code = "network_error" - msg = bootstrapErr.Error() - } - recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg) - } - } - - if inviteStatus != http.StatusOK { - if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) { - if recorder != nil { - recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected") - } - s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody) - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)}) - return - } - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody) - if recorder != nil { - recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) - return - } - if recorder != nil { - recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok") - } - - if summary := parseSoraInviteSummary(inviteBody); summary != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: summary}) - } else { - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) - } - - remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( - ctx, - account, - authToken, - soraRemainingURL, - proxyURL, - tlsProfile, - ) - if remainingErr != nil { - if recorder != nil { - recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error()) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) - return - } - if remainingStatus != http.StatusOK { - if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) { - if recorder != nil { - recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected") - } - s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody) - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)}) - return - } - upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody) - if recorder != nil { - recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage) - } - s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) - return - } - if recorder != nil { - recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok") - } - if summary := parseSoraRemainingSummary(remainingBody); summary != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: summary}) - } else { - s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) - } -} - -func (s *AccountTestService) fetchSoraTestEndpoint( - ctx context.Context, - account *Account, - authToken string, - url string, - proxyURL string, - tlsProfile *tlsfingerprint.Profile, -) (int, http.Header, []byte, error) { - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return 0, nil, nil, err - } - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - req.Header.Set("Accept", "application/json") - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - req.Header.Set("Origin", "https://sora.chatgpt.com") - req.Header.Set("Referer", "https://sora.chatgpt.com/") - - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile) - if err != nil { - return 0, nil, nil, err - } - defer func() { _ = resp.Body.Close() }() - - body, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return resp.StatusCode, resp.Header, nil, readErr - } - return resp.StatusCode, resp.Header, body, nil -} - -func parseSoraSubscriptionSummary(body []byte) string { - var subResp struct { - Data []struct { - Plan struct { - ID string `json:"id"` - Title string `json:"title"` - } `json:"plan"` - EndTS string `json:"end_ts"` - } `json:"data"` - } - if err := json.Unmarshal(body, &subResp); err != nil { - return "" - } - if len(subResp.Data) == 0 { - return "" - } - - first := subResp.Data[0] - parts := make([]string, 0, 3) - if first.Plan.Title != "" { - parts = append(parts, first.Plan.Title) - } - if first.Plan.ID != "" { - parts = append(parts, first.Plan.ID) - } - if first.EndTS != "" { - parts = append(parts, "end="+first.EndTS) - } - if len(parts) == 0 { - return "" - } - return "Subscription: " + strings.Join(parts, " | ") -} - -func parseSoraInviteSummary(body []byte) string { - var inviteResp struct { - InviteCode string `json:"invite_code"` - RedeemedCount int64 `json:"redeemed_count"` - TotalCount int64 `json:"total_count"` - } - if err := json.Unmarshal(body, &inviteResp); err != nil { - return "" - } - - parts := []string{"Sora2: supported"} - if inviteResp.InviteCode != "" { - parts = append(parts, "invite="+inviteResp.InviteCode) - } - if inviteResp.TotalCount > 0 { - parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) - } - return strings.Join(parts, " | ") -} - -func parseSoraRemainingSummary(body []byte) string { - var remainingResp struct { - RateLimitAndCreditBalance struct { - EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` - RateLimitReached bool `json:"rate_limit_reached"` - AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` - } `json:"rate_limit_and_credit_balance"` - } - if err := json.Unmarshal(body, &remainingResp); err != nil { - return "" - } - info := remainingResp.RateLimitAndCreditBalance - parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} - if info.RateLimitReached { - parts = append(parts, "rate_limited=true") - } - if info.AccessResetsInSeconds > 0 { - parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) - } - return strings.Join(parts, " | ") -} - -func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile { - if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint { - // Sora TLS fingerprint enabled — use built-in default profile - return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"} - } - return nil // disabled -} - -func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { - return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) -} - -func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { - return soraerror.FormatCloudflareChallengeMessage(base, headers, body) -} - -func extractCloudflareRayID(headers http.Header, body []byte) string { - return soraerror.ExtractCloudflareRayID(headers, body) -} - -func extractSoraEgressIPHint(headers http.Header) string { - if headers == nil { - return "unknown" - } - candidates := []string{ - "x-openai-public-ip", - "x-envoy-external-address", - "cf-connecting-ip", - "x-forwarded-for", - } - for _, key := range candidates { - if value := strings.TrimSpace(headers.Get(key)); value != "" { - return value - } - } - return "unknown" -} - -func sanitizeProxyURLForLog(raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return "" - } - u, err := url.Parse(raw) - if err != nil { - return "" - } - if u.User != nil { - u.User = nil - } - return u.String() -} - -func endpointPathForLog(endpoint string) string { - parsed, err := url.Parse(strings.TrimSpace(endpoint)) - if err != nil || parsed.Path == "" { - return endpoint - } - return parsed.Path -} - -func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) { - accountID := int64(0) - platform := "" - proxyID := "none" - if account != nil { - accountID = account.ID - platform = account.Platform - if account.ProxyID != nil { - proxyID = fmt.Sprintf("%d", *account.ProxyID) - } - } - cfRay := extractCloudflareRayID(headers, body) - if cfRay == "" { - cfRay = "unknown" - } - log.Printf( - "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s", - accountID, - platform, - endpoint, - endpointPathForLog(endpoint), - proxyID, - sanitizeProxyURLForLog(proxyURL), - cfRay, - extractSoraEgressIPHint(headers), - ) -} - -func truncateSoraErrorBody(body []byte, max int) string { - return soraerror.TruncateBody(body, max) -} - // routeAntigravityTest 路由 Antigravity 账号的测试请求。 // APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error { diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go index 5ba04c69b7..f38264a215 100644 --- a/backend/internal/service/account_test_service_gemini_test.go +++ b/backend/internal/service/account_test_service_gemini_test.go @@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) - ctx, recorder := newSoraTestContext() + ctx, recorder := newTestContext() svc := &AccountTestService{} stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n") diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index efa6f7da78..5125db5ba5 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -4,16 +4,61 @@ package service import ( "context" + "fmt" "io" "net/http" + "net/http/httptest" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" ) +// --- shared test helpers --- + +type queuedHTTPUpstream struct { + responses []*http.Response + requests []*http.Request + tlsFlags []bool +} + +func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, fmt.Errorf("unexpected Do call") +} + +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) { + u.requests = append(u.requests, req) + u.tlsFlags = append(u.tlsFlags, profile != nil) + if len(u.responses) == 0 { + return nil, fmt.Errorf("no mocked response") + } + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil +} + +func newJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +// --- test functions --- + +func newTestContext() (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + return c, rec +} + type openAIAccountTestRepo struct { mockAccountRepoForGemini updatedExtra map[string]any @@ -34,7 +79,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) { gin.SetMode(gin.TestMode) - ctx, recorder := newSoraTestContext() + ctx, recorder := newTestContext() resp := newJSONResponse(http.StatusOK, "") resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"} @@ -68,7 +113,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { gin.SetMode(gin.TestMode) - ctx, _ := newSoraTestContext() + ctx, _ := newTestContext() resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`) resp.Header.Set("x-codex-primary-used-percent", "100") diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go deleted file mode 100644 index 52f832a9c7..0000000000 --- a/backend/internal/service/account_test_service_sora_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package service - -import ( - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -type queuedHTTPUpstream struct { - responses []*http.Response - requests []*http.Request - tlsFlags []bool -} - -func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { - return nil, fmt.Errorf("unexpected Do call") -} - -func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) { - u.requests = append(u.requests, req) - u.tlsFlags = append(u.tlsFlags, profile != nil) - if len(u.responses) == 0 { - return nil, fmt.Errorf("no mocked response") - } - resp := u.responses[0] - u.responses = u.responses[1:] - return resp, nil -} - -func newJSONResponse(status int, body string) *http.Response { - return &http.Response{ - StatusCode: status, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader(body)), - } -} - -func newJSONResponseWithHeader(status int, body, key, value string) *http.Response { - resp := newJSONResponse(status, body) - resp.Header.Set(key, value) - return resp -} - -func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) { - gin.SetMode(gin.TestMode) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) - return c, rec -} - -func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), - newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), - newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), - newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), - }, - } - svc := &AccountTestService{ - httpUpstream: upstream, - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - TLSFingerprint: config.TLSFingerprintConfig{ - Enabled: true, - }, - }, - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - DisableTLSFingerprint: false, - }, - }, - }, - } - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.NoError(t, err) - require.Len(t, upstream.requests, 4) - require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) - require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) - require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) - require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) - require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) - require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) - require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) - - body := rec.Body.String() - require.Contains(t, body, `"type":"test_start"`) - require.Contains(t, body, "Sora connection OK - Email: demo@example.com") - require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") - require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") - require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"status":"success"`) - require.Contains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), - newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), - newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), - newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.NoError(t, err) - require.Len(t, upstream.requests, 4) - body := rec.Body.String() - require.Contains(t, body, "Sora connection OK - User: demo-user") - require.Contains(t, body, "Subscription check returned 403") - require.Contains(t, body, "Sora2 invite check returned 401") - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"status":"partial_success"`) - require.Contains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.Error(t, err) - require.Contains(t, err.Error(), "Cloudflare challenge") - require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") - body := rec.Body.String() - require.Contains(t, body, `"type":"error"`) - require.Contains(t, body, "Cloudflare challenge") - require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") -} - -func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.Error(t, err) - require.Contains(t, err.Error(), "Cloudflare challenge") - require.Contains(t, err.Error(), "HTTP 429") - body := rec.Body.String() - require.Contains(t, body, "Cloudflare challenge") -} - -func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.Error(t, err) - require.Contains(t, err.Error(), "token_invalidated") - body := rec.Body.String() - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"status":"failed"`) - require.Contains(t, body, "token_invalidated") - require.NotContains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), - }, - } - svc := &AccountTestService{ - httpUpstream: upstream, - soraTestCooldown: time.Hour, - } - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c1, _ := newSoraTestContext() - err := svc.testSoraAccountConnection(c1, account) - require.NoError(t, err) - - c2, rec2 := newSoraTestContext() - err = svc.testSoraAccountConnection(c2, account) - require.Error(t, err) - require.Contains(t, err.Error(), "测试过于频繁") - body := rec2.Body.String() - require.Contains(t, body, `"type":"sora_test_result"`) - require.Contains(t, body, `"code":"test_rate_limited"`) - require.Contains(t, body, `"status":"failed"`) - require.NotContains(t, body, `"type":"test_complete","success":true`) -} - -func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { - upstream := &queuedHTTPUpstream{ - responses: []*http.Response{ - newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), - newJSONResponse(http.StatusForbidden, `Just a moment...`), - newJSONResponse(http.StatusForbidden, `Just a moment...`), - }, - } - svc := &AccountTestService{httpUpstream: upstream} - account := &Account{ - ID: 1, - Platform: PlatformSora, - Type: AccountTypeOAuth, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "test_token", - }, - } - - c, rec := newSoraTestContext() - err := svc.testSoraAccountConnection(c, account) - - require.NoError(t, err) - body := rec.Body.String() - require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") - require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") - require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") - require.Contains(t, body, `"type":"test_complete","success":true`) -} - -func TestSanitizeProxyURLForLog(t *testing.T) { - require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080")) - require.Equal(t, "", sanitizeProxyURLForLog("")) - require.Equal(t, "", sanitizeProxyURLForLog("://invalid")) -} - -func TestExtractSoraEgressIPHint(t *testing.T) { - h := make(http.Header) - h.Set("x-openai-public-ip", "203.0.113.10") - require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h)) - - h2 := make(http.Header) - h2.Set("x-envoy-external-address", "198.51.100.9") - require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2)) - - require.Equal(t, "unknown", extractSoraEgressIPHint(nil)) - require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{})) -} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index b6d7d634df..8032f8717a 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -15,7 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/Wei-Shaw/sub2api/internal/util/soraerror" + "github.com/Wei-Shaw/sub2api/internal/util/httputil" ) // AdminService interface defines admin management operations @@ -104,14 +104,13 @@ type AdminService interface { // CreateUserInput represents input for creating a new user via admin operations. type CreateUserInput struct { - Email string - Password string - Username string - Notes string - Balance float64 - Concurrency int - AllowedGroups []int64 - SoraStorageQuotaBytes int64 + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 } type UpdateUserInput struct { @@ -125,8 +124,7 @@ type UpdateUserInput struct { AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 - SoraStorageQuotaBytes *int64 + GroupRates map[int64]*float64 } type CreateGroupInput struct { @@ -140,16 +138,11 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - // Sora 按次计费配置 - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 - SoraVideoPricePerRequestHD *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -158,8 +151,6 @@ type CreateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string - // Sora 存储配额 - SoraStorageQuotaBytes int64 // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch bool DefaultMappedModel string @@ -181,16 +172,11 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - // Sora 按次计费配置 - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 - SoraVideoPricePerRequestHD *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -199,8 +185,6 @@ type UpdateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string - // Sora 存储配额 - SoraStorageQuotaBytes *int64 // OpenAI Messages 调度配置(仅 openai 平台使用) AllowMessagesDispatch *bool DefaultMappedModel *string @@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{ http.StatusOK: {}, }, }, - { - Target: "sora", - URL: "https://sora.chatgpt.com/backend/me", - Method: http.MethodGet, - AllowedStatuses: map[int]struct{}{ - http.StatusUnauthorized: {}, - }, - }, } const ( @@ -448,7 +424,6 @@ type adminServiceImpl struct { userRepo UserRepository groupRepo GroupRepository accountRepo AccountRepository - soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -473,7 +448,6 @@ func NewAdminService( userRepo UserRepository, groupRepo GroupRepository, accountRepo AccountRepository, - soraAccountRepo SoraAccountRepository, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -492,7 +466,6 @@ func NewAdminService( userRepo: userRepo, groupRepo: groupRepo, accountRepo: accountRepo, - soraAccountRepo: soraAccountRepo, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -574,15 +547,14 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, - SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, } if err := user.SetPassword(input.Password); err != nil { return nil, err @@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.AllowedGroups = *input.AllowedGroups } - if input.SoraStorageQuotaBytes != nil { - user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes - } - if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } @@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) - soraImagePrice360 := normalizePrice(input.SoraImagePrice360) - soraImagePrice540 := normalizePrice(input.SoraImagePrice540) - soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) - soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) // 校验降级分组 if input.FallbackGroupID != nil { @@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, - SoraImagePrice360: soraImagePrice360, - SoraImagePrice540: soraImagePrice540, - SoraVideoPricePerRequest: soraVideoPrice, - SoraVideoPricePerRequestHD: soraVideoPriceHD, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, SupportedModelScopes: input.SupportedModelScopes, - SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, AllowMessagesDispatch: input.AllowMessagesDispatch, RequireOAuthOnly: input.RequireOAuthOnly, RequirePrivacySet: input.RequirePrivacySet, @@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } - if input.SoraImagePrice360 != nil { - group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) - } - if input.SoraImagePrice540 != nil { - group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) - } - if input.SoraVideoPricePerRequest != nil { - group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) - } - if input.SoraVideoPricePerRequestHD != nil { - group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) - } - if input.SoraStorageQuotaBytes != nil { - group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes - } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } - // Sora apikey 账号的 base_url 必填校验 - if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { - baseURL, _ := input.Credentials["base_url"].(string) - baseURL = strings.TrimSpace(baseURL) - if baseURL == "" { - return nil, errors.New("sora apikey 账号必须设置 base_url") - } - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") - } - } - account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), @@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return nil, err } - // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录 - if account.Platform == PlatformSora && s.soraAccountRepo != nil { - soraUpdates := map[string]any{ - "access_token": account.GetCredential("access_token"), - "refresh_token": account.GetCredential("refresh_token"), - } - if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { - // 只记录警告日志,不阻塞账号创建 - logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) - } - } - // 绑定分组 if len(groupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { @@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.AutoPauseOnExpired = *input.AutoPauseOnExpired } - // Sora apikey 账号的 base_url 必填校验 - if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { - baseURL, _ := account.Credentials["base_url"].(string) - baseURL = strings.TrimSpace(baseURL) - if baseURL == "" { - return nil, errors.New("sora apikey 账号必须设置 base_url") - } - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") - } - } - // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { @@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox body = body[:proxyQualityMaxBodyBytes] } - if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + // Cloudflare challenge 检测 + if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { item.Status = "challenge" - item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) - item.Message = "Sora 命中 Cloudflare challenge" + item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body) + item.Message = "命中 Cloudflare challenge" return item } diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go index 5a43cd9c29..d3b3f61bc6 100644 --- a/backend/internal/service/admin_service_proxy_quality_test.go +++ b/backend/internal/service/admin_service_proxy_quality_test.go @@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) { require.Contains(t, result.Summary, "挑战 1 项") } -func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { +func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html") w.Header().Set("cf-ray", "test-ray-123") @@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { defer server.Close() target := proxyQualityTarget{ - Target: "sora", + Target: "openai", URL: server.URL, Method: http.MethodGet, AllowedStatuses: map[int]struct{}{ diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index ecaffcbcb7..e3b60a2790 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -5,13 +5,12 @@ package service import ( "bytes" "context" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/stretchr/testify/require" "io" "net/http" "strings" "testing" - - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/stretchr/testify/require" ) // stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock @@ -81,17 +80,12 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI m.responseBodies[respIdx] = bodyBytes } - // 用缓存的 body 字节重建新的 reader - var body io.ReadCloser + // 用缓存的 body 重建 reader(支持重试场景多次读取) + cloned := *resp if m.responseBodies[respIdx] != nil { - body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx])) + cloned.Body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx])) } - - return &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: body, - }, respErr + return &cloned, respErr } func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index e8ad5c9c32..ad6ba0e930 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` - SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index f727ab10f3..64a70e8cf4 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, - SoraImagePrice360: apiKey.Group.SoraImagePrice360, - SoraImagePrice540: apiKey.Group.SoraImagePrice540, - SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, @@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, - SoraImagePrice360: snapshot.Group.SoraImagePrice360, - SoraImagePrice540: snapshot.Group.SoraImagePrice540, - SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 2fe13686a7..763abadbfc 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -808,14 +808,6 @@ type ImagePriceConfig struct { Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } -// SoraPriceConfig Sora 按次计费配置 -type SoraPriceConfig struct { - ImagePrice360 *float64 - ImagePrice540 *float64 - VideoPricePerRequest *float64 - VideoPricePerRequestHD *float64 -} - // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" @@ -846,65 +838,6 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag } } -// CalculateSoraImageCost 计算 Sora 图片按次费用 -func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { - if imageCount <= 0 { - return &CostBreakdown{} - } - - unitPrice := 0.0 - if groupConfig != nil { - switch imageSize { - case "540": - if groupConfig.ImagePrice540 != nil { - unitPrice = *groupConfig.ImagePrice540 - } - default: - if groupConfig.ImagePrice360 != nil { - unitPrice = *groupConfig.ImagePrice360 - } - } - } - - totalCost := unitPrice * float64(imageCount) - if rateMultiplier <= 0 { - rateMultiplier = 1.0 - } - actualCost := totalCost * rateMultiplier - - return &CostBreakdown{ - TotalCost: totalCost, - ActualCost: actualCost, - } -} - -// CalculateSoraVideoCost 计算 Sora 视频按次费用 -func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { - unitPrice := 0.0 - if groupConfig != nil { - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "sora2pro-hd") { - if groupConfig.VideoPricePerRequestHD != nil { - unitPrice = *groupConfig.VideoPricePerRequestHD - } - } - if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil { - unitPrice = *groupConfig.VideoPricePerRequest - } - } - - totalCost := unitPrice - if rateMultiplier <= 0 { - rateMultiplier = 1.0 - } - actualCost := totalCost * rateMultiplier - - return &CostBreakdown{ - TotalCost: totalCost, - ActualCost: actualCost, - } -} - // getImageUnitPrice 获取图片单价 func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { // 优先使用分组配置的价格 diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 1094342285..dd58502c01 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) { require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) } -func TestCalculateSoraVideoCost(t *testing.T) { - svc := newTestBillingService() - - price := 0.5 - cfg := &SoraPriceConfig{VideoPricePerRequest: &price} - cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0) - - require.InDelta(t, 0.5, cost.TotalCost, 1e-10) -} - -func TestCalculateSoraVideoCost_HDModel(t *testing.T) { - svc := newTestBillingService() - - hdPrice := 1.0 - normalPrice := 0.5 - cfg := &SoraPriceConfig{ - VideoPricePerRequest: &normalPrice, - VideoPricePerRequestHD: &hdPrice, - } - cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0) - require.InDelta(t, 1.0, cost.TotalCost, 1e-10) -} func TestIsModelSupported(t *testing.T) { svc := newTestBillingService() @@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) { require.Contains(t, err.Error(), "not initialized") } -func TestCalculateSoraImageCost(t *testing.T) { - svc := newTestBillingService() - - price360 := 0.05 - price540 := 0.08 - cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540} - - cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0) - require.InDelta(t, 0.10, cost.TotalCost, 1e-10) - - cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0) - require.InDelta(t, 0.08, cost540.TotalCost, 1e-10) - require.InDelta(t, 0.16, cost540.ActualCost, 1e-10) -} - -func TestCalculateSoraImageCost_ZeroCount(t *testing.T) { - svc := newTestBillingService() - cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0) - require.Equal(t, 0.0, cost.TotalCost) -} - -func TestCalculateSoraVideoCost_NilConfig(t *testing.T) { - svc := newTestBillingService() - cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0) - require.Equal(t, 0.0, cost.TotalCost) -} - func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) { // 使用空的 fallback prices 让 GetModelPricing 失败 svc := &BillingService{ diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index ecac0db0cf..52df52d656 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -24,7 +24,6 @@ const ( PlatformOpenAI = domain.PlatformOpenAI PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity - PlatformSora = domain.PlatformSora ) // Account type constants @@ -107,7 +106,6 @@ const ( SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" // OEM设置 - SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制) SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 @@ -199,27 +197,6 @@ const ( // SettingKeyBetaPolicySettings stores JSON config for beta policy rules. SettingKeyBetaPolicySettings = "beta_policy_settings" - // ========================= - // Sora S3 存储配置 - // ========================= - - SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储 - SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址 - SettingKeySoraS3Region = "sora_s3_region" // S3 区域 - SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称 - SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID - SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储) - SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀 - SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等) - SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选) - SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON) - - // ========================= - // Sora 用户存储配额 - // ========================= - - SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节) - // ========================= // Claude Code Version Check // ========================= diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a95b62b133..16bfcd8e5a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -60,13 +60,6 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) -// MediaType 媒体类型常量 -const ( - MediaTypeImage = "image" - MediaTypeVideo = "video" - MediaTypePrompt = "prompt" -) - // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} @@ -511,9 +504,6 @@ type ForwardResult struct { ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" - // Sora 媒体字段 - MediaType string // image / video / prompt - MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -1971,9 +1961,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - if platform == PlatformSora { - return s.listSoraSchedulableAccounts(ctx, groupID) - } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -2070,53 +2057,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } -func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { - const useMixed = false - - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } else if groupID != nil { - accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) - } else { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } - if err != nil { - slog.Debug("account_scheduling_list_failed", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "error", err) - return nil, useMixed, err - } - - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform != PlatformSora { - continue - } - if !s.isSoraAccountSchedulable(&acc) { - continue - } - filtered = append(filtered, acc) - } - slog.Debug("account_scheduling_list_sora", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "raw_count", len(accounts), - "filtered_count", len(filtered)) - for _, acc := range filtered { - slog.Debug("account_scheduling_account_detail", - "account_id", acc.ID, - "name", acc.Name, - "platform", acc.Platform, - "type", acc.Type, - "status", acc.Status, - "tls_fingerprint", acc.IsTLSFingerprintEnabled()) - } - return filtered, useMixed, nil -} - // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -2141,33 +2081,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } -func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { - return s.soraUnschedulableReason(account) == "" -} - -func (s *GatewayService) soraUnschedulableReason(account *Account) string { - if account == nil { - return "account_nil" - } - if account.Status != StatusActive { - return fmt.Sprintf("status=%s", account.Status) - } - if !account.Schedulable { - return "schedulable=false" - } - if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { - return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) - } - return "" -} - func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { if account == nil { return false } - if account.Platform == PlatformSora { - return s.isSoraAccountSchedulable(account) - } return account.IsSchedulable() } @@ -2175,12 +2092,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte if account == nil { return false } - if account.Platform == PlatformSora { - if !s.isSoraAccountSchedulable(account) { - return false - } - return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 - } return account.IsSchedulableForModelWithContext(ctx, requestedModel) } @@ -3357,9 +3268,6 @@ func (s *GatewayService) logDetailedSelectionFailure( stats.SampleMappingIDs, stats.SampleRateLimitIDs, ) - if platform == PlatformSora { - s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) - } return stats } @@ -3416,11 +3324,7 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "excluded"} } if !s.isAccountSchedulableForSelection(acc) { - detail := "generic_unschedulable" - if acc.Platform == PlatformSora { - detail = s.soraUnschedulableReason(acc) - } - return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { return selectionFailureDiagnosis{ @@ -3444,57 +3348,6 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } -func (s *GatewayService) logSoraSelectionFailureDetails( - ctx context.Context, - groupID *int64, - sessionHash string, - requestedModel string, - accounts []Account, - excludedIDs map[int64]struct{}, - allowMixedScheduling bool, -) { - const maxLines = 30 - logged := 0 - - for i := range accounts { - if logged >= maxLines { - break - } - acc := &accounts[i] - diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) - if diagnosis.Category == "eligible" { - continue - } - detail := diagnosis.Detail - if detail == "" { - detail = "-" - } - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - acc.ID, - acc.Platform, - diagnosis.Category, - detail, - ) - logged++ - } - if len(accounts) > maxLines { - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - len(accounts), - logged, - ) - } -} - func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3573,9 +3426,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } - if account.Platform == PlatformSora { - return s.isSoraModelSupportedByAccount(account, requestedModel) - } if account.IsBedrock() { _, ok := ResolveBedrockModelID(account, requestedModel) return ok @@ -3588,143 +3438,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } -func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { - if account == nil { - return false - } - if strings.TrimSpace(requestedModel) == "" { - return true - } - - // 先走原始精确/通配符匹配。 - mapping := account.GetModelMapping() - if len(mapping) == 0 || account.IsModelSupported(requestedModel) { - return true - } - - aliases := buildSoraModelAliases(requestedModel) - if len(aliases) == 0 { - return false - } - - hasSoraSelector := false - for pattern := range mapping { - if !isSoraModelSelector(pattern) { - continue - } - hasSoraSelector = true - if matchPatternAnyAlias(pattern, aliases) { - return true - } - } - - // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), - // 此时不应误拦截 Sora 模型请求。 - if !hasSoraSelector { - return true - } - - return false -} - -func matchPatternAnyAlias(pattern string, aliases []string) bool { - normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) - if normalizedPattern == "" { - return false - } - for _, alias := range aliases { - if matchWildcard(normalizedPattern, alias) { - return true - } - } - return false -} - -func isSoraModelSelector(pattern string) bool { - p := strings.ToLower(strings.TrimSpace(pattern)) - if p == "" { - return false - } - - switch { - case strings.HasPrefix(p, "sora"), - strings.HasPrefix(p, "gpt-image"), - strings.HasPrefix(p, "prompt-enhance"), - strings.HasPrefix(p, "sy_"): - return true - } - - return p == "video" || p == "image" -} - -func buildSoraModelAliases(requestedModel string) []string { - modelID := strings.ToLower(strings.TrimSpace(requestedModel)) - if modelID == "" { - return nil - } - - aliases := make([]string, 0, 8) - addAlias := func(value string) { - v := strings.ToLower(strings.TrimSpace(value)) - if v == "" { - return - } - for _, existing := range aliases { - if existing == v { - return - } - } - aliases = append(aliases, v) - } - - addAlias(modelID) - cfg, ok := GetSoraModelConfig(modelID) - if ok { - addAlias(cfg.Model) - switch cfg.Type { - case "video": - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case "image": - addAlias("image") - addAlias("gpt-image") - case "prompt_enhance": - addAlias("prompt-enhance") - } - return aliases - } - - switch { - case strings.HasPrefix(modelID, "sora"): - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case strings.HasPrefix(modelID, "gpt-image"): - addAlias("image") - addAlias("gpt-image") - case strings.HasPrefix(modelID, "prompt-enhance"): - addAlias("prompt-enhance") - default: - return nil - } - - return aliases -} - -func soraVideoFamilyAlias(modelID string) string { - switch { - case strings.HasPrefix(modelID, "sora2pro-hd"): - return "sora2pro-hd" - case strings.HasPrefix(modelID, "sora2pro"): - return "sora2pro" - case strings.HasPrefix(modelID, "sora2"): - return "sora2" - default: - return "" - } -} - // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -7592,9 +7305,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.ImageCount = usageLog.ImageCount - if usageLog.MediaType != nil { - cmd.MediaType = *usageLog.MediaType - } if usageLog.ServiceTier != nil { cmd.ServiceTier = *usageLog.ServiceTier } @@ -7750,8 +7460,6 @@ type recordUsageOpts struct { // EnableClaudePath 启用 Claude 路径特有逻辑: // - Claude Max 缓存计费策略 - // - Sora 媒体类型分支(image/video/prompt) - // - MediaType 字段写入使用日志 EnableClaudePath bool // 长上下文计费(仅 Gemini 路径需要) @@ -7842,7 +7550,6 @@ type recordUsageCoreInput struct { // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // opts 中的字段控制两者之间的差异行为: // - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 -// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt) // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result @@ -7944,16 +7651,6 @@ func (s *GatewayService) calculateRecordUsageCost( multiplier float64, opts *recordUsageOpts, ) *CostBreakdown { - // Sora 媒体类型分支(仅 Claude 路径启用) - if opts.EnableClaudePath { - if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { - return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier) - } - if result.MediaType == MediaTypePrompt { - return &CostBreakdown{} - } - } - // 图片生成计费 if result.ImageCount > 0 { return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) @@ -7963,28 +7660,6 @@ func (s *GatewayService) calculateRecordUsageCost( return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) } -// calculateSoraMediaCost 计算 Sora 图片/视频的费用。 -func (s *GatewayService) calculateSoraMediaCost( - result *ForwardResult, - apiKey *APIKey, - billingModel string, - multiplier float64, -) *CostBreakdown { - var soraConfig *SoraPriceConfig - if apiKey.Group != nil { - soraConfig = &SoraPriceConfig{ - ImagePrice360: apiKey.Group.SoraImagePrice360, - ImagePrice540: apiKey.Group.SoraImagePrice540, - VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - } - } - if result.MediaType == MediaTypeImage { - return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) - } - return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) -} - // resolveChannelPricing 检查指定模型是否存在渠道级别定价。 // 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { @@ -8133,13 +7808,12 @@ func (s *GatewayService) buildRecordUsageLog( RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, - BillingMode: resolveBillingMode(opts, result, cost), + BillingMode: resolveBillingMode(result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: optionalTrimmedStringPtr(result.ImageSize), - MediaType: resolveMediaType(opts, result), CacheTTLOverridden: cacheTTLOverridden, ChannelID: optionalInt64Ptr(input.ChannelID), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), @@ -8163,13 +7837,7 @@ func (s *GatewayService) buildRecordUsageLog( } // resolveBillingMode 根据计费结果和请求类型确定计费模式。 -// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。 -func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { - isSoraMedia := opts.EnableClaudePath && - (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) - if isSoraMedia { - return nil - } +func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { var mode string switch { case cost != nil && cost.BillingMode != "": @@ -8182,13 +7850,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost return &mode } -func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { - if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { - return &result.MediaType - } - return nil -} - func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { return &subscription.ID diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go index 743d70bbbf..ac8c6df67c 100644 --- a/backend/internal/service/gateway_service_selection_failure_stats_test.go +++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go @@ -9,35 +9,35 @@ import ( func TestCollectSelectionFailureStats(t *testing.T) { svc := &GatewayService{} - model := "sora2-landscape-10s" + model := "gpt-5.4" resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339) accounts := []Account{ // excluded { ID: 1, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, }, // unschedulable { ID: 2, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: false, }, // platform filtered { ID: 3, - Platform: PlatformOpenAI, + Platform: PlatformAntigravity, Status: StatusActive, Schedulable: true, }, // model unsupported { ID: 4, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Credentials: map[string]any{ @@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) { // model rate limited { ID: 5, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Extra: map[string]any{ @@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) { // eligible { ID: 6, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, }, } excluded := map[int64]struct{}{1: {}} - stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false) + stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformOpenAI, excluded, false) if stats.Total != 6 { t.Fatalf("total=%d want=6", stats.Total) @@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) { } } -func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) { +func TestDiagnoseSelectionFailure_UnschedulableDetail(t *testing.T) { svc := &GatewayService{} acc := &Account{ ID: 7, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: false, } - diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "gpt-5.4", PlatformOpenAI, map[int64]struct{}{}, false) if diagnosis.Category != "unschedulable" { t.Fatalf("category=%s want=unschedulable", diagnosis.Category) } - if diagnosis.Detail != "schedulable=false" { - t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail) + if diagnosis.Detail != "generic_unschedulable" { + t.Fatalf("detail=%s want=generic_unschedulable", diagnosis.Detail) } } -func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { +func TestDiagnoseSelectionFailure_ModelRateLimitedDetail(t *testing.T) { svc := &GatewayService{} - model := "sora2-landscape-10s" + model := "gpt-5.4" resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) acc := &Account{ ID: 8, - Platform: PlatformSora, + Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Extra: map[string]any{ @@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { }, } - diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false) + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformOpenAI, map[int64]struct{}{}, false) if diagnosis.Category != "model_rate_limited" { t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category) } diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go deleted file mode 100644 index 8ee2a960d1..0000000000 --- a/backend/internal/service/gateway_service_sora_model_support_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package service - -import "testing" - -func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{}, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected sora model to be supported when model_mapping is empty") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "gpt-4o": "gpt-4o", - }, - }, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected sora model to be supported when mapping has no sora selectors") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "sora2": "sora2", - }, - }, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") { - t.Fatalf("expected family selector sora2 to support sora2-landscape-15s") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "sy_8": "sy_8", - }, - }, - } - - if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s") - } -} - -func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) { - svc := &GatewayService{} - account := &Account{ - Platform: PlatformSora, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "gpt-image": "gpt-image", - }, - }, - } - - if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { - t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image") - } -} diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go deleted file mode 100644 index 5178e68e40..0000000000 --- a/backend/internal/service/gateway_service_sora_scheduling_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package service - -import ( - "context" - "testing" - "time" -) - -func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) { - svc := &GatewayService{} - now := time.Now() - past := now.Add(-1 * time.Minute) - future := now.Add(5 * time.Minute) - - acc := &Account{ - Platform: PlatformSora, - Status: StatusActive, - Schedulable: true, - AutoPauseOnExpired: true, - ExpiresAt: &past, - OverloadUntil: &future, - RateLimitResetAt: &future, - } - - if !svc.isAccountSchedulableForSelection(acc) { - t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows") - } -} - -func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) { - svc := &GatewayService{} - future := time.Now().Add(5 * time.Minute) - - acc := &Account{ - Platform: PlatformAnthropic, - Status: StatusActive, - Schedulable: true, - RateLimitResetAt: &future, - } - - if svc.isAccountSchedulableForSelection(acc) { - t.Fatalf("expected non-sora account to keep generic schedulable checks") - } -} - -func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) { - svc := &GatewayService{} - model := "sora2-landscape-10s" - resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) - globalResetAt := time.Now().Add(2 * time.Minute) - - acc := &Account{ - Platform: PlatformSora, - Status: StatusActive, - Schedulable: true, - RateLimitResetAt: &globalResetAt, - Extra: map[string]any{ - "model_rate_limits": map[string]any{ - model: map[string]any{ - "rate_limit_reset_at": resetAt, - }, - }, - }, - } - - if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) { - t.Fatalf("expected sora account to be blocked by model scope rate limit") - } -} - -func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) { - svc := &GatewayService{} - future := time.Now().Add(3 * time.Minute) - - accounts := []Account{ - { - ID: 1, - Platform: PlatformSora, - Status: StatusActive, - Schedulable: true, - RateLimitResetAt: &future, - }, - } - - stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) - if stats.Unschedulable != 0 || stats.Eligible != 1 { - t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible) - } -} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index e0f81a39a4..d59af9e1c0 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,15 +26,6 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 - // Sora 按次计费配置(阶段 1) - SoraImagePrice360 *float64 - SoraImagePrice540 *float64 - SoraVideoPricePerRequest *float64 - SoraVideoPricePerRequestHD *float64 - - // Sora 存储配额 - SoraStorageQuotaBytes int64 - // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 @@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { } } -// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540) -func (g *Group) GetSoraImagePrice(imageSize string) *float64 { - switch imageSize { - case "360": - return g.SoraImagePrice360 - case "540": - return g.SoraImagePrice540 - default: - return g.SoraImagePrice360 - } -} - // IsGroupContextValid reports whether a group from context has the fields required for routing decisions. func IsGroupContextValid(group *Group) bool { if group == nil { diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index f575cd625e..dc094d43ce 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -3,30 +3,15 @@ package service import ( "context" "crypto/subtle" - "encoding/json" - "io" "log/slog" "net/http" - "regexp" - "sort" - "strconv" "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) -var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" - -var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`) - -type soraSessionChunk struct { - index int - value string -} - // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") } -// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id. func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { @@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL) } -// ExchangeSoraSessionToken exchanges Sora session_token to access_token. -func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { - sessionToken = normalizeSoraSessionTokenInput(sessionToken) - if strings.TrimSpace(sessionToken) == "" { - return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") - } - - proxyURL, err := s.resolveProxyURL(ctx, proxyID) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) - if err != nil { - return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) - } - req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) - req.Header.Set("Accept", "application/json") - req.Header.Set("Origin", "https://sora.chatgpt.com") - req.Header.Set("Referer", "https://sora.chatgpt.com/") - req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - - client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: 120 * time.Second, - }) - if err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) - } - resp, err := client.Do(req) - if err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) - } - defer func() { _ = resp.Body.Close() }() - - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if resp.StatusCode != http.StatusOK { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) - } - - var sessionResp struct { - AccessToken string `json:"accessToken"` - Expires string `json:"expires"` - User struct { - Email string `json:"email"` - Name string `json:"name"` - } `json:"user"` - } - if err := json.Unmarshal(body, &sessionResp); err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) - } - if strings.TrimSpace(sessionResp.AccessToken) == "" { - return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") - } - - expiresAt := time.Now().Add(time.Hour).Unix() - if strings.TrimSpace(sessionResp.Expires) != "" { - if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { - expiresAt = parsed.Unix() - } - } - expiresIn := expiresAt - time.Now().Unix() - if expiresIn < 0 { - expiresIn = 0 - } - - return &OpenAITokenInfo{ - AccessToken: strings.TrimSpace(sessionResp.AccessToken), - ExpiresIn: expiresIn, - ExpiresAt: expiresAt, - ClientID: openai.SoraClientID, - Email: strings.TrimSpace(sessionResp.User.Email), - }, nil -} - -func normalizeSoraSessionTokenInput(raw string) string { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return "" - } - - matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1) - if len(matches) == 0 { - return sanitizeSessionToken(trimmed) - } - - chunkMatches := make([]soraSessionChunk, 0, len(matches)) - singleValues := make([]string, 0, len(matches)) - - for _, match := range matches { - if len(match) < 3 { - continue - } - - value := sanitizeSessionToken(match[2]) - if value == "" { - continue - } - - if strings.TrimSpace(match[1]) == "" { - singleValues = append(singleValues, value) - continue - } - - idx, err := strconv.Atoi(strings.TrimSpace(match[1])) - if err != nil || idx < 0 { - continue - } - chunkMatches = append(chunkMatches, soraSessionChunk{ - index: idx, - value: value, - }) - } - - if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" { - return merged - } - - if len(singleValues) > 0 { - return singleValues[len(singleValues)-1] - } - - return "" -} - -func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string { - if len(chunks) == 0 { - return "" - } - - byIndex := make(map[int]string, len(chunks)) - for _, chunk := range chunks { - byIndex[chunk.index] = chunk.value - } - - if _, ok := byIndex[0]; !ok { - return "" - } - if requireComplete { - for idx := 0; idx <= requiredMaxIndex; idx++ { - if _, ok := byIndex[idx]; !ok { - return "" - } - } - } - - orderedIndexes := make([]int, 0, len(byIndex)) - for idx := range byIndex { - orderedIndexes = append(orderedIndexes, idx) - } - sort.Ints(orderedIndexes) - - var builder strings.Builder - for _, idx := range orderedIndexes { - if _, err := builder.WriteString(byIndex[idx]); err != nil { - return "" - } - } - return sanitizeSessionToken(builder.String()) -} - -func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string { - if len(chunks) == 0 { - return "" - } - - requiredMaxIndex := 0 - for _, chunk := range chunks { - if chunk.index > requiredMaxIndex { - requiredMaxIndex = chunk.index - } - } - - groupStarts := make([]int, 0, len(chunks)) - for idx, chunk := range chunks { - if chunk.index == 0 { - groupStarts = append(groupStarts, idx) - } - } - - if len(groupStarts) == 0 { - return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) - } - - for i := len(groupStarts) - 1; i >= 0; i-- { - start := groupStarts[i] - end := len(chunks) - if i+1 < len(groupStarts) { - end = groupStarts[i+1] - } - if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" { - return merged - } - } - - return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) -} - -func sanitizeSessionToken(raw string) string { - token := strings.TrimSpace(raw) - token = strings.Trim(token, "\"'`") - token = strings.TrimSuffix(token, ";") - return strings.TrimSpace(token) -} - -// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +// RefreshAccountToken refreshes token for an OpenAI OAuth account func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + if account.Platform != PlatformOpenAI { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") } if account.Type != AccountTypeOAuth { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") @@ -594,25 +374,6 @@ func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } -func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { - if proxyID == nil { - return "", nil - } - proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) - if err != nil { - return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) - } - if proxy == nil { - return "", nil - } - return proxy.URL(), nil -} - func normalizeOpenAIOAuthPlatform(platform string) string { - switch strings.ToLower(strings.TrimSpace(platform)) { - case PlatformSora: - return openai.OAuthPlatformSora - default: - return openai.OAuthPlatformOpenAI - } + return openai.OAuthPlatformOpenAI } diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go index 5f26903db1..f3b507ca02 100644 --- a/backend/internal/service/openai_oauth_service_auth_url_test.go +++ b/backend/internal/service/openai_oauth_service_auth_url_test.go @@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) { require.True(t, ok) require.Equal(t, openai.ClientID, session.ClientID) } - -// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的 -// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。 -func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) { - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) - defer svc.Stop() - - result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora) - require.NoError(t, err) - require.NotEmpty(t, result.AuthURL) - require.NotEmpty(t, result.SessionID) - - parsed, err := url.Parse(result.AuthURL) - require.NoError(t, err) - q := parsed.Query() - require.Equal(t, openai.ClientID, q.Get("client_id")) - require.Empty(t, q.Get("codex_cli_simplified_flow")) - - session, ok := svc.sessionStore.Get(result.SessionID) - require.True(t, ok) - require.Equal(t, openai.ClientID, session.ClientID) -} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go deleted file mode 100644 index 08da85571c..0000000000 --- a/backend/internal/service/openai_oauth_service_sora_session_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package service - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" - "github.com/stretchr/testify/require" -) - -type openaiOAuthClientNoopStub struct{} - -func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { - return nil, errors.New("not implemented") -} - -func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { - return nil, errors.New("not implemented") -} - -func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { - return nil, errors.New("not implemented") -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) - require.NoError(t, err) - require.NotNil(t, info) - require.Equal(t, "at-token", info.AccessToken) - require.Equal(t, "demo@example.com", info.Email) - require.Greater(t, info.ExpiresAt, int64(0)) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) - require.Error(t, err) - require.Contains(t, err.Error(), "missing access token") -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax" - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := strings.Join([]string{ - "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly", - }, "\n") - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := strings.Join([]string{ - "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly", - "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly", - }, "\n") - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} - -func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodGet, r.Method) - require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) - })) - defer server.Close() - - origin := openAISoraSessionAuthURL - openAISoraSessionAuthURL = server.URL - defer func() { openAISoraSessionAuthURL = origin }() - - svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) - defer svc.Stop() - - raw := strings.Join([]string{ - "set-cookie", - "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/", - "set-cookie", - "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/", - "set-cookie", - "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/", - }, "\n") - info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) - require.NoError(t, err) - require.Equal(t, "at-token", info.AccessToken) -} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 69477ce7fc..e438588edb 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() { // OpenAITokenCache token cache interface. type OpenAITokenCache = GeminiTokenCache -// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts. +// OpenAITokenProvider manages access_token for OpenAI OAuth accounts. type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache @@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou if account == nil { return "", errors.New("account is nil") } - if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth { - return "", errors.New("not an openai/sora oauth account") + if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai oauth account") } cacheKey := OpenAITokenCacheKey(account) @@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou p.metrics.refreshRequests.Add(1) p.metrics.touchNow() - // Sora accounts skip OpenAI OAuth refresh and keep existing token path. - if account.Platform == PlatformSora { - slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew) + if err != nil { + if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { + return "", err + } + slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true - } else { - result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew) - if err != nil { - if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { - return "", err + } else if result.LockHeld { + if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache { + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr } - slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) - p.metrics.refreshFailure.Add(1) - refreshFailed = true - } else if result.LockHeld { - if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache { - p.metrics.lockContention.Add(1) - p.metrics.touchNow() - token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) - if waitErr != nil { - return "", waitErr - } - if strings.TrimSpace(token) != "" { - slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) - return token, nil - } + if strings.TrimSpace(token) != "" { + slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil } - } else if result.Refreshed { - p.metrics.refreshSuccess.Add(1) - account = result.Account - expiresAt = account.GetCredentialAsTime("expires_at") - } else { - account = result.Account - expiresAt = account.GetCredentialAsTime("expires_at") } + } else if result.Refreshed { + p.metrics.refreshSuccess.Add(1) + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") + } else { + account = result.Account + expiresAt = account.GetCredentialAsTime("expires_at") } } else if needsRefresh && p.tokenCache != nil { // Backward-compatible test path when refreshAPI is not injected. diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index 1cd923672a..e81fb46560 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai/sora oauth account") + require.Contains(t, err.Error(), "not an openai oauth account") require.Empty(t, token) } @@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai/sora oauth account") + require.Contains(t, err.Error(), "not an openai oauth account") require.Empty(t, token) } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 1a24bad149..a85efabd6d 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -22,8 +22,6 @@ import ( var ( ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") - ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") - ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") ErrDefaultSubGroupInvalid = infraerrors.BadRequest( "DEFAULT_SUBSCRIPTION_GROUP_INVALID", "default subscription group must exist and be subscription type", @@ -104,7 +102,6 @@ type SettingService struct { defaultSubGroupReader DefaultSubscriptionGroupReader cfg *config.Config onUpdate func() // Callback when settings are updated (for cache invalidation) - onS3Update func() // Callback when Sora S3 settings are updated version string // Application version } @@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyHideCcsImportButton, SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, - SettingKeySoraClientEnabled, SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, @@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, @@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) { s.onUpdate = callback } -// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。 -func (s *SettingService) SetOnS3UpdateCallback(callback func()) { - s.onS3Update = callback -} - // SetVersion sets the application version for injection into public settings func (s *SettingService) SetVersion(version string) { s.version = version @@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` @@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, @@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) - updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints @@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", - SettingKeySoraClientEnabled: "false", SettingKeyCustomMenuItems: "[]", SettingKeyCustomEndpoints: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), @@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", @@ -1583,607 +1568,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } - -type soraS3ProfilesStore struct { - ActiveProfileID string `json:"active_profile_id"` - Items []soraS3ProfileStoreItem `json:"items"` -} - -type soraS3ProfileStoreItem struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置) -func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { - profiles, err := s.ListSoraS3Profiles(ctx) - if err != nil { - return nil, err - } - - activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) - if activeProfile == nil { - return &SoraS3Settings{}, nil - } - - return &SoraS3Settings{ - Enabled: activeProfile.Enabled, - Endpoint: activeProfile.Endpoint, - Region: activeProfile.Region, - Bucket: activeProfile.Bucket, - AccessKeyID: activeProfile.AccessKeyID, - SecretAccessKey: activeProfile.SecretAccessKey, - SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured, - Prefix: activeProfile.Prefix, - ForcePathStyle: activeProfile.ForcePathStyle, - CDNURL: activeProfile.CDNURL, - DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes, - }, nil -} - -// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置) -func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error { - if settings == nil { - return fmt.Errorf("settings cannot be nil") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return err - } - - now := time.Now().UTC().Format(time.RFC3339) - activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID) - if activeIndex < 0 { - activeID := "default" - if hasSoraS3ProfileID(store.Items, activeID) { - activeID = fmt.Sprintf("default-%d", time.Now().Unix()) - } - store.Items = append(store.Items, soraS3ProfileStoreItem{ - ProfileID: activeID, - Name: "Default", - UpdatedAt: now, - }) - store.ActiveProfileID = activeID - activeIndex = len(store.Items) - 1 - } - - active := store.Items[activeIndex] - active.Enabled = settings.Enabled - active.Endpoint = strings.TrimSpace(settings.Endpoint) - active.Region = strings.TrimSpace(settings.Region) - active.Bucket = strings.TrimSpace(settings.Bucket) - active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID) - active.Prefix = strings.TrimSpace(settings.Prefix) - active.ForcePathStyle = settings.ForcePathStyle - active.CDNURL = strings.TrimSpace(settings.CDNURL) - active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0) - if settings.SecretAccessKey != "" { - active.SecretAccessKey = settings.SecretAccessKey - } - active.UpdatedAt = now - store.Items[activeIndex] = active - - return s.persistSoraS3ProfilesStore(ctx, store) -} - -// ListSoraS3Profiles 获取 Sora S3 多配置列表 -func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) { - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - return convertSoraS3ProfilesStore(store), nil -} - -// CreateSoraS3Profile 创建 Sora S3 配置 -func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) { - if profile == nil { - return nil, fmt.Errorf("profile cannot be nil") - } - - profileID := strings.TrimSpace(profile.ProfileID) - if profileID == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - name := strings.TrimSpace(profile.Name) - if name == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - if hasSoraS3ProfileID(store.Items, profileID) { - return nil, ErrSoraS3ProfileExists - } - - now := time.Now().UTC().Format(time.RFC3339) - store.Items = append(store.Items, soraS3ProfileStoreItem{ - ProfileID: profileID, - Name: name, - Enabled: profile.Enabled, - Endpoint: strings.TrimSpace(profile.Endpoint), - Region: strings.TrimSpace(profile.Region), - Bucket: strings.TrimSpace(profile.Bucket), - AccessKeyID: strings.TrimSpace(profile.AccessKeyID), - SecretAccessKey: profile.SecretAccessKey, - Prefix: strings.TrimSpace(profile.Prefix), - ForcePathStyle: profile.ForcePathStyle, - CDNURL: strings.TrimSpace(profile.CDNURL), - DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }) - - if setActive || store.ActiveProfileID == "" { - store.ActiveProfileID = profileID - } - - if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { - return nil, err - } - - profiles := convertSoraS3ProfilesStore(store) - created := findSoraS3ProfileByID(profiles.Items, profileID) - if created == nil { - return nil, ErrSoraS3ProfileNotFound - } - return created, nil -} - -// UpdateSoraS3Profile 更新 Sora S3 配置 -func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) { - if profile == nil { - return nil, fmt.Errorf("profile cannot be nil") - } - - targetID := strings.TrimSpace(profileID) - if targetID == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - - targetIndex := findSoraS3ProfileIndex(store.Items, targetID) - if targetIndex < 0 { - return nil, ErrSoraS3ProfileNotFound - } - - target := store.Items[targetIndex] - name := strings.TrimSpace(profile.Name) - if name == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") - } - target.Name = name - target.Enabled = profile.Enabled - target.Endpoint = strings.TrimSpace(profile.Endpoint) - target.Region = strings.TrimSpace(profile.Region) - target.Bucket = strings.TrimSpace(profile.Bucket) - target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID) - target.Prefix = strings.TrimSpace(profile.Prefix) - target.ForcePathStyle = profile.ForcePathStyle - target.CDNURL = strings.TrimSpace(profile.CDNURL) - target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0) - if profile.SecretAccessKey != "" { - target.SecretAccessKey = profile.SecretAccessKey - } - target.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - store.Items[targetIndex] = target - - if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { - return nil, err - } - - profiles := convertSoraS3ProfilesStore(store) - updated := findSoraS3ProfileByID(profiles.Items, targetID) - if updated == nil { - return nil, ErrSoraS3ProfileNotFound - } - return updated, nil -} - -// DeleteSoraS3Profile 删除 Sora S3 配置 -func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error { - targetID := strings.TrimSpace(profileID) - if targetID == "" { - return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return err - } - - targetIndex := findSoraS3ProfileIndex(store.Items, targetID) - if targetIndex < 0 { - return ErrSoraS3ProfileNotFound - } - - store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...) - if store.ActiveProfileID == targetID { - store.ActiveProfileID = "" - if len(store.Items) > 0 { - store.ActiveProfileID = store.Items[0].ProfileID - } - } - - return s.persistSoraS3ProfilesStore(ctx, store) -} - -// SetActiveSoraS3Profile 设置激活的 Sora S3 配置 -func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) { - targetID := strings.TrimSpace(profileID) - if targetID == "" { - return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") - } - - store, err := s.loadSoraS3ProfilesStore(ctx) - if err != nil { - return nil, err - } - - targetIndex := findSoraS3ProfileIndex(store.Items, targetID) - if targetIndex < 0 { - return nil, ErrSoraS3ProfileNotFound - } - - store.ActiveProfileID = targetID - store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339) - if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { - return nil, err - } - - profiles := convertSoraS3ProfilesStore(store) - active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) - if active == nil { - return nil, ErrSoraS3ProfileNotFound - } - return active, nil -} - -func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) { - raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles) - if err == nil { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return &soraS3ProfilesStore{}, nil - } - var store soraS3ProfilesStore - if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil { - legacy, legacyErr := s.getLegacySoraS3Settings(ctx) - if legacyErr != nil { - return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr) - } - if isEmptyLegacySoraS3Settings(legacy) { - return &soraS3ProfilesStore{}, nil - } - now := time.Now().UTC().Format(time.RFC3339) - return &soraS3ProfilesStore{ - ActiveProfileID: "default", - Items: []soraS3ProfileStoreItem{ - { - ProfileID: "default", - Name: "Default", - Enabled: legacy.Enabled, - Endpoint: strings.TrimSpace(legacy.Endpoint), - Region: strings.TrimSpace(legacy.Region), - Bucket: strings.TrimSpace(legacy.Bucket), - AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), - SecretAccessKey: legacy.SecretAccessKey, - Prefix: strings.TrimSpace(legacy.Prefix), - ForcePathStyle: legacy.ForcePathStyle, - CDNURL: strings.TrimSpace(legacy.CDNURL), - DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }, - }, - }, nil - } - normalized := normalizeSoraS3ProfilesStore(store) - return &normalized, nil - } - - if !errors.Is(err, ErrSettingNotFound) { - return nil, fmt.Errorf("get sora s3 profiles: %w", err) - } - - legacy, legacyErr := s.getLegacySoraS3Settings(ctx) - if legacyErr != nil { - return nil, legacyErr - } - if isEmptyLegacySoraS3Settings(legacy) { - return &soraS3ProfilesStore{}, nil - } - - now := time.Now().UTC().Format(time.RFC3339) - return &soraS3ProfilesStore{ - ActiveProfileID: "default", - Items: []soraS3ProfileStoreItem{ - { - ProfileID: "default", - Name: "Default", - Enabled: legacy.Enabled, - Endpoint: strings.TrimSpace(legacy.Endpoint), - Region: strings.TrimSpace(legacy.Region), - Bucket: strings.TrimSpace(legacy.Bucket), - AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), - SecretAccessKey: legacy.SecretAccessKey, - Prefix: strings.TrimSpace(legacy.Prefix), - ForcePathStyle: legacy.ForcePathStyle, - CDNURL: strings.TrimSpace(legacy.CDNURL), - DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), - UpdatedAt: now, - }, - }, - }, nil -} - -func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error { - if store == nil { - return fmt.Errorf("sora s3 profiles store cannot be nil") - } - - normalized := normalizeSoraS3ProfilesStore(*store) - data, err := json.Marshal(normalized) - if err != nil { - return fmt.Errorf("marshal sora s3 profiles: %w", err) - } - - updates := map[string]string{ - SettingKeySoraS3Profiles: string(data), - } - - active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID) - if active == nil { - updates[SettingKeySoraS3Enabled] = "false" - updates[SettingKeySoraS3Endpoint] = "" - updates[SettingKeySoraS3Region] = "" - updates[SettingKeySoraS3Bucket] = "" - updates[SettingKeySoraS3AccessKeyID] = "" - updates[SettingKeySoraS3Prefix] = "" - updates[SettingKeySoraS3ForcePathStyle] = "false" - updates[SettingKeySoraS3CDNURL] = "" - updates[SettingKeySoraDefaultStorageQuotaBytes] = "0" - updates[SettingKeySoraS3SecretAccessKey] = "" - } else { - updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled) - updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint) - updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region) - updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket) - updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID) - updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix) - updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle) - updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL) - updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10) - updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey - } - - if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { - return err - } - - if s.onUpdate != nil { - s.onUpdate() - } - if s.onS3Update != nil { - s.onS3Update() - } - return nil -} - -func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { - keys := []string{ - SettingKeySoraS3Enabled, - SettingKeySoraS3Endpoint, - SettingKeySoraS3Region, - SettingKeySoraS3Bucket, - SettingKeySoraS3AccessKeyID, - SettingKeySoraS3SecretAccessKey, - SettingKeySoraS3Prefix, - SettingKeySoraS3ForcePathStyle, - SettingKeySoraS3CDNURL, - SettingKeySoraDefaultStorageQuotaBytes, - } - - values, err := s.settingRepo.GetMultiple(ctx, keys) - if err != nil { - return nil, fmt.Errorf("get legacy sora s3 settings: %w", err) - } - - result := &SoraS3Settings{ - Enabled: values[SettingKeySoraS3Enabled] == "true", - Endpoint: values[SettingKeySoraS3Endpoint], - Region: values[SettingKeySoraS3Region], - Bucket: values[SettingKeySoraS3Bucket], - AccessKeyID: values[SettingKeySoraS3AccessKeyID], - SecretAccessKey: values[SettingKeySoraS3SecretAccessKey], - SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "", - Prefix: values[SettingKeySoraS3Prefix], - ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true", - CDNURL: values[SettingKeySoraS3CDNURL], - } - if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil { - result.DefaultStorageQuotaBytes = v - } - return result, nil -} - -func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore { - seen := make(map[string]struct{}, len(store.Items)) - normalized := soraS3ProfilesStore{ - ActiveProfileID: strings.TrimSpace(store.ActiveProfileID), - Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)), - } - now := time.Now().UTC().Format(time.RFC3339) - - for idx := range store.Items { - item := store.Items[idx] - item.ProfileID = strings.TrimSpace(item.ProfileID) - if item.ProfileID == "" { - item.ProfileID = fmt.Sprintf("profile-%d", idx+1) - } - if _, exists := seen[item.ProfileID]; exists { - continue - } - seen[item.ProfileID] = struct{}{} - - item.Name = strings.TrimSpace(item.Name) - if item.Name == "" { - item.Name = item.ProfileID - } - item.Endpoint = strings.TrimSpace(item.Endpoint) - item.Region = strings.TrimSpace(item.Region) - item.Bucket = strings.TrimSpace(item.Bucket) - item.AccessKeyID = strings.TrimSpace(item.AccessKeyID) - item.Prefix = strings.TrimSpace(item.Prefix) - item.CDNURL = strings.TrimSpace(item.CDNURL) - item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0) - item.UpdatedAt = strings.TrimSpace(item.UpdatedAt) - if item.UpdatedAt == "" { - item.UpdatedAt = now - } - normalized.Items = append(normalized.Items, item) - } - - if len(normalized.Items) == 0 { - normalized.ActiveProfileID = "" - return normalized - } - - if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 { - return normalized - } - - normalized.ActiveProfileID = normalized.Items[0].ProfileID - return normalized -} - -func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList { - if store == nil { - return &SoraS3ProfileList{} - } - items := make([]SoraS3Profile, 0, len(store.Items)) - for idx := range store.Items { - item := store.Items[idx] - items = append(items, SoraS3Profile{ - ProfileID: item.ProfileID, - Name: item.Name, - IsActive: item.ProfileID == store.ActiveProfileID, - Enabled: item.Enabled, - Endpoint: item.Endpoint, - Region: item.Region, - Bucket: item.Bucket, - AccessKeyID: item.AccessKeyID, - SecretAccessKey: item.SecretAccessKey, - SecretAccessKeyConfigured: item.SecretAccessKey != "", - Prefix: item.Prefix, - ForcePathStyle: item.ForcePathStyle, - CDNURL: item.CDNURL, - DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes, - UpdatedAt: item.UpdatedAt, - }) - } - return &SoraS3ProfileList{ - ActiveProfileID: store.ActiveProfileID, - Items: items, - } -} - -func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile { - for idx := range items { - if items[idx].ProfileID == activeProfileID { - return &items[idx] - } - } - if len(items) == 0 { - return nil - } - return &items[0] -} - -func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile { - for idx := range items { - if items[idx].ProfileID == profileID { - return &items[idx] - } - } - return nil -} - -func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem { - for idx := range items { - if items[idx].ProfileID == activeProfileID { - return &items[idx] - } - } - if len(items) == 0 { - return nil - } - return &items[0] -} - -func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int { - for idx := range items { - if items[idx].ProfileID == profileID { - return idx - } - } - return -1 -} - -func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool { - return findSoraS3ProfileIndex(items, profileID) >= 0 -} - -func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool { - if settings == nil { - return true - } - if settings.Enabled { - return false - } - if strings.TrimSpace(settings.Endpoint) != "" { - return false - } - if strings.TrimSpace(settings.Region) != "" { - return false - } - if strings.TrimSpace(settings.Bucket) != "" { - return false - } - if strings.TrimSpace(settings.AccessKeyID) != "" { - return false - } - if settings.SecretAccessKey != "" { - return false - } - if strings.TrimSpace(settings.Prefix) != "" { - return false - } - if strings.TrimSpace(settings.CDNURL) != "" { - return false - } - return settings.DefaultStorageQuotaBytes == 0 -} - -func maxInt64(value int64, min int64) int64 { - if value < min { - return min - } - return value -} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 411939bb62..4b64267fec 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -41,7 +41,6 @@ type SystemSettings struct { HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string - SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints @@ -107,7 +106,6 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string - SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints @@ -116,46 +114,6 @@ type PublicSettings struct { Version string } -// SoraS3Settings Sora S3 存储配置 -type SoraS3Settings struct { - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端 - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` -} - -// SoraS3Profile Sora S3 多配置项(服务内部模型) -type SoraS3Profile struct { - ProfileID string `json:"profile_id"` - Name string `json:"name"` - IsActive bool `json:"is_active"` - Enabled bool `json:"enabled"` - Endpoint string `json:"endpoint"` - Region string `json:"region"` - Bucket string `json:"bucket"` - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端 - SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 - Prefix string `json:"prefix"` - ForcePathStyle bool `json:"force_path_style"` - CDNURL string `json:"cdn_url"` - DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` - UpdatedAt string `json:"updated_at"` -} - -// SoraS3ProfileList Sora S3 多配置列表 -type SoraS3ProfileList struct { - ActiveProfileID string `json:"active_profile_id"` - Items []SoraS3Profile `json:"items"` -} - // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go deleted file mode 100644 index eccc1acff7..0000000000 --- a/backend/internal/service/sora_account_service.go +++ /dev/null @@ -1,40 +0,0 @@ -package service - -import "context" - -// SoraAccountRepository Sora 账号扩展表仓储接口 -// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。 -// -// 设计说明: -// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本 -// - Sora gateway 优先读取此表的字段以获得更好的查询性能 -// - 主表 accounts 通过 credentials JSON 字段也存储相同信息 -// - Token 刷新时需要同时更新两个表以保持数据一致性 -type SoraAccountRepository interface { - // Upsert 创建或更新 Sora 账号扩展信息 - // accountID: 关联的 accounts.id - // updates: 要更新的字段,支持 access_token、refresh_token、session_token - // - // 如果记录不存在则创建,存在则更新。 - // 用于: - // 1. 创建 Sora 账号时初始化扩展表 - // 2. Token 刷新时同步更新扩展表 - Upsert(ctx context.Context, accountID int64, updates map[string]any) error - - // GetByAccountID 根据账号 ID 获取 Sora 扩展信息 - // 返回 nil, nil 表示记录不存在(非错误) - GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) - - // Delete 删除 Sora 账号扩展信息 - // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理 - Delete(ctx context.Context, accountID int64) error -} - -// SoraAccount Sora 账号扩展信息 -// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本 -type SoraAccount struct { - AccountID int64 // 关联的 accounts.id - AccessToken string // OAuth access_token - RefreshToken string // OAuth refresh_token - SessionToken string // Session token(可选,用于 ST→AT 兜底) -} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go deleted file mode 100644 index 0a914d2db4..0000000000 --- a/backend/internal/service/sora_client.go +++ /dev/null @@ -1,117 +0,0 @@ -package service - -import ( - "context" - "fmt" - "net/http" -) - -// SoraClient 定义直连 Sora 的任务操作接口。 -type SoraClient interface { - Enabled() bool - UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) - CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) - CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) - CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) - UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) - GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) - DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) - UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) - FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) - SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error - DeleteCharacter(ctx context.Context, account *Account, characterID string) error - PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) - DeletePost(ctx context.Context, account *Account, postID string) error - GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) - EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) - GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) - GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) -} - -// SoraImageRequest 图片生成请求参数 -type SoraImageRequest struct { - Prompt string - Width int - Height int - MediaID string -} - -// SoraVideoRequest 视频生成请求参数 -type SoraVideoRequest struct { - Prompt string - Orientation string - Frames int - Model string - Size string - VideoCount int - MediaID string - RemixTargetID string - CameoIDs []string -} - -// SoraStoryboardRequest 分镜视频生成请求参数 -type SoraStoryboardRequest struct { - Prompt string - Orientation string - Frames int - Model string - Size string - MediaID string -} - -// SoraImageTaskStatus 图片任务状态 -type SoraImageTaskStatus struct { - ID string - Status string - ProgressPct float64 - URLs []string - ErrorMsg string -} - -// SoraVideoTaskStatus 视频任务状态 -type SoraVideoTaskStatus struct { - ID string - Status string - ProgressPct int - URLs []string - GenerationID string - ErrorMsg string -} - -// SoraCameoStatus 角色处理中间态 -type SoraCameoStatus struct { - Status string - StatusMessage string - DisplayNameHint string - UsernameHint string - ProfileAssetURL string - InstructionSetHint any - InstructionSet any -} - -// SoraCharacterFinalizeRequest 角色定稿请求参数 -type SoraCharacterFinalizeRequest struct { - CameoID string - Username string - DisplayName string - ProfileAssetPointer string - InstructionSet any -} - -// SoraUpstreamError 上游错误 -type SoraUpstreamError struct { - StatusCode int - Message string - Headers http.Header - Body []byte -} - -func (e *SoraUpstreamError) Error() string { - if e == nil { - return "sora upstream error" - } - if e.Message != "" { - return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) - } - return fmt.Sprintf("sora upstream error: %d", e.StatusCode) -} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go deleted file mode 100644 index e9d325f452..0000000000 --- a/backend/internal/service/sora_gateway_service.go +++ /dev/null @@ -1,1559 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "math" - "math/rand" - "mime" - "net" - "net/http" - "net/url" - "regexp" - "strconv" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - "github.com/gin-gonic/gin" -) - -const soraImageInputMaxBytes = 20 << 20 -const soraImageInputMaxRedirects = 3 -const soraImageInputTimeout = 20 * time.Second -const soraVideoInputMaxBytes = 200 << 20 -const soraVideoInputMaxRedirects = 3 -const soraVideoInputTimeout = 60 * time.Second - -var soraImageSizeMap = map[string]string{ - "gpt-image": "360", - "gpt-image-landscape": "540", - "gpt-image-portrait": "540", -} - -var soraBlockedHostnames = map[string]struct{}{ - "localhost": {}, - "localhost.localdomain": {}, - "metadata.google.internal": {}, - "metadata.google.internal.": {}, -} - -var soraBlockedCIDRs = mustParseCIDRs([]string{ - "0.0.0.0/8", - "10.0.0.0/8", - "100.64.0.0/10", - "127.0.0.0/8", - "169.254.0.0/16", - "172.16.0.0/12", - "192.168.0.0/16", - "224.0.0.0/4", - "240.0.0.0/4", - "::/128", - "::1/128", - "fc00::/7", - "fe80::/10", -}) - -// SoraGatewayService handles forwarding requests to Sora upstream. -type SoraGatewayService struct { - soraClient SoraClient - rateLimitService *RateLimitService - httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传 - cfg *config.Config -} - -type soraWatermarkOptions struct { - Enabled bool - ParseMethod string - ParseURL string - ParseToken string - FallbackOnFailure bool - DeletePost bool -} - -type soraCharacterOptions struct { - SetPublic bool - DeleteAfterGenerate bool -} - -type soraCharacterFlowResult struct { - CameoID string - CharacterID string - Username string - DisplayName string -} - -var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) -var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) -var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) -var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) - -type soraPreflightChecker interface { - PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error -} - -func NewSoraGatewayService( - soraClient SoraClient, - rateLimitService *RateLimitService, - httpUpstream HTTPUpstream, - cfg *config.Config, -) *SoraGatewayService { - return &SoraGatewayService{ - soraClient: soraClient, - rateLimitService: rateLimitService, - httpUpstream: httpUpstream, - cfg: cfg, - } -} - -func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { - startTime := time.Now() - - // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient - if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" { - if s.httpUpstream == nil { - s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream) - return nil, errors.New("httpUpstream not configured for sora apikey forwarding") - } - return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime) - } - - if s.soraClient == nil || !s.soraClient.Enabled() { - if c != nil { - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "type": "api_error", - "message": "Sora 上游未配置", - }, - }) - } - return nil, errors.New("sora upstream not configured") - } - - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) - return nil, fmt.Errorf("parse request: %w", err) - } - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - if strings.TrimSpace(reqModel) == "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) - return nil, errors.New("model is required") - } - originalModel := reqModel - - mappedModel := account.GetMappedModel(reqModel) - var upstreamModel string - if mappedModel != "" && mappedModel != reqModel { - reqModel = mappedModel - upstreamModel = mappedModel - } - - modelCfg, ok := GetSoraModelConfig(reqModel) - if !ok { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) - return nil, fmt.Errorf("unsupported model: %s", reqModel) - } - prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) - prompt = strings.TrimSpace(prompt) - imageInput = strings.TrimSpace(imageInput) - videoInput = strings.TrimSpace(videoInput) - remixTargetID = strings.TrimSpace(remixTargetID) - - if videoInput != "" && modelCfg.Type != "video" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) - return nil, errors.New("video input only supports video models") - } - if videoInput != "" && imageInput != "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) - return nil, errors.New("image input and video input cannot be used together") - } - characterOnly := videoInput != "" && prompt == "" - if modelCfg.Type == "prompt_enhance" && prompt == "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) - return nil, errors.New("prompt is required") - } - if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) - return nil, errors.New("prompt is required") - } - - reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) - if cancel != nil { - defer cancel() - } - if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { - if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - } - - if modelCfg.Type == "prompt_enhance" { - enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - content := strings.TrimSpace(enhancedPrompt) - if content == "" { - content = prompt - } - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) - } - return &ForwardResult{ - RequestID: "", - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", - }, nil - } - - characterOpts := parseSoraCharacterOptions(reqBody) - watermarkOpts := parseSoraWatermarkOptions(reqBody) - var characterResult *soraCharacterFlowResult - if videoInput != "" { - videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) - if videoErr != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) - return nil, videoErr - } - characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) - if videoErr != nil { - return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) - } - if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { - characterID := strings.TrimSpace(characterResult.CharacterID) - defer func() { - cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) - defer cancelCleanup() - if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { - log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) - } - }() - } - if characterOnly { - content := "角色创建成功" - if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { - content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) - } - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - resp := buildSoraNonStreamResponse(content, reqModel) - if characterResult != nil { - resp["character_id"] = characterResult.CharacterID - resp["cameo_id"] = characterResult.CameoID - resp["character_username"] = characterResult.Username - resp["character_display_name"] = characterResult.DisplayName - } - c.JSON(http.StatusOK, resp) - } - return &ForwardResult{ - RequestID: "", - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", - }, nil - } - if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { - prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) - } - } - - var imageData []byte - imageFilename := "" - if imageInput != "" { - decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) - if err != nil { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) - return nil, err - } - imageData = decoded - imageFilename = filename - } - - mediaID := "" - if len(imageData) > 0 { - uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - mediaID = uploadID - } - - taskID := "" - var err error - videoCount := parseSoraVideoCount(reqBody) - switch modelCfg.Type { - case "image": - taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ - Prompt: prompt, - Width: modelCfg.Width, - Height: modelCfg.Height, - MediaID: mediaID, - }) - case "video": - if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { - taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ - Prompt: formatSoraStoryboardPrompt(prompt), - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - MediaID: mediaID, - }) - } else { - taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ - Prompt: prompt, - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - VideoCount: videoCount, - MediaID: mediaID, - RemixTargetID: remixTargetID, - CameoIDs: extractSoraCameoIDs(reqBody), - }) - } - default: - err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) - } - if err != nil { - return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) - } - - if clientStream && c != nil { - s.prepareSoraStream(c, taskID) - } - - var mediaURLs []string - videoGenerationID := "" - mediaType := modelCfg.Type - imageCount := 0 - imageSize := "" - switch modelCfg.Type { - case "image": - urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) - if pollErr != nil { - return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) - } - mediaURLs = urls - imageCount = len(urls) - imageSize = soraImageSizeFromModel(reqModel) - case "video": - videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) - if pollErr != nil { - return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) - } - if videoStatus != nil { - mediaURLs = videoStatus.URLs - videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) - } - default: - mediaType = "prompt" - } - - watermarkPostID := "" - if modelCfg.Type == "video" && watermarkOpts.Enabled { - watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) - if watermarkErr != nil { - if !watermarkOpts.FallbackOnFailure { - return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) - } - log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) - } else if strings.TrimSpace(watermarkURL) != "" { - mediaURLs = []string{strings.TrimSpace(watermarkURL)} - watermarkPostID = strings.TrimSpace(postID) - } - } - - // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。 - // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。 - finalURLs := s.normalizeSoraMediaURLs(mediaURLs) - if watermarkPostID != "" && watermarkOpts.DeletePost { - if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { - log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) - } - } - - content := buildSoraContent(mediaType, finalURLs) - var firstTokenMs *int - if clientStream { - ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) - if streamErr != nil { - return nil, streamErr - } - firstTokenMs = ms - } else if c != nil { - response := buildSoraNonStreamResponse(content, reqModel) - if len(finalURLs) > 0 { - response["media_url"] = finalURLs[0] - if len(finalURLs) > 1 { - response["media_urls"] = finalURLs - } - } - c.JSON(http.StatusOK, response) - } - - return &ForwardResult{ - RequestID: taskID, - Model: originalModel, - UpstreamModel: upstreamModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: mediaType, - MediaURL: firstMediaURL(finalURLs), - ImageCount: imageCount, - ImageSize: imageSize, - }, nil -} - -func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { - if s == nil || s.cfg == nil { - return ctx, nil - } - timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds - if stream { - timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds - } - if timeoutSeconds <= 0 { - return ctx, nil - } - return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) -} - -func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { - opts := soraWatermarkOptions{ - Enabled: parseBoolWithDefault(body, "watermark_free", false), - ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), - ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), - ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), - FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), - DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), - } - if opts.ParseMethod == "" { - opts.ParseMethod = "third_party" - } - return opts -} - -func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { - return soraCharacterOptions{ - SetPublic: parseBoolWithDefault(body, "character_set_public", true), - DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), - } -} - -func parseSoraVideoCount(body map[string]any) int { - if body == nil { - return 1 - } - keys := []string{"video_count", "videos", "n_variants"} - for _, key := range keys { - count := parseIntWithDefault(body, key, 0) - if count > 0 { - return clampInt(count, 1, 3) - } - } - return 1 -} - -func parseBoolWithDefault(body map[string]any, key string, def bool) bool { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - switch typed := val.(type) { - case bool: - return typed - case int: - return typed != 0 - case int32: - return typed != 0 - case int64: - return typed != 0 - case float64: - return typed != 0 - case string: - typed = strings.ToLower(strings.TrimSpace(typed)) - if typed == "true" || typed == "1" || typed == "yes" { - return true - } - if typed == "false" || typed == "0" || typed == "no" { - return false - } - } - return def -} - -func parseStringWithDefault(body map[string]any, key, def string) string { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - if str, ok := val.(string); ok { - return str - } - return def -} - -func parseIntWithDefault(body map[string]any, key string, def int) int { - if body == nil { - return def - } - val, ok := body[key] - if !ok { - return def - } - switch typed := val.(type) { - case int: - return typed - case int32: - return int(typed) - case int64: - return int(typed) - case float64: - return int(typed) - case string: - parsed, err := strconv.Atoi(strings.TrimSpace(typed)) - if err == nil { - return parsed - } - } - return def -} - -func clampInt(v, minVal, maxVal int) int { - if v < minVal { - return minVal - } - if v > maxVal { - return maxVal - } - return v -} - -func extractSoraCameoIDs(body map[string]any) []string { - if body == nil { - return nil - } - raw, ok := body["cameo_ids"] - if !ok { - return nil - } - switch typed := raw.(type) { - case []string: - out := make([]string, 0, len(typed)) - for _, item := range typed { - item = strings.TrimSpace(item) - if item != "" { - out = append(out, item) - } - } - return out - case []any: - out := make([]string, 0, len(typed)) - for _, item := range typed { - str, ok := item.(string) - if !ok { - continue - } - str = strings.TrimSpace(str) - if str != "" { - out = append(out, str) - } - } - return out - default: - return nil - } -} - -func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { - cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) - if err != nil { - return nil, err - } - - cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) - if err != nil { - return nil, err - } - username := processSoraCharacterUsername(cameoStatus.UsernameHint) - displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) - if displayName == "" { - displayName = "Character" - } - profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) - if profileAssetURL == "" { - return nil, errors.New("profile asset url not found in cameo status") - } - - avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) - if err != nil { - return nil, err - } - assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) - if err != nil { - return nil, err - } - instructionSet := cameoStatus.InstructionSetHint - if instructionSet == nil { - instructionSet = cameoStatus.InstructionSet - } - - characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ - CameoID: strings.TrimSpace(cameoID), - Username: username, - DisplayName: displayName, - ProfileAssetPointer: assetPointer, - InstructionSet: instructionSet, - }) - if err != nil { - return nil, err - } - - if opts.SetPublic { - if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { - return nil, err - } - } - - return &soraCharacterFlowResult{ - CameoID: strings.TrimSpace(cameoID), - CharacterID: strings.TrimSpace(characterID), - Username: strings.TrimSpace(username), - DisplayName: displayName, - }, nil -} - -func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { - timeout := 10 * time.Minute - interval := 5 * time.Second - maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) - if maxAttempts < 1 { - maxAttempts = 1 - } - - var lastErr error - consecutiveErrors := 0 - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) - if err != nil { - lastErr = err - consecutiveErrors++ - if consecutiveErrors >= 3 { - break - } - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - continue - } - consecutiveErrors = 0 - if status == nil { - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - continue - } - currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) - statusMessage := strings.TrimSpace(status.StatusMessage) - if currentStatus == "failed" { - if statusMessage == "" { - statusMessage = "character creation failed" - } - return nil, errors.New(statusMessage) - } - if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { - return status, nil - } - if attempt < maxAttempts-1 { - if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { - return nil, sleepErr - } - } - } - if lastErr != nil { - return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) - } - return nil, errors.New("cameo processing timeout") -} - -func processSoraCharacterUsername(usernameHint string) string { - usernameHint = strings.TrimSpace(usernameHint) - if usernameHint == "" { - usernameHint = "character" - } - if strings.Contains(usernameHint, ".") { - parts := strings.Split(usernameHint, ".") - usernameHint = strings.TrimSpace(parts[len(parts)-1]) - } - if usernameHint == "" { - usernameHint = "character" - } - return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100) -} - -func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { - generationID = strings.TrimSpace(generationID) - if generationID == "" { - return "", "", errors.New("generation id is required for watermark-free mode") - } - postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) - if err != nil { - return "", "", err - } - postID = strings.TrimSpace(postID) - if postID == "" { - return "", "", errors.New("watermark-free publish returned empty post id") - } - - switch opts.ParseMethod { - case "custom": - urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) - if parseErr != nil { - return "", postID, parseErr - } - return strings.TrimSpace(urlVal), postID, nil - case "", "third_party": - return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil - default: - return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) - } -} - -func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { - switch statusCode { - case 401, 402, 403, 404, 429, 529: - return true - default: - return statusCode >= 500 - } -} - -func buildSoraNonStreamResponse(content, model string) map[string]any { - return map[string]any{ - "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "message": map[string]any{ - "role": "assistant", - "content": content, - }, - "finish_reason": "stop", - }, - }, - } -} - -func soraImageSizeFromModel(model string) string { - modelLower := strings.ToLower(model) - if size, ok := soraImageSizeMap[modelLower]; ok { - return size - } - if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { - return "540" - } - return "360" -} - -func soraProErrorMessage(model, upstreamMsg string) string { - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "sora2pro-hd") { - return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" - } - if strings.Contains(modelLower, "sora2pro") { - return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" - } - return "" -} - -func firstMediaURL(urls []string) string { - if len(urls) == 0 { - return "" - } - return urls[0] -} - -func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { - if path == "" { - return path - } - prefix := "/sora/media" - values := url.Values{} - if rawQuery != "" { - if parsed, err := url.ParseQuery(rawQuery); err == nil { - values = parsed - } - } - - signKey := "" - ttlSeconds := 0 - if s != nil && s.cfg != nil { - signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) - ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds - } - values.Del("sig") - values.Del("expires") - signingQuery := values.Encode() - if signKey != "" && ttlSeconds > 0 { - expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() - signature := SignSoraMediaURL(path, signingQuery, expires, signKey) - if signature != "" { - values.Set("expires", strconv.FormatInt(expires, 10)) - values.Set("sig", signature) - prefix = "/sora/media-signed" - } - } - - encoded := values.Encode() - if encoded == "" { - return prefix + path - } - return prefix + path + "?" + encoded -} - -func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { - if c == nil { - return - } - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - if strings.TrimSpace(requestID) != "" { - c.Header("x-request-id", requestID) - } -} - -func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { - if c == nil { - return nil, nil - } - writer := c.Writer - flusher, _ := writer.(http.Flusher) - - chunk := map[string]any{ - "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "delta": map[string]any{ - "content": content, - }, - }, - }, - } - encoded, _ := jsonMarshalRaw(chunk) - if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { - return nil, err - } - if flusher != nil { - flusher.Flush() - } - ms := int(time.Since(startTime).Milliseconds()) - finalChunk := map[string]any{ - "id": chunk["id"], - "object": "chat.completion.chunk", - "created": time.Now().Unix(), - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "delta": map[string]any{}, - "finish_reason": "stop", - }, - }, - } - finalEncoded, _ := jsonMarshalRaw(finalChunk) - if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { - return &ms, err - } - if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { - return &ms, err - } - if flusher != nil { - flusher.Flush() - } - return &ms, nil -} - -func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { - if c == nil { - return - } - if stream { - flusher, _ := c.Writer.(http.Flusher) - errorData := map[string]any{ - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) - _, _ = fmt.Fprint(c.Writer, errorEvent) - _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") - if flusher != nil { - flusher.Flush() - } - return - } - c.JSON(status, gin.H{ - "error": gin.H{ - "type": errType, - "message": message, - }, - }) -} - -func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { - if err == nil { - return nil - } - var upstreamErr *SoraUpstreamError - if errors.As(err, &upstreamErr) { - accountID := int64(0) - if account != nil { - accountID = account.ID - } - logger.LegacyPrintf( - "service.sora", - "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s", - accountID, - model, - upstreamErr.StatusCode, - strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")), - strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")), - strings.TrimSpace(upstreamErr.Message), - truncateForLog(upstreamErr.Body, 1024), - ) - if s.rateLimitService != nil && account != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) - } - if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { - var responseHeaders http.Header - if upstreamErr.Headers != nil { - responseHeaders = upstreamErr.Headers.Clone() - } - return &UpstreamFailoverError{ - StatusCode: upstreamErr.StatusCode, - ResponseBody: upstreamErr.Body, - ResponseHeaders: responseHeaders, - } - } - msg := upstreamErr.Message - if override := soraProErrorMessage(model, msg); override != "" { - msg = override - } - s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) - return err - } - if errors.Is(err, context.DeadlineExceeded) { - s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) - return err - } - s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) - return err -} - -func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { - interval := s.pollInterval() - maxAttempts := s.pollMaxAttempts() - lastPing := time.Now() - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetImageTask(ctx, account, taskID) - if err != nil { - return nil, err - } - switch strings.ToLower(status.Status) { - case "succeeded", "completed": - return status.URLs, nil - case "failed": - if status.ErrorMsg != "" { - return nil, errors.New(status.ErrorMsg) - } - return nil, errors.New("sora image generation failed") - } - if stream { - s.maybeSendPing(c, &lastPing) - } - if err := sleepWithContext(ctx, interval); err != nil { - return nil, err - } - } - return nil, errors.New("sora image generation timeout") -} - -func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { - interval := s.pollInterval() - maxAttempts := s.pollMaxAttempts() - lastPing := time.Now() - for attempt := 0; attempt < maxAttempts; attempt++ { - status, err := s.soraClient.GetVideoTask(ctx, account, taskID) - if err != nil { - return nil, err - } - switch strings.ToLower(status.Status) { - case "completed", "succeeded": - return status, nil - case "failed": - if status.ErrorMsg != "" { - return nil, errors.New(status.ErrorMsg) - } - return nil, errors.New("sora video generation failed") - } - if stream { - s.maybeSendPing(c, &lastPing) - } - if err := sleepWithContext(ctx, interval); err != nil { - return nil, err - } - } - return nil, errors.New("sora video generation timeout") -} - -func (s *SoraGatewayService) pollInterval() time.Duration { - if s == nil || s.cfg == nil { - return 2 * time.Second - } - interval := s.cfg.Sora.Client.PollIntervalSeconds - if interval <= 0 { - interval = 2 - } - return time.Duration(interval) * time.Second -} - -func (s *SoraGatewayService) pollMaxAttempts() int { - if s == nil || s.cfg == nil { - return 600 - } - maxAttempts := s.cfg.Sora.Client.MaxPollAttempts - if maxAttempts <= 0 { - maxAttempts = 600 - } - return maxAttempts -} - -func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { - if c == nil { - return - } - interval := 10 * time.Second - if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { - interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second - } - if time.Since(*lastPing) < interval { - return - } - if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { - if flusher, ok := c.Writer.(http.Flusher); ok { - flusher.Flush() - } - *lastPing = time.Now() - } -} - -func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { - if len(urls) == 0 { - return urls - } - output := make([]string, 0, len(urls)) - for _, raw := range urls { - raw = strings.TrimSpace(raw) - if raw == "" { - continue - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - output = append(output, raw) - continue - } - pathVal := raw - if !strings.HasPrefix(pathVal, "/") { - pathVal = "/" + pathVal - } - output = append(output, s.buildSoraMediaURL(pathVal, "")) - } - return output -} - -// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符, -// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。 -func jsonMarshalRaw(v any) ([]byte, error) { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetEscapeHTML(false) - if err := enc.Encode(v); err != nil { - return nil, err - } - // Encode 会追加换行符,去掉它 - b := buf.Bytes() - if len(b) > 0 && b[len(b)-1] == '\n' { - b = b[:len(b)-1] - } - return b, nil -} - -func buildSoraContent(mediaType string, urls []string) string { - switch mediaType { - case "image": - parts := make([]string, 0, len(urls)) - for _, u := range urls { - parts = append(parts, fmt.Sprintf("![image](%s)", u)) - } - return strings.Join(parts, "\n") - case "video": - if len(urls) == 0 { - return "" - } - return fmt.Sprintf("```html\n\n```", urls[0]) - default: - return "" - } -} - -func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { - if body == nil { - return "", "", "", "" - } - if v, ok := body["remix_target_id"].(string); ok { - remixTargetID = strings.TrimSpace(v) - } - if v, ok := body["image"].(string); ok { - imageInput = v - } - if v, ok := body["video"].(string); ok { - videoInput = v - } - if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { - prompt = v - } - if messages, ok := body["messages"].([]any); ok { - builder := strings.Builder{} - for _, raw := range messages { - msg, ok := raw.(map[string]any) - if !ok { - continue - } - role, _ := msg["role"].(string) - if role != "" && role != "user" { - continue - } - content := msg["content"] - text, img, vid := parseSoraMessageContent(content) - if text != "" { - if builder.Len() > 0 { - _, _ = builder.WriteString("\n") - } - _, _ = builder.WriteString(text) - } - if imageInput == "" && img != "" { - imageInput = img - } - if videoInput == "" && vid != "" { - videoInput = vid - } - } - if prompt == "" { - prompt = builder.String() - } - } - if remixTargetID == "" { - remixTargetID = extractRemixTargetIDFromPrompt(prompt) - } - prompt = cleanRemixLinkFromPrompt(prompt) - return prompt, imageInput, videoInput, remixTargetID -} - -func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { - switch val := content.(type) { - case string: - return val, "", "" - case []any: - builder := strings.Builder{} - for _, item := range val { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - t, _ := itemMap["type"].(string) - switch t { - case "text": - if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { - if builder.Len() > 0 { - _, _ = builder.WriteString("\n") - } - _, _ = builder.WriteString(txt) - } - case "image_url": - if imageInput == "" { - if urlVal, ok := itemMap["image_url"].(map[string]any); ok { - imageInput = fmt.Sprintf("%v", urlVal["url"]) - } else if urlStr, ok := itemMap["image_url"].(string); ok { - imageInput = urlStr - } - } - case "video_url": - if videoInput == "" { - if urlVal, ok := itemMap["video_url"].(map[string]any); ok { - videoInput = fmt.Sprintf("%v", urlVal["url"]) - } else if urlStr, ok := itemMap["video_url"].(string); ok { - videoInput = urlStr - } - } - } - } - return builder.String(), imageInput, videoInput - default: - return "", "", "" - } -} - -func isSoraStoryboardPrompt(prompt string) bool { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return false - } - return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 -} - -func formatSoraStoryboardPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return "" - } - matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) - if len(matches) == 0 { - return prompt - } - firstBracketPos := strings.Index(prompt, "[") - instructions := "" - if firstBracketPos > 0 { - instructions = strings.TrimSpace(prompt[:firstBracketPos]) - } - shots := make([]string, 0, len(matches)) - for i, match := range matches { - if len(match) < 3 { - continue - } - duration := strings.TrimSpace(match[1]) - scene := strings.TrimSpace(match[2]) - if scene == "" { - continue - } - shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) - } - if len(shots) == 0 { - return prompt - } - timeline := strings.Join(shots, "\n\n") - if instructions == "" { - return timeline - } - return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) -} - -func extractRemixTargetIDFromPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return "" - } - return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) -} - -func cleanRemixLinkFromPrompt(prompt string) string { - prompt = strings.TrimSpace(prompt) - if prompt == "" { - return prompt - } - cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") - cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") - cleaned = strings.Join(strings.Fields(cleaned), " ") - return strings.TrimSpace(cleaned) -} - -func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { - raw := strings.TrimSpace(input) - if raw == "" { - return nil, "", errors.New("empty image input") - } - if strings.HasPrefix(raw, "data:") { - parts := strings.SplitN(raw, ",", 2) - if len(parts) != 2 { - return nil, "", errors.New("invalid data url") - } - meta := parts[0] - payload := parts[1] - decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) - if err != nil { - return nil, "", err - } - ext := "" - if strings.HasPrefix(meta, "data:") { - metaParts := strings.SplitN(meta[5:], ";", 2) - if len(metaParts) > 0 { - if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { - ext = exts[0] - } - } - } - filename := "image" + ext - return decoded, filename, nil - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - return downloadSoraImageInput(ctx, raw) - } - decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) - if err != nil { - return nil, "", errors.New("invalid base64 image") - } - return decoded, "image.png", nil -} - -func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { - raw := strings.TrimSpace(input) - if raw == "" { - return nil, errors.New("empty video input") - } - if strings.HasPrefix(raw, "data:") { - parts := strings.SplitN(raw, ",", 2) - if len(parts) != 2 { - return nil, errors.New("invalid video data url") - } - decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) - if err != nil { - return nil, errors.New("invalid base64 video") - } - if len(decoded) == 0 { - return nil, errors.New("empty video data") - } - return decoded, nil - } - if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { - return downloadSoraVideoInput(ctx, raw) - } - decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) - if err != nil { - return nil, errors.New("invalid base64 video") - } - if len(decoded) == 0 { - return nil, errors.New("empty video data") - } - return decoded, nil -} - -func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { - parsed, err := validateSoraRemoteURL(rawURL) - if err != nil { - return nil, "", err - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) - if err != nil { - return nil, "", err - } - client := &http.Client{ - Timeout: soraImageInputTimeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= soraImageInputMaxRedirects { - return errors.New("too many redirects") - } - return validateSoraRemoteURLValue(req.URL) - }, - } - resp, err := client.Do(req) - if err != nil { - return nil, "", err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes)) - if err != nil { - return nil, "", err - } - ext := fileExtFromURL(parsed.String()) - if ext == "" { - ext = fileExtFromContentType(resp.Header.Get("Content-Type")) - } - filename := "image" + ext - return data, filename, nil -} - -func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { - parsed, err := validateSoraRemoteURL(rawURL) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) - if err != nil { - return nil, err - } - client := &http.Client{ - Timeout: soraVideoInputTimeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= soraVideoInputMaxRedirects { - return errors.New("too many redirects") - } - return validateSoraRemoteURLValue(req.URL) - }, - } - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) - } - data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) - if err != nil { - return nil, err - } - if len(data) == 0 { - return nil, errors.New("empty video content") - } - return data, nil -} - -func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { - if maxBytes <= 0 { - return nil, errors.New("invalid max bytes limit") - } - decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) - limited := io.LimitReader(decoder, maxBytes+1) - data, err := io.ReadAll(limited) - if err != nil { - return nil, err - } - if int64(len(data)) > maxBytes { - return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) - } - return data, nil -} - -func validateSoraRemoteURL(raw string) (*url.URL, error) { - if strings.TrimSpace(raw) == "" { - return nil, errors.New("empty remote url") - } - parsed, err := url.Parse(raw) - if err != nil { - return nil, fmt.Errorf("invalid remote url: %w", err) - } - if err := validateSoraRemoteURLValue(parsed); err != nil { - return nil, err - } - return parsed, nil -} - -func validateSoraRemoteURLValue(parsed *url.URL) error { - if parsed == nil { - return errors.New("invalid remote url") - } - scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) - if scheme != "http" && scheme != "https" { - return errors.New("only http/https remote url is allowed") - } - if parsed.User != nil { - return errors.New("remote url cannot contain userinfo") - } - host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) - if host == "" { - return errors.New("remote url missing host") - } - if _, blocked := soraBlockedHostnames[host]; blocked { - return errors.New("remote url is not allowed") - } - if ip := net.ParseIP(host); ip != nil { - if isSoraBlockedIP(ip) { - return errors.New("remote url is not allowed") - } - return nil - } - ips, err := net.LookupIP(host) - if err != nil { - return fmt.Errorf("resolve remote url failed: %w", err) - } - for _, ip := range ips { - if isSoraBlockedIP(ip) { - return errors.New("remote url is not allowed") - } - } - return nil -} - -func isSoraBlockedIP(ip net.IP) bool { - if ip == nil { - return true - } - for _, cidr := range soraBlockedCIDRs { - if cidr.Contains(ip) { - return true - } - } - return false -} - -func mustParseCIDRs(values []string) []*net.IPNet { - out := make([]*net.IPNet, 0, len(values)) - for _, val := range values { - _, cidr, err := net.ParseCIDR(val) - if err != nil { - continue - } - out = append(out, cidr) - } - return out -} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go deleted file mode 100644 index 2fef600c53..0000000000 --- a/backend/internal/service/sora_gateway_service_test.go +++ /dev/null @@ -1,564 +0,0 @@ -//go:build unit - -package service - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" -) - -var _ SoraClient = (*stubSoraClientForPoll)(nil) - -type stubSoraClientForPoll struct { - imageStatus *SoraImageTaskStatus - videoStatus *SoraVideoTaskStatus - imageCalls int - videoCalls int - enhanced string - enhanceErr error - storyboard bool - videoReq SoraVideoRequest - parseErr error - postCalls int - deleteCalls int -} - -func (s *stubSoraClientForPoll) Enabled() bool { return true } -func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { - return "", nil -} -func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { - return "task-image", nil -} -func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { - s.videoReq = req - return "task-video", nil -} -func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { - s.storyboard = true - return "task-video", nil -} -func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { - return "cameo-1", nil -} -func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { - return &SoraCameoStatus{ - Status: "finalized", - StatusMessage: "Completed", - DisplayNameHint: "Character", - UsernameHint: "user.character", - ProfileAssetURL: "https://example.com/avatar.webp", - }, nil -} -func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { - return []byte("avatar"), nil -} -func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { - return "asset-pointer", nil -} -func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { - return "character-1", nil -} -func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { - return nil -} -func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { - return nil -} -func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { - s.postCalls++ - return "s_post", nil -} -func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { - s.deleteCalls++ - return nil -} -func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { - if s.parseErr != nil { - return "", s.parseErr - } - return "https://example.com/no-watermark.mp4", nil -} -func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { - if s.enhanced != "" { - return s.enhanced, s.enhanceErr - } - return "enhanced prompt", s.enhanceErr -} -func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { - s.imageCalls++ - return s.imageStatus, nil -} -func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { - s.videoCalls++ - return s.videoStatus, nil -} - -func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { - client := &stubSoraClientForPoll{ - imageStatus: &SoraImageTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/a.png"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - service := NewSoraGatewayService(client, nil, nil, cfg) - - urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) - require.NoError(t, err) - require.Equal(t, []string{"https://example.com/a.png"}, urls) - require.Equal(t, 1, client.imageCalls) -} - -func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { - client := &stubSoraClientForPoll{ - enhanced: "cinematic prompt", - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ - ID: 1, - Platform: PlatformSora, - Status: StatusActive, - Credentials: map[string]any{ - "model_mapping": map[string]any{ - "prompt-enhance-short-10s": "prompt-enhance-short-15s", - }, - }, - } - body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "prompt", result.MediaType) - require.Equal(t, "prompt-enhance-short-10s", result.Model) - require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel) -} - -func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/v.mp4"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.True(t, client.storyboard) -} - -func TestSoraGatewayService_ForwardVideoCount(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/v.mp4"}, - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, 3, client.videoReq.VideoCount) -} - -func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { - client := &stubSoraClientForPoll{} - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "prompt", result.MediaType) - require.Equal(t, 0, client.videoCalls) -} - -func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/original.mp4"}, - GenerationID: "gen_1", - }, - parseErr: errors.New("parse failed"), - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "https://example.com/original.mp4", result.MediaURL) - require.Equal(t, 1, client.postCalls) - require.Equal(t, 0, client.deleteCalls) -} - -func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "completed", - URLs: []string{"https://example.com/original.mp4"}, - GenerationID: "gen_1", - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - svc := NewSoraGatewayService(client, nil, nil, cfg) - account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} - body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) - - result, err := svc.Forward(context.Background(), nil, account, body, false) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) - require.Equal(t, 1, client.postCalls) - require.Equal(t, 1, client.deleteCalls) -} - -func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { - client := &stubSoraClientForPoll{ - videoStatus: &SoraVideoTaskStatus{ - Status: "failed", - ErrorMsg: "reject", - }, - } - cfg := &config.Config{ - Sora: config.SoraConfig{ - Client: config.SoraClientConfig{ - PollIntervalSeconds: 1, - MaxPollAttempts: 1, - }, - }, - } - service := NewSoraGatewayService(client, nil, nil, cfg) - - status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) - require.Error(t, err) - require.Nil(t, status) - require.Contains(t, err.Error(), "reject") - require.Equal(t, 1, client.videoCalls) -} - -func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { - cfg := &config.Config{ - Gateway: config.GatewayConfig{ - SoraMediaSigningKey: "test-key", - SoraMediaSignedURLTTLSeconds: 600, - }, - } - service := NewSoraGatewayService(nil, nil, nil, cfg) - - url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") - require.Contains(t, url, "/sora/media-signed") - require.Contains(t, url, "expires=") - require.Contains(t, url, "sig=") -} - -func TestNormalizeSoraMediaURLs_Empty(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - result := svc.normalizeSoraMediaURLs(nil) - require.Empty(t, result) - - result = svc.normalizeSoraMediaURLs([]string{}) - require.Empty(t, result) -} - -func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"} - result := svc.normalizeSoraMediaURLs(urls) - require.Equal(t, urls, result) -} - -func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) { - cfg := &config.Config{} - svc := NewSoraGatewayService(nil, nil, nil, cfg) - urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"} - result := svc.normalizeSoraMediaURLs(urls) - require.Len(t, result, 2) - require.Contains(t, result[0], "/sora/media") - require.Contains(t, result[1], "/sora/media") -} - -func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) { - svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) - urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"} - result := svc.normalizeSoraMediaURLs(urls) - require.Len(t, result, 2) -} - -func TestBuildSoraContent_Image(t *testing.T) { - content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"}) - require.Contains(t, content, "![image](https://a.com/1.png)") - require.Contains(t, content, "![image](https://a.com/2.png)") -} - -func TestBuildSoraContent_Video(t *testing.T) { - content := buildSoraContent("video", []string{"https://a.com/v.mp4"}) - require.Contains(t, content, "