diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 513b7996db..fdc5c6acdb 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -102,12 +102,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.NewSchedulerCache(redisClient)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
- soraAccountRepository := repository.NewSoraAccountRepository(db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
- adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
+ adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -184,12 +183,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
- soraS3Storage := service.NewSoraS3Storage(settingService)
- settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient)
- soraGenerationRepository := repository.NewSoraGenerationRepository(db)
- soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
- soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
- settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
+ settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@@ -223,16 +217,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
- soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
- soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
- soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
- soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
- soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -243,12 +232,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
- soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -283,7 +271,6 @@ func provideCleanup(
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
- soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
@@ -331,12 +318,6 @@ func provideCleanup(
}
return nil
}},
- {"SoraMediaCleanupService", func() error {
- if soraMediaCleanup != nil {
- soraMediaCleanup.Stop()
- }
- return nil
- }},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go
index 9d2a54b98b..6e4561c994 100644
--- a/backend/cmd/server/wire_gen_test.go
+++ b/backend/cmd/server/wire_gen_test.go
@@ -57,7 +57,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
&service.OpsCleanupService{},
&service.OpsScheduledReportService{},
opsSystemLogSinkSvc,
- &service.SoraMediaCleanupService{},
schedulerSnapshotSvc,
tokenRefreshSvc,
accountExpirySvc,
diff --git a/backend/ent/group.go b/backend/ent/group.go
index fc691a9b4c..b15ac15dd9 100644
--- a/backend/ent/group.go
+++ b/backend/ent/group.go
@@ -52,16 +52,6 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
- // SoraImagePrice360 holds the value of the "sora_image_price_360" field.
- SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
- // SoraImagePrice540 holds the value of the "sora_image_price_540" field.
- SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
- // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field.
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
- // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
- SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
- // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
@@ -196,9 +186,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool)
- case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
+ case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
- case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
+ case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
values[i] = new(sql.NullString)
@@ -335,40 +325,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64
}
- case group.FieldSoraImagePrice360:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i])
- } else if value.Valid {
- _m.SoraImagePrice360 = new(float64)
- *_m.SoraImagePrice360 = value.Float64
- }
- case group.FieldSoraImagePrice540:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i])
- } else if value.Valid {
- _m.SoraImagePrice540 = new(float64)
- *_m.SoraImagePrice540 = value.Float64
- }
- case group.FieldSoraVideoPricePerRequest:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i])
- } else if value.Valid {
- _m.SoraVideoPricePerRequest = new(float64)
- *_m.SoraVideoPricePerRequest = value.Float64
- }
- case group.FieldSoraVideoPricePerRequestHd:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i])
- } else if value.Valid {
- _m.SoraVideoPricePerRequestHd = new(float64)
- *_m.SoraVideoPricePerRequestHd = value.Float64
- }
- case group.FieldSoraStorageQuotaBytes:
- if value, ok := values[i].(*sql.NullInt64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
- } else if value.Valid {
- _m.SoraStorageQuotaBytes = value.Int64
- }
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
@@ -590,29 +546,6 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
- if v := _m.SoraImagePrice360; v != nil {
- builder.WriteString("sora_image_price_360=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- if v := _m.SoraImagePrice540; v != nil {
- builder.WriteString("sora_image_price_540=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- if v := _m.SoraVideoPricePerRequest; v != nil {
- builder.WriteString("sora_video_price_per_request=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- if v := _m.SoraVideoPricePerRequestHd; v != nil {
- builder.WriteString("sora_video_price_per_request_hd=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- builder.WriteString("sora_storage_quota_bytes=")
- builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
- builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")
diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go
index 352221275b..21a7c2cb76 100644
--- a/backend/ent/group/group.go
+++ b/backend/ent/group/group.go
@@ -49,16 +49,6 @@ const (
FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k"
- // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database.
- FieldSoraImagePrice360 = "sora_image_price_360"
- // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database.
- FieldSoraImagePrice540 = "sora_image_price_540"
- // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database.
- FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
- // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
- FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
- // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
- FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
@@ -175,11 +165,6 @@ var Columns = []string{
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
- FieldSoraImagePrice360,
- FieldSoraImagePrice540,
- FieldSoraVideoPricePerRequest,
- FieldSoraVideoPricePerRequestHd,
- FieldSoraStorageQuotaBytes,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldFallbackGroupIDOnInvalidRequest,
@@ -247,8 +232,6 @@ var (
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
- // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
- DefaultSoraStorageQuotaBytes int64
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
@@ -364,31 +347,6 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
}
-// BySoraImagePrice360 orders the results by the sora_image_price_360 field.
-func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc()
-}
-
-// BySoraImagePrice540 orders the results by the sora_image_price_540 field.
-func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc()
-}
-
-// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field.
-func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc()
-}
-
-// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field.
-func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
-}
-
-// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
-func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
-}
-
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go
index 41bd575a53..cba2ce5f0e 100644
--- a/backend/ent/group/where.go
+++ b/backend/ent/group/where.go
@@ -140,31 +140,6 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
}
-// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ.
-func SoraImagePrice360(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ.
-func SoraImagePrice540(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
-}
-
-// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ.
-func SoraVideoPricePerRequest(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ.
-func SoraVideoPricePerRequestHd(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
-func SoraStorageQuotaBytes(v int64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
-}
-
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
@@ -1070,246 +1045,6 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
}
-// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field.
-func SoraImagePrice360EQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field.
-func SoraImagePrice360NEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field.
-func SoraImagePrice360In(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...))
-}
-
-// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field.
-func SoraImagePrice360NotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...))
-}
-
-// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field.
-func SoraImagePrice360GT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field.
-func SoraImagePrice360GTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field.
-func SoraImagePrice360LT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field.
-func SoraImagePrice360LTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field.
-func SoraImagePrice360IsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360))
-}
-
-// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field.
-func SoraImagePrice360NotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360))
-}
-
-// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field.
-func SoraImagePrice540EQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field.
-func SoraImagePrice540NEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field.
-func SoraImagePrice540In(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...))
-}
-
-// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field.
-func SoraImagePrice540NotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...))
-}
-
-// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field.
-func SoraImagePrice540GT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field.
-func SoraImagePrice540GTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field.
-func SoraImagePrice540LT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field.
-func SoraImagePrice540LTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field.
-func SoraImagePrice540IsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540))
-}
-
-// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field.
-func SoraImagePrice540NotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540))
-}
-
-// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestNEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...))
-}
-
-// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...))
-}
-
-// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestGT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestGTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestLT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestLTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestIsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest))
-}
-
-// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestNotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest))
-}
-
-// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...))
-}
-
-// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...))
-}
-
-// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdGT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdLT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdIsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd))
-}
-
-// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdNotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
-}
-
-// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesEQ(v int64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNEQ(v int64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
-}
-
-// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
-}
-
-// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGT(v int64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGTE(v int64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLT(v int64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLTE(v int64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
-}
-
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go
index a635dfd999..a8c30b184d 100644
--- a/backend/ent/group_create.go
+++ b/backend/ent/group_create.go
@@ -258,76 +258,6 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate {
- _c.mutation.SetSoraImagePrice360(v)
- return _c
-}
-
-// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraImagePrice360(*v)
- }
- return _c
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate {
- _c.mutation.SetSoraImagePrice540(v)
- return _c
-}
-
-// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraImagePrice540(*v)
- }
- return _c
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate {
- _c.mutation.SetSoraVideoPricePerRequest(v)
- return _c
-}
-
-// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraVideoPricePerRequest(*v)
- }
- return _c
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate {
- _c.mutation.SetSoraVideoPricePerRequestHd(v)
- return _c
-}
-
-// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraVideoPricePerRequestHd(*v)
- }
- return _c
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate {
- _c.mutation.SetSoraStorageQuotaBytes(v)
- return _c
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate {
- if v != nil {
- _c.SetSoraStorageQuotaBytes(*v)
- }
- return _c
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
@@ -645,10 +575,6 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- v := group.DefaultSoraStorageQuotaBytes
- _c.mutation.SetSoraStorageQuotaBytes(v)
- }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
@@ -737,9 +663,6 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)}
- }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
@@ -867,26 +790,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value
}
- if value, ok := _c.mutation.SoraImagePrice360(); ok {
- _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- _node.SoraImagePrice360 = &value
- }
- if value, ok := _c.mutation.SoraImagePrice540(); ok {
- _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- _node.SoraImagePrice540 = &value
- }
- if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- _node.SoraVideoPricePerRequest = &value
- }
- if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- _node.SoraVideoPricePerRequestHd = &value
- }
- if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- _node.SoraStorageQuotaBytes = value
- }
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
@@ -1379,120 +1282,6 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert {
- u.Set(group.FieldSoraImagePrice360, v)
- return u
-}
-
-// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert {
- u.SetExcluded(group.FieldSoraImagePrice360)
- return u
-}
-
-// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
-func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert {
- u.Add(group.FieldSoraImagePrice360, v)
- return u
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert {
- u.SetNull(group.FieldSoraImagePrice360)
- return u
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert {
- u.Set(group.FieldSoraImagePrice540, v)
- return u
-}
-
-// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert {
- u.SetExcluded(group.FieldSoraImagePrice540)
- return u
-}
-
-// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
-func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert {
- u.Add(group.FieldSoraImagePrice540, v)
- return u
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert {
- u.SetNull(group.FieldSoraImagePrice540)
- return u
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert {
- u.Set(group.FieldSoraVideoPricePerRequest, v)
- return u
-}
-
-// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert {
- u.SetExcluded(group.FieldSoraVideoPricePerRequest)
- return u
-}
-
-// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
-func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert {
- u.Add(group.FieldSoraVideoPricePerRequest, v)
- return u
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert {
- u.SetNull(group.FieldSoraVideoPricePerRequest)
- return u
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
- u.Set(group.FieldSoraVideoPricePerRequestHd, v)
- return u
-}
-
-// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert {
- u.SetExcluded(group.FieldSoraVideoPricePerRequestHd)
- return u
-}
-
-// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
- u.Add(group.FieldSoraVideoPricePerRequestHd, v)
- return u
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
- u.SetNull(group.FieldSoraVideoPricePerRequestHd)
- return u
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert {
- u.Set(group.FieldSoraStorageQuotaBytes, v)
- return u
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert {
- u.SetExcluded(group.FieldSoraStorageQuotaBytes)
- return u
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert {
- u.Add(group.FieldSoraStorageQuotaBytes, v)
- return u
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
@@ -2054,139 +1843,6 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
})
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice360(v)
- })
-}
-
-// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
-func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice360(v)
- })
-}
-
-// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice360()
- })
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice360()
- })
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice540(v)
- })
-}
-
-// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
-func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice540(v)
- })
-}
-
-// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice540()
- })
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice540()
- })
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequest(v)
- })
-}
-
-// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
-func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequest(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequest()
- })
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequest()
- })
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequestHd(v)
- })
-}
-
-// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequestHd(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequestHd()
- })
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequestHd()
- })
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraStorageQuotaBytes(v)
- })
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraStorageQuotaBytes(v)
- })
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraStorageQuotaBytes()
- })
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@@ -2944,139 +2600,6 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
})
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice360(v)
- })
-}
-
-// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
-func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice360(v)
- })
-}
-
-// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice360()
- })
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice360()
- })
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice540(v)
- })
-}
-
-// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
-func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice540(v)
- })
-}
-
-// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice540()
- })
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice540()
- })
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequest(v)
- })
-}
-
-// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
-func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequest(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequest()
- })
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequest()
- })
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequestHd(v)
- })
-}
-
-// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequestHd(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequestHd()
- })
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequestHd()
- })
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraStorageQuotaBytes(v)
- })
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraStorageQuotaBytes(v)
- })
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraStorageQuotaBytes()
- })
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go
index a9a4b9da80..aa1a83d421 100644
--- a/backend/ent/group_update.go
+++ b/backend/ent/group_update.go
@@ -355,135 +355,6 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate {
- _u.mutation.ResetSoraImagePrice360()
- _u.mutation.SetSoraImagePrice360(v)
- return _u
-}
-
-// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraImagePrice360(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
-func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate {
- _u.mutation.AddSoraImagePrice360(v)
- return _u
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate {
- _u.mutation.ClearSoraImagePrice360()
- return _u
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate {
- _u.mutation.ResetSoraImagePrice540()
- _u.mutation.SetSoraImagePrice540(v)
- return _u
-}
-
-// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraImagePrice540(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
-func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate {
- _u.mutation.AddSoraImagePrice540(v)
- return _u
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate {
- _u.mutation.ClearSoraImagePrice540()
- return _u
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate {
- _u.mutation.ResetSoraVideoPricePerRequest()
- _u.mutation.SetSoraVideoPricePerRequest(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraVideoPricePerRequest(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
-func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate {
- _u.mutation.AddSoraVideoPricePerRequest(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate {
- _u.mutation.ClearSoraVideoPricePerRequest()
- return _u
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
- _u.mutation.ResetSoraVideoPricePerRequestHd()
- _u.mutation.SetSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraVideoPricePerRequestHd(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
- _u.mutation.AddSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
- _u.mutation.ClearSoraVideoPricePerRequestHd()
- return _u
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate {
- if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate {
- _u.mutation.AddSoraStorageQuotaBytes(v)
- return _u
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
@@ -1082,48 +953,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
- if value, ok := _u.mutation.SoraImagePrice360(); ok {
- _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
- _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice360Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraImagePrice540(); ok {
- _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
- _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice540Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestHdCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
@@ -1817,135 +1646,6 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraImagePrice360()
- _u.mutation.SetSoraImagePrice360(v)
- return _u
-}
-
-// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraImagePrice360(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
-func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraImagePrice360(v)
- return _u
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne {
- _u.mutation.ClearSoraImagePrice360()
- return _u
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraImagePrice540()
- _u.mutation.SetSoraImagePrice540(v)
- return _u
-}
-
-// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraImagePrice540(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
-func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraImagePrice540(v)
- return _u
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne {
- _u.mutation.ClearSoraImagePrice540()
- return _u
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraVideoPricePerRequest()
- _u.mutation.SetSoraVideoPricePerRequest(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraVideoPricePerRequest(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
-func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraVideoPricePerRequest(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne {
- _u.mutation.ClearSoraVideoPricePerRequest()
- return _u
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraVideoPricePerRequestHd()
- _u.mutation.SetSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraVideoPricePerRequestHd(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
- _u.mutation.ClearSoraVideoPricePerRequestHd()
- return _u
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
- _u.mutation.AddSoraStorageQuotaBytes(v)
- return _u
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
@@ -2574,48 +2274,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
- if value, ok := _u.mutation.SoraImagePrice360(); ok {
- _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
- _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice360Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraImagePrice540(); ok {
- _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
- _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice540Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestHdCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index bdbb9fdddd..5400bf9319 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -395,11 +395,6 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
@@ -447,7 +442,7 @@ var (
{
Name: "group_sort_order",
Unique: false,
- Columns: []*schema.Column{GroupsColumns[30]},
+ Columns: []*schema.Column{GroupsColumns[25]},
},
},
}
@@ -770,7 +765,6 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
- {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
@@ -787,31 +781,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[35]},
+ Columns: []*schema.Column{UsageLogsColumns[34]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[36]},
+ Columns: []*schema.Column{UsageLogsColumns[35]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[37]},
+ Columns: []*schema.Column{UsageLogsColumns[36]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[38]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -820,32 +814,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[37]},
+ Columns: []*schema.Column{UsageLogsColumns[36]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[33]},
},
{
Name: "usagelog_account_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[35]},
+ Columns: []*schema.Column{UsageLogsColumns[34]},
},
{
Name: "usagelog_group_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[36]},
+ Columns: []*schema.Column{UsageLogsColumns[35]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[38]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
},
{
Name: "usagelog_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_model",
@@ -865,17 +859,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
},
},
}
@@ -896,8 +890,6 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
- {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
- {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 28d9a0ef22..d206039af4 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -8230,16 +8230,6 @@ type GroupMutation struct {
addimage_price_2k *float64
image_price_4k *float64
addimage_price_4k *float64
- sora_image_price_360 *float64
- addsora_image_price_360 *float64
- sora_image_price_540 *float64
- addsora_image_price_540 *float64
- sora_video_price_per_request *float64
- addsora_video_price_per_request *float64
- sora_video_price_per_request_hd *float64
- addsora_video_price_per_request_hd *float64
- sora_storage_quota_bytes *int64
- addsora_storage_quota_bytes *int64
claude_code_only *bool
fallback_group_id *int64
addfallback_group_id *int64
@@ -9260,342 +9250,6 @@ func (m *GroupMutation) ResetImagePrice4k() {
delete(m.clearedFields, group.FieldImagePrice4k)
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (m *GroupMutation) SetSoraImagePrice360(f float64) {
- m.sora_image_price_360 = &f
- m.addsora_image_price_360 = nil
-}
-
-// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation.
-func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) {
- v := m.sora_image_price_360
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err)
- }
- return oldValue.SoraImagePrice360, nil
-}
-
-// AddSoraImagePrice360 adds f to the "sora_image_price_360" field.
-func (m *GroupMutation) AddSoraImagePrice360(f float64) {
- if m.addsora_image_price_360 != nil {
- *m.addsora_image_price_360 += f
- } else {
- m.addsora_image_price_360 = &f
- }
-}
-
-// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation.
-func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) {
- v := m.addsora_image_price_360
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (m *GroupMutation) ClearSoraImagePrice360() {
- m.sora_image_price_360 = nil
- m.addsora_image_price_360 = nil
- m.clearedFields[group.FieldSoraImagePrice360] = struct{}{}
-}
-
-// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation.
-func (m *GroupMutation) SoraImagePrice360Cleared() bool {
- _, ok := m.clearedFields[group.FieldSoraImagePrice360]
- return ok
-}
-
-// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field.
-func (m *GroupMutation) ResetSoraImagePrice360() {
- m.sora_image_price_360 = nil
- m.addsora_image_price_360 = nil
- delete(m.clearedFields, group.FieldSoraImagePrice360)
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (m *GroupMutation) SetSoraImagePrice540(f float64) {
- m.sora_image_price_540 = &f
- m.addsora_image_price_540 = nil
-}
-
-// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation.
-func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) {
- v := m.sora_image_price_540
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err)
- }
- return oldValue.SoraImagePrice540, nil
-}
-
-// AddSoraImagePrice540 adds f to the "sora_image_price_540" field.
-func (m *GroupMutation) AddSoraImagePrice540(f float64) {
- if m.addsora_image_price_540 != nil {
- *m.addsora_image_price_540 += f
- } else {
- m.addsora_image_price_540 = &f
- }
-}
-
-// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation.
-func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) {
- v := m.addsora_image_price_540
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (m *GroupMutation) ClearSoraImagePrice540() {
- m.sora_image_price_540 = nil
- m.addsora_image_price_540 = nil
- m.clearedFields[group.FieldSoraImagePrice540] = struct{}{}
-}
-
-// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation.
-func (m *GroupMutation) SoraImagePrice540Cleared() bool {
- _, ok := m.clearedFields[group.FieldSoraImagePrice540]
- return ok
-}
-
-// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field.
-func (m *GroupMutation) ResetSoraImagePrice540() {
- m.sora_image_price_540 = nil
- m.addsora_image_price_540 = nil
- delete(m.clearedFields, group.FieldSoraImagePrice540)
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) {
- m.sora_video_price_per_request = &f
- m.addsora_video_price_per_request = nil
-}
-
-// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation.
-func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) {
- v := m.sora_video_price_per_request
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err)
- }
- return oldValue.SoraVideoPricePerRequest, nil
-}
-
-// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field.
-func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) {
- if m.addsora_video_price_per_request != nil {
- *m.addsora_video_price_per_request += f
- } else {
- m.addsora_video_price_per_request = &f
- }
-}
-
-// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation.
-func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) {
- v := m.addsora_video_price_per_request
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (m *GroupMutation) ClearSoraVideoPricePerRequest() {
- m.sora_video_price_per_request = nil
- m.addsora_video_price_per_request = nil
- m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{}
-}
-
-// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation.
-func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool {
- _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest]
- return ok
-}
-
-// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field.
-func (m *GroupMutation) ResetSoraVideoPricePerRequest() {
- m.sora_video_price_per_request = nil
- m.addsora_video_price_per_request = nil
- delete(m.clearedFields, group.FieldSoraVideoPricePerRequest)
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) {
- m.sora_video_price_per_request_hd = &f
- m.addsora_video_price_per_request_hd = nil
-}
-
-// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation.
-func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) {
- v := m.sora_video_price_per_request_hd
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err)
- }
- return oldValue.SoraVideoPricePerRequestHd, nil
-}
-
-// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) {
- if m.addsora_video_price_per_request_hd != nil {
- *m.addsora_video_price_per_request_hd += f
- } else {
- m.addsora_video_price_per_request_hd = &f
- }
-}
-
-// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation.
-func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) {
- v := m.addsora_video_price_per_request_hd
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() {
- m.sora_video_price_per_request_hd = nil
- m.addsora_video_price_per_request_hd = nil
- m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{}
-}
-
-// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation.
-func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool {
- _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd]
- return ok
-}
-
-// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() {
- m.sora_video_price_per_request_hd = nil
- m.addsora_video_price_per_request_hd = nil
- delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd)
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) {
- m.sora_storage_quota_bytes = &i
- m.addsora_storage_quota_bytes = nil
-}
-
-// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
-func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.sora_storage_quota_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
- }
- return oldValue.SoraStorageQuotaBytes, nil
-}
-
-// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
-func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) {
- if m.addsora_storage_quota_bytes != nil {
- *m.addsora_storage_quota_bytes += i
- } else {
- m.addsora_storage_quota_bytes = &i
- }
-}
-
-// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
-func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.addsora_storage_quota_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
-func (m *GroupMutation) ResetSoraStorageQuotaBytes() {
- m.sora_storage_quota_bytes = nil
- m.addsora_storage_quota_bytes = nil
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
m.claude_code_only = &b
@@ -10502,7 +10156,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 34)
+ fields := make([]string, 0, 29)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10554,21 +10208,6 @@ func (m *GroupMutation) Fields() []string {
if m.image_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
- if m.sora_image_price_360 != nil {
- fields = append(fields, group.FieldSoraImagePrice360)
- }
- if m.sora_image_price_540 != nil {
- fields = append(fields, group.FieldSoraImagePrice540)
- }
- if m.sora_video_price_per_request != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequest)
- }
- if m.sora_video_price_per_request_hd != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
- }
- if m.sora_storage_quota_bytes != nil {
- fields = append(fields, group.FieldSoraStorageQuotaBytes)
- }
if m.claude_code_only != nil {
fields = append(fields, group.FieldClaudeCodeOnly)
}
@@ -10647,16 +10286,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ImagePrice2k()
case group.FieldImagePrice4k:
return m.ImagePrice4k()
- case group.FieldSoraImagePrice360:
- return m.SoraImagePrice360()
- case group.FieldSoraImagePrice540:
- return m.SoraImagePrice540()
- case group.FieldSoraVideoPricePerRequest:
- return m.SoraVideoPricePerRequest()
- case group.FieldSoraVideoPricePerRequestHd:
- return m.SoraVideoPricePerRequestHd()
- case group.FieldSoraStorageQuotaBytes:
- return m.SoraStorageQuotaBytes()
case group.FieldClaudeCodeOnly:
return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID:
@@ -10724,16 +10353,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldImagePrice2k(ctx)
case group.FieldImagePrice4k:
return m.OldImagePrice4k(ctx)
- case group.FieldSoraImagePrice360:
- return m.OldSoraImagePrice360(ctx)
- case group.FieldSoraImagePrice540:
- return m.OldSoraImagePrice540(ctx)
- case group.FieldSoraVideoPricePerRequest:
- return m.OldSoraVideoPricePerRequest(ctx)
- case group.FieldSoraVideoPricePerRequestHd:
- return m.OldSoraVideoPricePerRequestHd(ctx)
- case group.FieldSoraStorageQuotaBytes:
- return m.OldSoraStorageQuotaBytes(ctx)
case group.FieldClaudeCodeOnly:
return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID:
@@ -10886,41 +10505,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetImagePrice4k(v)
return nil
- case group.FieldSoraImagePrice360:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraImagePrice360(v)
- return nil
- case group.FieldSoraImagePrice540:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraImagePrice540(v)
- return nil
- case group.FieldSoraVideoPricePerRequest:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraVideoPricePerRequest(v)
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraVideoPricePerRequestHd(v)
- return nil
- case group.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraStorageQuotaBytes(v)
- return nil
case group.FieldClaudeCodeOnly:
v, ok := value.(bool)
if !ok {
@@ -11037,21 +10621,6 @@ func (m *GroupMutation) AddedFields() []string {
if m.addimage_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
- if m.addsora_image_price_360 != nil {
- fields = append(fields, group.FieldSoraImagePrice360)
- }
- if m.addsora_image_price_540 != nil {
- fields = append(fields, group.FieldSoraImagePrice540)
- }
- if m.addsora_video_price_per_request != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequest)
- }
- if m.addsora_video_price_per_request_hd != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
- }
- if m.addsora_storage_quota_bytes != nil {
- fields = append(fields, group.FieldSoraStorageQuotaBytes)
- }
if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
@@ -11085,16 +10654,6 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedImagePrice2k()
case group.FieldImagePrice4k:
return m.AddedImagePrice4k()
- case group.FieldSoraImagePrice360:
- return m.AddedSoraImagePrice360()
- case group.FieldSoraImagePrice540:
- return m.AddedSoraImagePrice540()
- case group.FieldSoraVideoPricePerRequest:
- return m.AddedSoraVideoPricePerRequest()
- case group.FieldSoraVideoPricePerRequestHd:
- return m.AddedSoraVideoPricePerRequestHd()
- case group.FieldSoraStorageQuotaBytes:
- return m.AddedSoraStorageQuotaBytes()
case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID()
case group.FieldFallbackGroupIDOnInvalidRequest:
@@ -11166,41 +10725,6 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddImagePrice4k(v)
return nil
- case group.FieldSoraImagePrice360:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraImagePrice360(v)
- return nil
- case group.FieldSoraImagePrice540:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraImagePrice540(v)
- return nil
- case group.FieldSoraVideoPricePerRequest:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraVideoPricePerRequest(v)
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraVideoPricePerRequestHd(v)
- return nil
- case group.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraStorageQuotaBytes(v)
- return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
@@ -11254,18 +10778,6 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldImagePrice4k) {
fields = append(fields, group.FieldImagePrice4k)
}
- if m.FieldCleared(group.FieldSoraImagePrice360) {
- fields = append(fields, group.FieldSoraImagePrice360)
- }
- if m.FieldCleared(group.FieldSoraImagePrice540) {
- fields = append(fields, group.FieldSoraImagePrice540)
- }
- if m.FieldCleared(group.FieldSoraVideoPricePerRequest) {
- fields = append(fields, group.FieldSoraVideoPricePerRequest)
- }
- if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) {
- fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
- }
if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID)
}
@@ -11313,18 +10825,6 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldImagePrice4k:
m.ClearImagePrice4k()
return nil
- case group.FieldSoraImagePrice360:
- m.ClearSoraImagePrice360()
- return nil
- case group.FieldSoraImagePrice540:
- m.ClearSoraImagePrice540()
- return nil
- case group.FieldSoraVideoPricePerRequest:
- m.ClearSoraVideoPricePerRequest()
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- m.ClearSoraVideoPricePerRequestHd()
- return nil
case group.FieldFallbackGroupID:
m.ClearFallbackGroupID()
return nil
@@ -11393,21 +10893,6 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldImagePrice4k:
m.ResetImagePrice4k()
return nil
- case group.FieldSoraImagePrice360:
- m.ResetSoraImagePrice360()
- return nil
- case group.FieldSoraImagePrice540:
- m.ResetSoraImagePrice540()
- return nil
- case group.FieldSoraVideoPricePerRequest:
- m.ResetSoraVideoPricePerRequest()
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- m.ResetSoraVideoPricePerRequestHd()
- return nil
- case group.FieldSoraStorageQuotaBytes:
- m.ResetSoraStorageQuotaBytes()
- return nil
case group.FieldClaudeCodeOnly:
m.ResetClaudeCodeOnly()
return nil
@@ -19770,7 +19255,6 @@ type UsageLogMutation struct {
image_count *int
addimage_count *int
image_size *string
- media_type *string
cache_ttl_overridden *bool
created_at *time.Time
clearedFields map[string]struct{}
@@ -21713,55 +21197,6 @@ func (m *UsageLogMutation) ResetImageSize() {
delete(m.clearedFields, usagelog.FieldImageSize)
}
-// SetMediaType sets the "media_type" field.
-func (m *UsageLogMutation) SetMediaType(s string) {
- m.media_type = &s
-}
-
-// MediaType returns the value of the "media_type" field in the mutation.
-func (m *UsageLogMutation) MediaType() (r string, exists bool) {
- v := m.media_type
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldMediaType returns the old "media_type" field's value of the UsageLog entity.
-// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldMediaType is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldMediaType requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldMediaType: %w", err)
- }
- return oldValue.MediaType, nil
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (m *UsageLogMutation) ClearMediaType() {
- m.media_type = nil
- m.clearedFields[usagelog.FieldMediaType] = struct{}{}
-}
-
-// MediaTypeCleared returns if the "media_type" field was cleared in this mutation.
-func (m *UsageLogMutation) MediaTypeCleared() bool {
- _, ok := m.clearedFields[usagelog.FieldMediaType]
- return ok
-}
-
-// ResetMediaType resets all changes to the "media_type" field.
-func (m *UsageLogMutation) ResetMediaType() {
- m.media_type = nil
- delete(m.clearedFields, usagelog.FieldMediaType)
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
m.cache_ttl_overridden = &b
@@ -22003,7 +21438,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
- fields := make([]string, 0, 38)
+ fields := make([]string, 0, 37)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -22109,9 +21544,6 @@ func (m *UsageLogMutation) Fields() []string {
if m.image_size != nil {
fields = append(fields, usagelog.FieldImageSize)
}
- if m.media_type != nil {
- fields = append(fields, usagelog.FieldMediaType)
- }
if m.cache_ttl_overridden != nil {
fields = append(fields, usagelog.FieldCacheTTLOverridden)
}
@@ -22196,8 +21628,6 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageCount()
case usagelog.FieldImageSize:
return m.ImageSize()
- case usagelog.FieldMediaType:
- return m.MediaType()
case usagelog.FieldCacheTTLOverridden:
return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt:
@@ -22281,8 +21711,6 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageCount(ctx)
case usagelog.FieldImageSize:
return m.OldImageSize(ctx)
- case usagelog.FieldMediaType:
- return m.OldMediaType(ctx)
case usagelog.FieldCacheTTLOverridden:
return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt:
@@ -22541,13 +21969,6 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetImageSize(v)
return nil
- case usagelog.FieldMediaType:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetMediaType(v)
- return nil
case usagelog.FieldCacheTTLOverridden:
v, ok := value.(bool)
if !ok {
@@ -22865,9 +22286,6 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize)
}
- if m.FieldCleared(usagelog.FieldMediaType) {
- fields = append(fields, usagelog.FieldMediaType)
- }
return fields
}
@@ -22924,9 +22342,6 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldImageSize:
m.ClearImageSize()
return nil
- case usagelog.FieldMediaType:
- m.ClearMediaType()
- return nil
}
return fmt.Errorf("unknown UsageLog nullable field %s", name)
}
@@ -23040,9 +22455,6 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldImageSize:
m.ResetImageSize()
return nil
- case usagelog.FieldMediaType:
- m.ResetMediaType()
- return nil
case usagelog.FieldCacheTTLOverridden:
m.ResetCacheTTLOverridden()
return nil
@@ -23221,10 +22633,6 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
- sora_storage_quota_bytes *int64
- addsora_storage_quota_bytes *int64
- sora_storage_used_bytes *int64
- addsora_storage_used_bytes *int64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -23939,118 +23347,6 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) {
- m.sora_storage_quota_bytes = &i
- m.addsora_storage_quota_bytes = nil
-}
-
-// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
-func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.sora_storage_quota_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity.
-// If the User object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
- }
- return oldValue.SoraStorageQuotaBytes, nil
-}
-
-// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
-func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) {
- if m.addsora_storage_quota_bytes != nil {
- *m.addsora_storage_quota_bytes += i
- } else {
- m.addsora_storage_quota_bytes = &i
- }
-}
-
-// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
-func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.addsora_storage_quota_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
-func (m *UserMutation) ResetSoraStorageQuotaBytes() {
- m.sora_storage_quota_bytes = nil
- m.addsora_storage_quota_bytes = nil
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (m *UserMutation) SetSoraStorageUsedBytes(i int64) {
- m.sora_storage_used_bytes = &i
- m.addsora_storage_used_bytes = nil
-}
-
-// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation.
-func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) {
- v := m.sora_storage_used_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity.
-// If the User object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err)
- }
- return oldValue.SoraStorageUsedBytes, nil
-}
-
-// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field.
-func (m *UserMutation) AddSoraStorageUsedBytes(i int64) {
- if m.addsora_storage_used_bytes != nil {
- *m.addsora_storage_used_bytes += i
- } else {
- m.addsora_storage_used_bytes = &i
- }
-}
-
-// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation.
-func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) {
- v := m.addsora_storage_used_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field.
-func (m *UserMutation) ResetSoraStorageUsedBytes() {
- m.sora_storage_used_bytes = nil
- m.addsora_storage_used_bytes = nil
-}
-
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -24571,7 +23867,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 16)
+ fields := make([]string, 0, 14)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -24614,12 +23910,6 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
- if m.sora_storage_quota_bytes != nil {
- fields = append(fields, user.FieldSoraStorageQuotaBytes)
- }
- if m.sora_storage_used_bytes != nil {
- fields = append(fields, user.FieldSoraStorageUsedBytes)
- }
return fields
}
@@ -24656,10 +23946,6 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
- case user.FieldSoraStorageQuotaBytes:
- return m.SoraStorageQuotaBytes()
- case user.FieldSoraStorageUsedBytes:
- return m.SoraStorageUsedBytes()
}
return nil, false
}
@@ -24697,10 +23983,6 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
- case user.FieldSoraStorageQuotaBytes:
- return m.OldSoraStorageQuotaBytes(ctx)
- case user.FieldSoraStorageUsedBytes:
- return m.OldSoraStorageUsedBytes(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -24808,20 +24090,6 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
- case user.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraStorageQuotaBytes(v)
- return nil
- case user.FieldSoraStorageUsedBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraStorageUsedBytes(v)
- return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -24836,12 +24104,6 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency)
}
- if m.addsora_storage_quota_bytes != nil {
- fields = append(fields, user.FieldSoraStorageQuotaBytes)
- }
- if m.addsora_storage_used_bytes != nil {
- fields = append(fields, user.FieldSoraStorageUsedBytes)
- }
return fields
}
@@ -24854,10 +24116,6 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance()
case user.FieldConcurrency:
return m.AddedConcurrency()
- case user.FieldSoraStorageQuotaBytes:
- return m.AddedSoraStorageQuotaBytes()
- case user.FieldSoraStorageUsedBytes:
- return m.AddedSoraStorageUsedBytes()
}
return nil, false
}
@@ -24881,20 +24139,6 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
- case user.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraStorageQuotaBytes(v)
- return nil
- case user.FieldSoraStorageUsedBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraStorageUsedBytes(v)
- return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -24985,12 +24229,6 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
- case user.FieldSoraStorageQuotaBytes:
- m.ResetSoraStorageQuotaBytes()
- return nil
- case user.FieldSoraStorageUsedBytes:
- m.ResetSoraStorageUsedBytes()
- return nil
}
return fmt.Errorf("unknown User field %s", name)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 336b1f8243..803b7bc24b 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -430,44 +430,40 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
- // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
- groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor()
- // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
- group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
- groupDescClaudeCodeOnly := groupFields[19].Descriptor()
+ groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
- groupDescModelRoutingEnabled := groupFields[23].Descriptor()
+ groupDescModelRoutingEnabled := groupFields[18].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
- groupDescMcpXMLInject := groupFields[24].Descriptor()
+ groupDescMcpXMLInject := groupFields[19].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
- groupDescSupportedModelScopes := groupFields[25].Descriptor()
+ groupDescSupportedModelScopes := groupFields[20].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
- groupDescSortOrder := groupFields[26].Descriptor()
+ groupDescSortOrder := groupFields[21].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
- groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
+ groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
- groupDescRequireOauthOnly := groupFields[28].Descriptor()
+ groupDescRequireOauthOnly := groupFields[23].Descriptor()
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
- groupDescRequirePrivacySet := groupFields[29].Descriptor()
+ groupDescRequirePrivacySet := groupFields[24].Descriptor()
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
- groupDescDefaultMappedModel := groupFields[30].Descriptor()
+ groupDescDefaultMappedModel := groupFields[25].Descriptor()
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
@@ -963,16 +959,12 @@ func init() {
usagelogDescImageSize := usagelogFields[34].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
- // usagelogDescMediaType is the schema descriptor for media_type field.
- usagelogDescMediaType := usagelogFields[35].Descriptor()
- // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
- usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
- usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor()
+ usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
- usagelogDescCreatedAt := usagelogFields[37].Descriptor()
+ usagelogDescCreatedAt := usagelogFields[36].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
@@ -1064,14 +1056,6 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
- // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
- userDescSoraStorageQuotaBytes := userFields[11].Descriptor()
- // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
- user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64)
- // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field.
- userDescSoraStorageUsedBytes := userFields[12].Descriptor()
- // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field.
- user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index fd83bf26ad..0eb89c181b 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -87,28 +87,6 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- // Sora 按次计费配置(阶段 1)
- field.Float("sora_image_price_360").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- field.Float("sora_image_price_540").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- field.Float("sora_video_price_per_request").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- field.Float("sora_video_price_per_request_hd").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
-
- // Sora 存储配额
- field.Int64("sora_storage_quota_bytes").
- Default(0),
-
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go
index f6c725a2b1..bd3ebfcc3c 100644
--- a/backend/ent/schema/usage_log.go
+++ b/backend/ent/schema/usage_log.go
@@ -134,12 +134,6 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10).
Optional().
Nillable(),
- // 媒体类型字段(sora 使用)
- field.String("media_type").
- MaxLen(16).
- Optional().
- Nillable(),
-
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index 0a3b5d9ec2..d443ef455c 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -72,12 +72,6 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
-
- // Sora 存储配额
- field.Int64("sora_storage_quota_bytes").
- Default(0),
- field.Int64("sora_storage_used_bytes").
- Default(0),
}
}
diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go
index b857afdbb5..a8e0cc6ce8 100644
--- a/backend/ent/usagelog.go
+++ b/backend/ent/usagelog.go
@@ -92,8 +92,6 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
- // MediaType holds the value of the "media_type" field.
- MediaType *string `json:"media_type,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
@@ -187,7 +185,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
- case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
+ case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -436,13 +434,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
- case usagelog.FieldMediaType:
- if value, ok := values[i].(*sql.NullString); !ok {
- return fmt.Errorf("unexpected type %T for field media_type", values[i])
- } else if value.Valid {
- _m.MediaType = new(string)
- *_m.MediaType = value.String
- }
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
@@ -649,11 +640,6 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
- if v := _m.MediaType; v != nil {
- builder.WriteString("media_type=")
- builder.WriteString(*v)
- }
- builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")
diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go
index 1567ad9b45..a7438e604f 100644
--- a/backend/ent/usagelog/usagelog.go
+++ b/backend/ent/usagelog/usagelog.go
@@ -84,8 +84,6 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
- // FieldMediaType holds the string denoting the media_type field in the database.
- FieldMediaType = "media_type"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
@@ -177,7 +175,6 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
- FieldMediaType,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@@ -245,8 +242,6 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
- // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
- MediaTypeValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
@@ -436,11 +431,6 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
-// ByMediaType orders the results by the media_type field.
-func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldMediaType, opts...).ToFunc()
-}
-
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go
index a1fb36cbaa..b8439a0397 100644
--- a/backend/ent/usagelog/where.go
+++ b/backend/ent/usagelog/where.go
@@ -230,11 +230,6 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
-// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
-func MediaType(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
-}
-
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
@@ -1905,81 +1900,6 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
-// MediaTypeEQ applies the EQ predicate on the "media_type" field.
-func MediaTypeEQ(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
-}
-
-// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
-func MediaTypeNEQ(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
-}
-
-// MediaTypeIn applies the In predicate on the "media_type" field.
-func MediaTypeIn(vs ...string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
-}
-
-// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
-func MediaTypeNotIn(vs ...string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
-}
-
-// MediaTypeGT applies the GT predicate on the "media_type" field.
-func MediaTypeGT(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
-}
-
-// MediaTypeGTE applies the GTE predicate on the "media_type" field.
-func MediaTypeGTE(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
-}
-
-// MediaTypeLT applies the LT predicate on the "media_type" field.
-func MediaTypeLT(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
-}
-
-// MediaTypeLTE applies the LTE predicate on the "media_type" field.
-func MediaTypeLTE(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
-}
-
-// MediaTypeContains applies the Contains predicate on the "media_type" field.
-func MediaTypeContains(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
-}
-
-// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
-func MediaTypeHasPrefix(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
-}
-
-// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
-func MediaTypeHasSuffix(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
-}
-
-// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
-func MediaTypeIsNil() predicate.UsageLog {
- return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
-}
-
-// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
-func MediaTypeNotNil() predicate.UsageLog {
- return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
-}
-
-// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
-func MediaTypeEqualFold(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
-}
-
-// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
-func MediaTypeContainsFold(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
-}
-
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go
index d15e231d9c..fded364e0e 100644
--- a/backend/ent/usagelog_create.go
+++ b/backend/ent/usagelog_create.go
@@ -477,20 +477,6 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
-// SetMediaType sets the "media_type" field.
-func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
- _c.mutation.SetMediaType(v)
- return _c
-}
-
-// SetNillableMediaType sets the "media_type" field if the given value is not nil.
-func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
- if v != nil {
- _c.SetMediaType(*v)
- }
- return _c
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
@@ -768,11 +754,6 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
- if v, ok := _c.mutation.MediaType(); ok {
- if err := usagelog.MediaTypeValidator(v); err != nil {
- return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
- }
- }
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
@@ -935,10 +916,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
- if value, ok := _c.mutation.MediaType(); ok {
- _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
- _node.MediaType = &value
- }
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
@@ -1702,24 +1679,6 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
-// SetMediaType sets the "media_type" field.
-func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
- u.Set(usagelog.FieldMediaType, v)
- return u
-}
-
-// UpdateMediaType sets the "media_type" field to the value that was provided on create.
-func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
- u.SetExcluded(usagelog.FieldMediaType)
- return u
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
- u.SetNull(usagelog.FieldMediaType)
- return u
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
@@ -2498,27 +2457,6 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
-// SetMediaType sets the "media_type" field.
-func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
- return u.Update(func(s *UsageLogUpsert) {
- s.SetMediaType(v)
- })
-}
-
-// UpdateMediaType sets the "media_type" field to the value that was provided on create.
-func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
- return u.Update(func(s *UsageLogUpsert) {
- s.UpdateMediaType()
- })
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
- return u.Update(func(s *UsageLogUpsert) {
- s.ClearMediaType()
- })
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -3465,27 +3403,6 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
-// SetMediaType sets the "media_type" field.
-func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
- return u.Update(func(s *UsageLogUpsert) {
- s.SetMediaType(v)
- })
-}
-
-// UpdateMediaType sets the "media_type" field to the value that was provided on create.
-func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
- return u.Update(func(s *UsageLogUpsert) {
- s.UpdateMediaType()
- })
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
- return u.Update(func(s *UsageLogUpsert) {
- s.ClearMediaType()
- })
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go
index 52f5a606dc..bb5ac86c78 100644
--- a/backend/ent/usagelog_update.go
+++ b/backend/ent/usagelog_update.go
@@ -739,26 +739,6 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
-// SetMediaType sets the "media_type" field.
-func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
- _u.mutation.SetMediaType(v)
- return _u
-}
-
-// SetNillableMediaType sets the "media_type" field if the given value is not nil.
-func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
- if v != nil {
- _u.SetMediaType(*v)
- }
- return _u
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
- _u.mutation.ClearMediaType()
- return _u
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
@@ -912,11 +892,6 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
- if v, ok := _u.mutation.MediaType(); ok {
- if err := usagelog.MediaTypeValidator(v); err != nil {
- return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
- }
- }
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -1124,12 +1099,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
- if value, ok := _u.mutation.MediaType(); ok {
- _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
- }
- if _u.mutation.MediaTypeCleared() {
- _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
- }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
@@ -2005,26 +1974,6 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
-// SetMediaType sets the "media_type" field.
-func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
- _u.mutation.SetMediaType(v)
- return _u
-}
-
-// SetNillableMediaType sets the "media_type" field if the given value is not nil.
-func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
- if v != nil {
- _u.SetMediaType(*v)
- }
- return _u
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
- _u.mutation.ClearMediaType()
- return _u
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
@@ -2191,11 +2140,6 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
- if v, ok := _u.mutation.MediaType(); ok {
- if err := usagelog.MediaTypeValidator(v); err != nil {
- return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
- }
- }
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -2420,12 +2364,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
- if value, ok := _u.mutation.MediaType(); ok {
- _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
- }
- if _u.mutation.MediaTypeCleared() {
- _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
- }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
diff --git a/backend/ent/user.go b/backend/ent/user.go
index b3f933f6fa..2435aa1b99 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,10 +45,6 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
- // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
- // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
- SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -181,7 +177,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
- case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
+ case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
@@ -295,18 +291,6 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
- case user.FieldSoraStorageQuotaBytes:
- if value, ok := values[i].(*sql.NullInt64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
- } else if value.Valid {
- _m.SoraStorageQuotaBytes = value.Int64
- }
- case user.FieldSoraStorageUsedBytes:
- if value, ok := values[i].(*sql.NullInt64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
- } else if value.Valid {
- _m.SoraStorageUsedBytes = value.Int64
- }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -440,12 +424,6 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
- builder.WriteString(", ")
- builder.WriteString("sora_storage_quota_bytes=")
- builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
- builder.WriteString(", ")
- builder.WriteString("sora_storage_used_bytes=")
- builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index 155b916086..ae9418ff07 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,10 +43,6 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
- // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
- FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
- // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
- FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -156,8 +152,6 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
- FieldSoraStorageQuotaBytes,
- FieldSoraStorageUsedBytes,
}
var (
@@ -214,10 +208,6 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
- // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
- DefaultSoraStorageQuotaBytes int64
- // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
- DefaultSoraStorageUsedBytes int64
)
// OrderOption defines the ordering options for the User queries.
@@ -298,16 +288,6 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
-// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
-func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
-}
-
-// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
-func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
-}
-
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index e26afcf381..1de6103702 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,16 +125,6 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
-// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
-func SoraStorageQuotaBytes(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
-func SoraStorageUsedBytes(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
-}
-
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -870,86 +860,6 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
-// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesEQ(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
- return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
-}
-
-// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
-}
-
-// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGT(v int64) predicate.User {
- return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGTE(v int64) predicate.User {
- return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLT(v int64) predicate.User {
- return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLTE(v int64) predicate.User {
- return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesEQ(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
-}
-
-// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesNEQ(v int64) predicate.User {
- return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
-}
-
-// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
-}
-
-// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
-}
-
-// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesGT(v int64) predicate.User {
- return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
-}
-
-// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesGTE(v int64) predicate.User {
- return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
-}
-
-// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesLT(v int64) predicate.User {
- return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
-}
-
-// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesLTE(v int64) predicate.User {
- return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
-}
-
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index df0c6bcc1a..f862a580c5 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -210,34 +210,6 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
- _c.mutation.SetSoraStorageQuotaBytes(v)
- return _c
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
- if v != nil {
- _c.SetSoraStorageQuotaBytes(*v)
- }
- return _c
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
- _c.mutation.SetSoraStorageUsedBytes(v)
- return _c
-}
-
-// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
-func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
- if v != nil {
- _c.SetSoraStorageUsedBytes(*v)
- }
- return _c
-}
-
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -452,14 +424,6 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- v := user.DefaultSoraStorageQuotaBytes
- _c.mutation.SetSoraStorageQuotaBytes(v)
- }
- if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
- v := user.DefaultSoraStorageUsedBytes
- _c.mutation.SetSoraStorageUsedBytes(v)
- }
return nil
}
@@ -523,12 +487,6 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
- }
- if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
- return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
- }
return nil
}
@@ -612,14 +570,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
- if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- _node.SoraStorageQuotaBytes = value
- }
- if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
- _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
- _node.SoraStorageUsedBytes = value
- }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1006,42 +956,6 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
- u.Set(user.FieldSoraStorageQuotaBytes, v)
- return u
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
- u.SetExcluded(user.FieldSoraStorageQuotaBytes)
- return u
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
- u.Add(user.FieldSoraStorageQuotaBytes, v)
- return u
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
- u.Set(user.FieldSoraStorageUsedBytes, v)
- return u
-}
-
-// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
-func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
- u.SetExcluded(user.FieldSoraStorageUsedBytes)
- return u
-}
-
-// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
-func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
- u.Add(user.FieldSoraStorageUsedBytes, v)
- return u
-}
-
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1304,48 +1218,6 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
- return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageQuotaBytes(v)
- })
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
- return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageQuotaBytes(v)
- })
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
- return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageQuotaBytes()
- })
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
- return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageUsedBytes(v)
- })
-}
-
-// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
-func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
- return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageUsedBytes(v)
- })
-}
-
-// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
-func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
- return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageUsedBytes()
- })
-}
-
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1774,48 +1646,6 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
- return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageQuotaBytes(v)
- })
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
- return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageQuotaBytes(v)
- })
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
- return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageQuotaBytes()
- })
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
- return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageUsedBytes(v)
- })
-}
-
-// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
-func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
- return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageUsedBytes(v)
- })
-}
-
-// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
-func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
- return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageUsedBytes()
- })
-}
-
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index f71f0cadfa..80222c92dc 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -242,48 +242,6 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
- if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
- _u.mutation.AddSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
- _u.mutation.ResetSoraStorageUsedBytes()
- _u.mutation.SetSoraStorageUsedBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
-func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
- if v != nil {
- _u.SetSoraStorageUsedBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
-func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
- _u.mutation.AddSoraStorageUsedBytes(v)
- return _u
-}
-
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -751,18 +709,6 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
- _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
- _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
- }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1406,48 +1352,6 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
- if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
- _u.mutation.AddSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
- _u.mutation.ResetSoraStorageUsedBytes()
- _u.mutation.SetSoraStorageUsedBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
-func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
- if v != nil {
- _u.SetSoraStorageUsedBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
-func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
- _u.mutation.AddSoraStorageUsedBytes(v)
- return _u
-}
-
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1945,18 +1849,6 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
- _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
- _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
- }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 3ee5d6cdca..9b43037789 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -77,7 +77,6 @@ type Config struct {
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
- Sora SoraConfig `mapstructure:"sora"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
@@ -197,8 +196,6 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
- // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
- SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
}
type PricingConfig struct {
@@ -303,59 +300,6 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"`
}
-// SoraConfig 直连 Sora 配置
-type SoraConfig struct {
- Client SoraClientConfig `mapstructure:"client"`
- Storage SoraStorageConfig `mapstructure:"storage"`
-}
-
-// SoraClientConfig 直连 Sora 客户端配置
-type SoraClientConfig struct {
- BaseURL string `mapstructure:"base_url"`
- TimeoutSeconds int `mapstructure:"timeout_seconds"`
- MaxRetries int `mapstructure:"max_retries"`
- CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
- PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
- MaxPollAttempts int `mapstructure:"max_poll_attempts"`
- RecentTaskLimit int `mapstructure:"recent_task_limit"`
- RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
- Debug bool `mapstructure:"debug"`
- UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
- Headers map[string]string `mapstructure:"headers"`
- UserAgent string `mapstructure:"user_agent"`
- DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
- CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
-}
-
-// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
-type SoraCurlCFFISidecarConfig struct {
- Enabled bool `mapstructure:"enabled"`
- BaseURL string `mapstructure:"base_url"`
- Impersonate string `mapstructure:"impersonate"`
- TimeoutSeconds int `mapstructure:"timeout_seconds"`
- SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
- SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
-}
-
-// SoraStorageConfig 媒体存储配置
-type SoraStorageConfig struct {
- Type string `mapstructure:"type"`
- LocalPath string `mapstructure:"local_path"`
- FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
- MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
- DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
- MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
- Debug bool `mapstructure:"debug"`
- Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
-}
-
-// SoraStorageCleanupConfig 媒体清理配置
-type SoraStorageCleanupConfig struct {
- Enabled bool `mapstructure:"enabled"`
- Schedule string `mapstructure:"schedule"`
- RetentionDays int `mapstructure:"retention_days"`
-}
-
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间(秒),0表示无超时
@@ -424,24 +368,6 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
- // Sora 专用配置
- // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
- SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
- // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
- SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
- // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
- SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
- // SoraStreamMode: stream 强制策略(force/error)
- SoraStreamMode string `mapstructure:"sora_stream_mode"`
- // SoraModelFilters: 模型列表过滤配置
- SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
- // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
- SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
- // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
- SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
- // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
- SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
-
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
@@ -639,12 +565,6 @@ type GatewayUsageRecordConfig struct {
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
}
-// SoraModelFiltersConfig Sora 模型过滤配置
-type SoraModelFiltersConfig struct {
- // HidePromptEnhance 是否隐藏 prompt-enhance 模型
- HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
-}
-
// TLSFingerprintConfig TLS指纹伪装配置
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
type TLSFingerprintConfig struct {
@@ -1402,13 +1322,6 @@ func setDefaults() {
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
- viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
- viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
- viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
- viper.SetDefault("gateway.sora_stream_mode", "force")
- viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
- viper.SetDefault("gateway.sora_media_require_api_key", true)
- viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
@@ -1465,45 +1378,12 @@ func setDefaults() {
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
- // Sora 直连配置
- viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
- viper.SetDefault("sora.client.timeout_seconds", 120)
- viper.SetDefault("sora.client.max_retries", 3)
- viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
- viper.SetDefault("sora.client.poll_interval_seconds", 2)
- viper.SetDefault("sora.client.max_poll_attempts", 600)
- viper.SetDefault("sora.client.recent_task_limit", 50)
- viper.SetDefault("sora.client.recent_task_limit_max", 200)
- viper.SetDefault("sora.client.debug", false)
- viper.SetDefault("sora.client.use_openai_token_provider", false)
- viper.SetDefault("sora.client.headers", map[string]string{})
- viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- viper.SetDefault("sora.client.disable_tls_fingerprint", false)
- viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
- viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
- viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
- viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
- viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
- viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
-
- viper.SetDefault("sora.storage.type", "local")
- viper.SetDefault("sora.storage.local_path", "")
- viper.SetDefault("sora.storage.fallback_to_upstream", true)
- viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
- viper.SetDefault("sora.storage.download_timeout_seconds", 120)
- viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
- viper.SetDefault("sora.storage.debug", false)
- viper.SetDefault("sora.storage.cleanup.enabled", true)
- viper.SetDefault("sora.storage.cleanup.retention_days", 7)
- viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
-
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
- viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
@@ -1879,86 +1759,6 @@ func (c *Config) Validate() error {
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
}
- if c.Gateway.SoraMaxBodySize < 0 {
- return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
- }
- if c.Gateway.SoraStreamTimeoutSeconds < 0 {
- return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
- }
- if c.Gateway.SoraRequestTimeoutSeconds < 0 {
- return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
- }
- if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
- return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
- }
- if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
- switch mode {
- case "force", "error":
- default:
- return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
- }
- }
- if c.Sora.Client.TimeoutSeconds < 0 {
- return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
- }
- if c.Sora.Client.MaxRetries < 0 {
- return fmt.Errorf("sora.client.max_retries must be non-negative")
- }
- if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
- return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
- }
- if c.Sora.Client.PollIntervalSeconds < 0 {
- return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
- }
- if c.Sora.Client.MaxPollAttempts < 0 {
- return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
- }
- if c.Sora.Client.RecentTaskLimit < 0 {
- return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
- }
- if c.Sora.Client.RecentTaskLimitMax < 0 {
- return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
- }
- if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
- c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
- c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
- }
- if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
- }
- if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
- }
- if !c.Sora.Client.CurlCFFISidecar.Enabled {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
- }
- if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
- }
- if c.Sora.Storage.MaxConcurrentDownloads < 0 {
- return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
- }
- if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
- return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
- }
- if c.Sora.Storage.MaxDownloadBytes < 0 {
- return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
- }
- if c.Sora.Storage.Cleanup.Enabled {
- if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
- return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
- }
- if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
- return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
- }
- } else {
- if c.Sora.Storage.Cleanup.RetentionDays < 0 {
- return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
- }
- }
- if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
- return fmt.Errorf("sora.storage.type must be 'local'")
- }
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index abb76549da..2de5451ee0 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -1554,94 +1554,6 @@ func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
}
}
-func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
- t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
- }
- if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
- t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
- }
- if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
- t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
- }
- if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
- t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
- }
- if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
- t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
- }
- if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
- t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
- }
-}
-
-func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CurlCFFISidecar.Enabled = false
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
- t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
- }
-}
-
-func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
- t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
- }
-}
-
-func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
- t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
- }
-}
-
-func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
- t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
- }
-}
-
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index 4e69ca0252..429486c3bf 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -22,7 +22,6 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
- PlatformSora = "sora"
)
// Account type constants
diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go
index 12139b5165..20cc09eebc 100644
--- a/backend/internal/handler/admin/account_data.go
+++ b/backend/internal/handler/admin/account_data.go
@@ -567,15 +567,15 @@ func defaultProxyName(name string) string {
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
-// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
+// Only applies to OpenAI OAuth accounts. Skips expired token errors silently.
// Existing credential values are never overwritten — only missing fields are filled.
func enrichCredentialsFromIDToken(item *DataAccount) {
if item.Credentials == nil {
return
}
- // Only enrich OpenAI/Sora OAuth accounts
+ // Only enrich OpenAI OAuth accounts
platform := strings.ToLower(strings.TrimSpace(item.Platform))
- if platform != service.PlatformOpenAI && platform != service.PlatformSora {
+ if platform != service.PlatformOpenAI {
return
}
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 681da5e8fe..9aed64d592 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -1875,12 +1875,6 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
- // Handle Sora accounts
- if account.Platform == service.PlatformSora {
- response.Success(c, service.DefaultSoraModels(nil))
- return
- }
-
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 9759cef5c0..60d68913e8 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -380,7 +380,6 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
- {Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go
index caa27bc38c..458ed35d47 100644
--- a/backend/internal/handler/admin/group_handler.go
+++ b/backend/internal/handler/admin/group_handler.go
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
- Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
+ Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -95,10 +95,6 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
- SoraImagePrice360 *float64 `json:"sora_image_price_360"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -108,8 +104,6 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
- // Sora 存储配额
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
RequireOAuthOnly bool `json:"require_oauth_only"`
@@ -123,7 +117,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
- Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
+ Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -135,10 +129,6 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
- SoraImagePrice360 *float64 `json:"sora_image_price_360"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -148,8 +138,6 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
- // Sora 存储配额
- SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
RequireOAuthOnly *bool `json:"require_oauth_only"`
@@ -258,10 +246,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
- SoraImagePrice360: req.SoraImagePrice360,
- SoraImagePrice540: req.SoraImagePrice540,
- SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -269,7 +253,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet,
@@ -313,10 +296,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
- SoraImagePrice360: req.SoraImagePrice360,
- SoraImagePrice540: req.SoraImagePrice540,
- SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -324,7 +303,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet,
diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go
index 4e6179dbe1..cc0c933792 100644
--- a/backend/internal/handler/admin/openai_oauth_handler.go
+++ b/backend/internal/handler/admin/openai_oauth_handler.go
@@ -19,9 +19,6 @@ type OpenAIOAuthHandler struct {
}
func oauthPlatformFromPath(c *gin.Context) string {
- if strings.Contains(c.FullPath(), "/admin/sora/") {
- return service.PlatformSora
- }
return service.PlatformOpenAI
}
@@ -105,7 +102,6 @@ type OpenAIRefreshTokenRequest struct {
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
-// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
@@ -145,39 +141,8 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
-// ExchangeSoraSessionToken exchanges Sora session token to access token
-// POST /api/v1/admin/sora/st2at
-func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
- var req struct {
- SessionToken string `json:"session_token"`
- ST string `json:"st"`
- ProxyID *int64 `json:"proxy_id"`
- }
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- sessionToken := strings.TrimSpace(req.SessionToken)
- if sessionToken == "" {
- sessionToken = strings.TrimSpace(req.ST)
- }
- if sessionToken == "" {
- response.BadRequest(c, "session_token is required")
- return
- }
-
- tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, tokenInfo)
-}
-
-// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
+// RefreshAccountToken refreshes token for a specific OpenAI account
// POST /api/v1/admin/openai/accounts/:id/refresh
-// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -232,9 +197,8 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
-// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
+// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
-// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
@@ -276,11 +240,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
name = tokenInfo.Email
}
if name == "" {
- if platform == service.PlatformSora {
- name = "Sora OAuth Account"
- } else {
- name = "OpenAI OAuth Account"
- }
+ name = "OpenAI OAuth Account"
}
// Create account
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 397526a7ea..069169172d 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -41,17 +41,15 @@ type SettingHandler struct {
emailService *service.EmailService
turnstileService *service.TurnstileService
opsService *service.OpsService
- soraS3Storage *service.SoraS3Storage
}
// NewSettingHandler 创建系统设置处理器
-func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
+func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
opsService: opsService,
- soraS3Storage: soraS3Storage,
}
}
@@ -108,7 +106,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
- SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
@@ -177,7 +174,6 @@ type UpdateSettingsRequest struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
@@ -566,7 +562,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
- SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
@@ -676,7 +671,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
- SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
@@ -1207,384 +1201,6 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
})
}
-func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
- if settings == nil {
- return dto.SoraS3Settings{}
- }
- return dto.SoraS3Settings{
- Enabled: settings.Enabled,
- Endpoint: settings.Endpoint,
- Region: settings.Region,
- Bucket: settings.Bucket,
- AccessKeyID: settings.AccessKeyID,
- SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
- Prefix: settings.Prefix,
- ForcePathStyle: settings.ForcePathStyle,
- CDNURL: settings.CDNURL,
- DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
- }
-}
-
-func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
- return dto.SoraS3Profile{
- ProfileID: profile.ProfileID,
- Name: profile.Name,
- IsActive: profile.IsActive,
- Enabled: profile.Enabled,
- Endpoint: profile.Endpoint,
- Region: profile.Region,
- Bucket: profile.Bucket,
- AccessKeyID: profile.AccessKeyID,
- SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
- Prefix: profile.Prefix,
- ForcePathStyle: profile.ForcePathStyle,
- CDNURL: profile.CDNURL,
- DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
- UpdatedAt: profile.UpdatedAt,
- }
-}
-
-func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
- if !enabled {
- return nil
- }
- if strings.TrimSpace(endpoint) == "" {
- return fmt.Errorf("S3 Endpoint is required when enabled")
- }
- if strings.TrimSpace(bucket) == "" {
- return fmt.Errorf("S3 Bucket is required when enabled")
- }
- if strings.TrimSpace(accessKeyID) == "" {
- return fmt.Errorf("S3 Access Key ID is required when enabled")
- }
- if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
- return nil
- }
- return fmt.Errorf("S3 Secret Access Key is required when enabled")
-}
-
-func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
- for idx := range items {
- if items[idx].ProfileID == profileID {
- return &items[idx]
- }
- }
- return nil
-}
-
-// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
-// GET /api/v1/admin/settings/sora-s3
-func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
- settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, toSoraS3SettingsDTO(settings))
-}
-
-// ListSoraS3Profiles 获取 Sora S3 多配置
-// GET /api/v1/admin/settings/sora-s3/profiles
-func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
- result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- items := make([]dto.SoraS3Profile, 0, len(result.Items))
- for idx := range result.Items {
- items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
- }
- response.Success(c, dto.ListSoraS3ProfilesResponse{
- ActiveProfileID: result.ActiveProfileID,
- Items: items,
- })
-}
-
-// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
-type UpdateSoraS3SettingsRequest struct {
- ProfileID string `json:"profile_id"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-type CreateSoraS3ProfileRequest struct {
- ProfileID string `json:"profile_id"`
- Name string `json:"name"`
- SetActive bool `json:"set_active"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-type UpdateSoraS3ProfileRequest struct {
- Name string `json:"name"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-// CreateSoraS3Profile 创建 Sora S3 配置
-// POST /api/v1/admin/settings/sora-s3/profiles
-func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
- var req CreateSoraS3ProfileRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- if req.DefaultStorageQuotaBytes < 0 {
- req.DefaultStorageQuotaBytes = 0
- }
- if strings.TrimSpace(req.Name) == "" {
- response.BadRequest(c, "Name is required")
- return
- }
- if strings.TrimSpace(req.ProfileID) == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
- if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
- response.BadRequest(c, err.Error())
- return
- }
-
- created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
- ProfileID: req.ProfileID,
- Name: req.Name,
- Enabled: req.Enabled,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
- }, req.SetActive)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- response.Success(c, toSoraS3ProfileDTO(*created))
-}
-
-// UpdateSoraS3Profile 更新 Sora S3 配置
-// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
-func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
- profileID := strings.TrimSpace(c.Param("profile_id"))
- if profileID == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
-
- var req UpdateSoraS3ProfileRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- if req.DefaultStorageQuotaBytes < 0 {
- req.DefaultStorageQuotaBytes = 0
- }
- if strings.TrimSpace(req.Name) == "" {
- response.BadRequest(c, "Name is required")
- return
- }
-
- existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- existing := findSoraS3ProfileByID(existingList.Items, profileID)
- if existing == nil {
- response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
- return
- }
- if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
- response.BadRequest(c, err.Error())
- return
- }
-
- updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
- Name: req.Name,
- Enabled: req.Enabled,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
- })
- if updateErr != nil {
- response.ErrorFrom(c, updateErr)
- return
- }
-
- response.Success(c, toSoraS3ProfileDTO(*updated))
-}
-
-// DeleteSoraS3Profile 删除 Sora S3 配置
-// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
-func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
- profileID := strings.TrimSpace(c.Param("profile_id"))
- if profileID == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
- if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"deleted": true})
-}
-
-// SetActiveSoraS3Profile 切换激活 Sora S3 配置
-// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
-func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
- profileID := strings.TrimSpace(c.Param("profile_id"))
- if profileID == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
- active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, toSoraS3ProfileDTO(*active))
-}
-
-// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
-// PUT /api/v1/admin/settings/sora-s3
-func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
- var req UpdateSoraS3SettingsRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- if req.DefaultStorageQuotaBytes < 0 {
- req.DefaultStorageQuotaBytes = 0
- }
- if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
- response.BadRequest(c, err.Error())
- return
- }
-
- settings := &service.SoraS3Settings{
- Enabled: req.Enabled,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
- }
- if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, toSoraS3SettingsDTO(updatedSettings))
-}
-
-// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
-// POST /api/v1/admin/settings/sora-s3/test
-func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
- if h.soraS3Storage == nil {
- response.Error(c, 500, "S3 存储服务未初始化")
- return
- }
-
- var req UpdateSoraS3SettingsRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
- if !req.Enabled {
- response.BadRequest(c, "S3 未启用,无法测试连接")
- return
- }
-
- if req.SecretAccessKey == "" {
- if req.ProfileID != "" {
- profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
- if err == nil {
- profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
- if profile != nil {
- req.SecretAccessKey = profile.SecretAccessKey
- }
- }
- }
- if req.SecretAccessKey == "" {
- existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err == nil {
- req.SecretAccessKey = existing.SecretAccessKey
- }
- }
- }
-
- testCfg := &service.SoraS3Settings{
- Enabled: true,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- }
- if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
- response.Error(c, 400, "S3 连接测试失败: "+err.Error())
- return
- }
- response.Success(c, gin.H{"message": "S3 连接成功"})
-}
-
// GetRectifierSettings 获取请求整流器配置
// GET /api/v1/admin/settings/rectifier
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index 998308dd99..a357657e20 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -34,14 +34,13 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
- Email string `json:"email" binding:"required,email"`
- Password string `json:"password" binding:"required,min=6"`
- Username string `json:"username"`
- Notes string `json:"notes"`
- Balance float64 `json:"balance"`
- Concurrency int `json:"concurrency"`
- AllowedGroups []int64 `json:"allowed_groups"`
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required,min=6"`
+ Username string `json:"username"`
+ Notes string `json:"notes"`
+ Balance float64 `json:"balance"`
+ Concurrency int `json:"concurrency"`
+ AllowedGroups []int64 `json:"allowed_groups"`
}
// UpdateUserRequest represents admin update user request
@@ -57,8 +56,7 @@ type UpdateUserRequest struct {
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
- GroupRates map[int64]*float64 `json:"group_rates"`
- SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
+ GroupRates map[int64]*float64 `json:"group_rates"`
}
// UpdateBalanceRequest represents balance update request
@@ -182,14 +180,13 @@ func (h *UserHandler) Create(c *gin.Context) {
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
- Email: req.Email,
- Password: req.Password,
- Username: req.Username,
- Notes: req.Notes,
- Balance: req.Balance,
- Concurrency: req.Concurrency,
- AllowedGroups: req.AllowedGroups,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
+ Email: req.Email,
+ Password: req.Password,
+ Username: req.Username,
+ Notes: req.Notes,
+ Balance: req.Balance,
+ Concurrency: req.Concurrency,
+ AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -216,16 +213,15 @@ func (h *UserHandler) Update(c *gin.Context) {
// 使用指针类型直接传递,nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
- Email: req.Email,
- Password: req.Password,
- Username: req.Username,
- Notes: req.Notes,
- Balance: req.Balance,
- Concurrency: req.Concurrency,
- Status: req.Status,
- AllowedGroups: req.AllowedGroups,
- GroupRates: req.GroupRates,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
+ Email: req.Email,
+ Password: req.Password,
+ Username: req.Username,
+ Notes: req.Notes,
+ Balance: req.Balance,
+ Concurrency: req.Concurrency,
+ Status: req.Status,
+ AllowedGroups: req.AllowedGroups,
+ GroupRates: req.GroupRates,
})
if err != nil {
response.ErrorFrom(c, err)
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index d9d657836d..2eab670e75 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -59,11 +59,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil
}
return &AdminUser{
- User: *base,
- Notes: u.Notes,
- GroupRates: u.GroupRates,
- SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
- SoraStorageUsedBytes: u.SoraStorageUsedBytes,
+ User: *base,
+ Notes: u.Notes,
+ GroupRates: u.GroupRates,
}
}
@@ -172,14 +170,9 @@ func groupFromServiceBase(g *service.Group) Group {
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
- SoraImagePrice360: g.SoraImagePrice360,
- SoraImagePrice540: g.SoraImagePrice540,
- SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
- SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
AllowMessagesDispatch: g.AllowMessagesDispatch,
RequireOAuthOnly: g.RequireOAuthOnly,
RequirePrivacySet: g.RequirePrivacySet,
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 47bab091d7..acc1129cd2 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -61,7 +61,6 @@ type SystemSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
@@ -128,49 +127,10 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version"`
}
-// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
-type SoraS3Settings struct {
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
-type SoraS3Profile struct {
- ProfileID string `json:"profile_id"`
- Name string `json:"name"`
- IsActive bool `json:"is_active"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
- UpdatedAt string `json:"updated_at"`
-}
-
-// ListSoraS3ProfilesResponse Sora S3 配置列表响应
-type ListSoraS3ProfilesResponse struct {
- ActiveProfileID string `json:"active_profile_id"`
- Items []SoraS3Profile `json:"items"`
-}
-
// OverloadCooldownSettings 529过载冷却配置 DTO
type OverloadCooldownSettings struct {
Enabled bool `json:"enabled"`
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 56b67c8c4d..82065deb72 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -26,9 +26,7 @@ type AdminUser struct {
Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
- GroupRates map[int64]float64 `json:"group_rates,omitempty"`
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
- SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
+ GroupRates map[int64]float64 `json:"group_rates,omitempty"`
}
type APIKey struct {
@@ -84,21 +82,12 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
- // Sora 按次计费配置
- SoraImagePrice360 *float64 `json:"sora_image_price_360"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
-
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
- // Sora 存储配额
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
-
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go
index b120098875..a897bc4054 100644
--- a/backend/internal/handler/endpoint.go
+++ b/backend/internal/handler/endpoint.go
@@ -31,7 +31,7 @@ const (
// ──────────────────────────────────────────────────────────
// NormalizeInboundEndpoint maps a raw request path (which may carry
-// prefixes like /antigravity, /openai, /sora) to its canonical form.
+// prefixes like /antigravity, /openai) to its canonical form.
//
// "/antigravity/v1/messages" → "/v1/messages"
// "/v1/chat/completions" → "/v1/chat/completions"
@@ -61,7 +61,7 @@ func NormalizeInboundEndpoint(path string) string {
// such as /v1/responses/compact preserved from the raw URL).
// - Anthropic → /v1/messages
// - Gemini → /v1beta/models
-// - Sora → /v1/chat/completions
+// - Antigravity → /v1/messages (Claude) or gemini (Gemini)
// - Antigravity routes may target either Claude or Gemini, so the
// inbound endpoint is used to distinguish.
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
@@ -82,9 +82,6 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
case service.PlatformGemini:
return EndpointGeminiModels
- case service.PlatformSora:
- return EndpointChatCompletions
-
case service.PlatformAntigravity:
// Antigravity accounts serve both Claude and Gemini.
if inbound == EndpointGeminiModels {
diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go
index a3767ac499..1519bc9e62 100644
--- a/backend/internal/handler/endpoint_test.go
+++ b/backend/internal/handler/endpoint_test.go
@@ -27,11 +27,10 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
{"/v1/responses", EndpointResponses},
{"/v1beta/models", EndpointGeminiModels},
- // Prefixed paths (antigravity, openai, sora).
+ // Prefixed paths (antigravity, openai).
{"/antigravity/v1/messages", EndpointMessages},
{"/openai/v1/responses", EndpointResponses},
{"/openai/v1/responses/compact", EndpointResponses},
- {"/sora/v1/chat/completions", EndpointChatCompletions},
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
// Gin route patterns with wildcards.
@@ -68,9 +67,6 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
// Gemini.
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
- // Sora.
- {"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
-
// OpenAI — always /v1/responses.
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index dfc9fb88b4..59619d508a 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -859,14 +859,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
platform = forcedPlatform
}
- if platform == service.PlatformSora {
- c.JSON(http.StatusOK, gin.H{
- "object": "list",
- "data": service.DefaultSoraModels(h.cfg),
- })
- return
- }
-
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index ebf8d5f674..d4c349fb65 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -45,8 +45,6 @@ type Handlers struct {
Admin *AdminHandlers
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
- SoraGateway *SoraGatewayHandler
- SoraClient *SoraClientHandler
Setting *SettingHandler
Totp *TotpHandler
}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 2c999cf13b..977c2301bb 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -54,7 +54,6 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
- SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version,
})
diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go
deleted file mode 100644
index 80acc83349..0000000000
--- a/backend/internal/handler/sora_client_handler.go
+++ /dev/null
@@ -1,979 +0,0 @@
-package handler
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "strconv"
- "strings"
- "sync"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
- "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
- "github.com/Wei-Shaw/sub2api/internal/pkg/response"
- middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
- "github.com/gin-gonic/gin"
-)
-
-const (
- // 上游模型缓存 TTL
- modelCacheTTL = 1 * time.Hour // 上游获取成功
- modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
-)
-
-// SoraClientHandler 处理 Sora 客户端 API 请求。
-type SoraClientHandler struct {
- genService *service.SoraGenerationService
- quotaService *service.SoraQuotaService
- s3Storage *service.SoraS3Storage
- soraGatewayService *service.SoraGatewayService
- gatewayService *service.GatewayService
- mediaStorage *service.SoraMediaStorage
- apiKeyService *service.APIKeyService
-
- // 上游模型缓存
- modelCacheMu sync.RWMutex
- cachedFamilies []service.SoraModelFamily
- modelCacheTime time.Time
- modelCacheUpstream bool // 是否来自上游(决定 TTL)
-}
-
-// NewSoraClientHandler 创建 Sora 客户端 Handler。
-func NewSoraClientHandler(
- genService *service.SoraGenerationService,
- quotaService *service.SoraQuotaService,
- s3Storage *service.SoraS3Storage,
- soraGatewayService *service.SoraGatewayService,
- gatewayService *service.GatewayService,
- mediaStorage *service.SoraMediaStorage,
- apiKeyService *service.APIKeyService,
-) *SoraClientHandler {
- return &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- s3Storage: s3Storage,
- soraGatewayService: soraGatewayService,
- gatewayService: gatewayService,
- mediaStorage: mediaStorage,
- apiKeyService: apiKeyService,
- }
-}
-
-// GenerateRequest 生成请求。
-type GenerateRequest struct {
- Model string `json:"model" binding:"required"`
- Prompt string `json:"prompt" binding:"required"`
- MediaType string `json:"media_type"` // video / image,默认 video
- VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
- ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
- APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
-}
-
-// Generate 异步生成 — 创建 pending 记录后立即返回。
-// POST /api/v1/sora/generate
-func (h *SoraClientHandler) Generate(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- var req GenerateRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
- return
- }
-
- if req.MediaType == "" {
- req.MediaType = "video"
- }
- req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
-
- // 并发数检查(最多 3 个)
- activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- if activeCount >= 3 {
- response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
- return
- }
-
- // 配额检查(粗略检查,实际文件大小在上传后才知道)
- if h.quotaService != nil {
- if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
- var quotaErr *service.QuotaExceededError
- if errors.As(err, "aErr) {
- response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
- return
- }
- response.Error(c, http.StatusForbidden, err.Error())
- return
- }
- }
-
- // 获取 API Key ID 和 Group ID
- var apiKeyID *int64
- var groupID *int64
-
- if req.APIKeyID != nil && h.apiKeyService != nil {
- // 前端传递了 api_key_id,需要校验
- apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "API Key 不存在")
- return
- }
- if apiKey.UserID != userID {
- response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
- return
- }
- if apiKey.Status != service.StatusAPIKeyActive {
- response.Error(c, http.StatusForbidden, "API Key 不可用")
- return
- }
- apiKeyID = &apiKey.ID
- groupID = apiKey.GroupID
- } else if id, ok := c.Get("api_key_id"); ok {
- // 兼容 API Key 认证路径(/sora/v1/ 网关路由)
- if v, ok := id.(int64); ok {
- apiKeyID = &v
- }
- }
-
- gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
- if err != nil {
- if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
- response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
- return
- }
- response.ErrorFrom(c, err)
- return
- }
-
- // 启动后台异步生成 goroutine
- go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
-
- response.Success(c, gin.H{
- "generation_id": gen.ID,
- "status": gen.Status,
- })
-}
-
-// processGeneration 后台异步执行 Sora 生成任务。
-// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
-func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
- defer cancel()
-
- // 标记为生成中
- if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
- if errors.Is(err, service.ErrSoraGenerationStateConflict) {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
- return
- }
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
- return
- }
-
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
- genID,
- userID,
- groupIDForLog(groupID),
- model,
- mediaType,
- videoCount,
- strings.TrimSpace(imageInput) != "",
- len(strings.TrimSpace(prompt)),
- )
-
- // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
- if groupID == nil {
- ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
- }
-
- if h.gatewayService == nil {
- _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
- return
- }
-
- // 选择 Sora 账号
- account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
- if err != nil {
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
- genID,
- userID,
- groupIDForLog(groupID),
- model,
- err,
- )
- _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
- return
- }
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
- genID,
- userID,
- groupIDForLog(groupID),
- model,
- account.ID,
- account.Name,
- account.Platform,
- account.Type,
- )
-
- // 构建 chat completions 请求体(非流式)
- body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
-
- if h.soraGatewayService == nil {
- _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
- return
- }
-
- // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
- recorder := httptest.NewRecorder()
- mockGinCtx, _ := gin.CreateTestContext(recorder)
- mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
-
- // 调用 Forward(非流式)
- result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
- if err != nil {
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
- genID,
- account.ID,
- model,
- recorder.Code,
- trimForLog(recorder.Body.String(), 400),
- err,
- )
- // 检查是否已取消
- gen, _ := h.genService.GetByID(ctx, genID, userID)
- if gen != nil && gen.Status == service.SoraGenStatusCancelled {
- return
- }
- _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
- return
- }
-
- // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
- mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
- if mediaURL == "" {
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
- genID,
- account.ID,
- model,
- recorder.Code,
- trimForLog(recorder.Body.String(), 400),
- )
- _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
- return
- }
-
- // 检查任务是否已被取消
- gen, _ := h.genService.GetByID(ctx, genID, userID)
- if gen != nil && gen.Status == service.SoraGenStatusCancelled {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
- return
- }
-
- // 三层降级存储:S3 → 本地 → 上游临时 URL
- storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
-
- usageAdded := false
- if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
- if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
- h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
- var quotaErr *service.QuotaExceededError
- if errors.As(err, "aErr) {
- _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
- return
- }
- _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
- return
- }
- usageAdded = true
- }
-
- // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
- gen, _ = h.genService.GetByID(ctx, genID, userID)
- if gen != nil && gen.Status == service.SoraGenStatusCancelled {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
- h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
- if usageAdded && h.quotaService != nil {
- _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
- }
- return
- }
-
- // 标记完成
- if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
- if errors.Is(err, service.ErrSoraGenerationStateConflict) {
- h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
- if usageAdded && h.quotaService != nil {
- _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
- }
- return
- }
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
- return
- }
-
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
-}
-
-// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
-func (h *SoraClientHandler) storeMediaWithDegradation(
- ctx context.Context, userID int64, mediaType string,
- mediaURL string, mediaURLs []string,
-) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
- urls := mediaURLs
- if len(urls) == 0 {
- urls = []string{mediaURL}
- }
-
- // 第一层:尝试 S3
- if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
- keys := make([]string, 0, len(urls))
- var totalSize int64
- allOK := true
- for _, u := range urls {
- key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
- if err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
- allOK = false
- // 清理已上传的文件
- if len(keys) > 0 {
- _ = h.s3Storage.DeleteObjects(ctx, keys)
- }
- break
- }
- keys = append(keys, key)
- totalSize += size
- }
- if allOK && len(keys) > 0 {
- accessURLs := make([]string, 0, len(keys))
- for _, key := range keys {
- accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
- if err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
- _ = h.s3Storage.DeleteObjects(ctx, keys)
- allOK = false
- break
- }
- accessURLs = append(accessURLs, accessURL)
- }
- if allOK && len(accessURLs) > 0 {
- return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
- }
- }
- }
-
- // 第二层:尝试本地存储
- if h.mediaStorage != nil && h.mediaStorage.Enabled() {
- storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
- if err == nil && len(storedPaths) > 0 {
- firstPath := storedPaths[0]
- totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
- if sizeErr != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
- }
- return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
- }
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
- }
-
- // 第三层:保留上游临时 URL
- return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
-}
-
-// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
-func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
- body := map[string]any{
- "model": model,
- "messages": []map[string]string{
- {"role": "user", "content": prompt},
- },
- "stream": false,
- }
- if imageInput != "" {
- body["image_input"] = imageInput
- }
- if videoCount > 1 {
- body["video_count"] = videoCount
- }
- b, _ := json.Marshal(body)
- return b
-}
-
-func normalizeVideoCount(mediaType string, videoCount int) int {
- if mediaType != "video" {
- return 1
- }
- if videoCount <= 0 {
- return 1
- }
- if videoCount > 3 {
- return 3
- }
- return videoCount
-}
-
-// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
-// OAuth 路径:ForwardResult.MediaURL 已填充。
-// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
-func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
- // 优先从 ForwardResult 获取(OAuth 路径)
- if result != nil && result.MediaURL != "" {
- // 尝试从响应体获取完整 URL 列表
- if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
- return urls[0], urls
- }
- return result.MediaURL, []string{result.MediaURL}
- }
-
- // 从响应体解析(APIKey 路径)
- if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
- return urls[0], urls
- }
-
- return "", nil
-}
-
-// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
-func parseMediaURLsFromBody(body []byte) []string {
- if len(body) == 0 {
- return nil
- }
- var resp map[string]any
- if err := json.Unmarshal(body, &resp); err != nil {
- return nil
- }
-
- // 优先 media_urls(多图数组)
- if rawURLs, ok := resp["media_urls"]; ok {
- if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
- urls := make([]string, 0, len(arr))
- for _, item := range arr {
- if s, ok := item.(string); ok && s != "" {
- urls = append(urls, s)
- }
- }
- if len(urls) > 0 {
- return urls
- }
- }
- }
-
- // 回退到 media_url(单个 URL)
- if url, ok := resp["media_url"].(string); ok && url != "" {
- return []string{url}
- }
-
- return nil
-}
-
-// ListGenerations 查询生成记录列表。
-// GET /api/v1/sora/generations
-func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
- pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
-
- params := service.SoraGenerationListParams{
- UserID: userID,
- Status: c.Query("status"),
- StorageType: c.Query("storage_type"),
- MediaType: c.Query("media_type"),
- Page: page,
- PageSize: pageSize,
- }
-
- gens, total, err := h.genService.List(c.Request.Context(), params)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- // 为 S3 记录动态生成预签名 URL
- for _, gen := range gens {
- _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
- }
-
- response.Success(c, gin.H{
- "data": gens,
- "total": total,
- "page": page,
- })
-}
-
-// GetGeneration 查询生成记录详情。
-// GET /api/v1/sora/generations/:id
-func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
- response.Success(c, gen)
-}
-
-// DeleteGeneration 删除生成记录。
-// DELETE /api/v1/sora/generations/:id
-func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
- if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
- paths := gen.MediaURLs
- if len(paths) == 0 && gen.MediaURL != "" {
- paths = []string{gen.MediaURL}
- }
- if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
- }
- }
-
- if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- response.Success(c, gin.H{"message": "已删除"})
-}
-
-// GetQuota 查询用户存储配额。
-// GET /api/v1/sora/quota
-func (h *SoraClientHandler) GetQuota(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- if h.quotaService == nil {
- response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
- return
- }
-
- quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, quota)
-}
-
-// CancelGeneration 取消生成任务。
-// POST /api/v1/sora/generations/:id/cancel
-func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- // 权限校验
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
- _ = gen
-
- if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
- if errors.Is(err, service.ErrSoraGenerationNotActive) {
- response.Error(c, http.StatusConflict, "任务已结束,无法取消")
- return
- }
- response.Error(c, http.StatusBadRequest, err.Error())
- return
- }
-
- response.Success(c, gin.H{"message": "已取消"})
-}
-
-// SaveToStorage 手动保存 upstream 记录到 S3。
-// POST /api/v1/sora/generations/:id/save
-func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- if gen.StorageType != service.SoraStorageTypeUpstream {
- response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
- return
- }
- if gen.MediaURL == "" {
- response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
- return
- }
-
- if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
- response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
- return
- }
-
- sourceURLs := gen.MediaURLs
- if len(sourceURLs) == 0 && gen.MediaURL != "" {
- sourceURLs = []string{gen.MediaURL}
- }
- if len(sourceURLs) == 0 {
- response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
- return
- }
-
- uploadedKeys := make([]string, 0, len(sourceURLs))
- accessURLs := make([]string, 0, len(sourceURLs))
- var totalSize int64
-
- for _, sourceURL := range sourceURLs {
- objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
- if uploadErr != nil {
- if len(uploadedKeys) > 0 {
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- }
- var upstreamErr *service.UpstreamDownloadError
- if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
- response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
- return
- }
- response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
- return
- }
- accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
- if err != nil {
- uploadedKeys = append(uploadedKeys, objectKey)
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
- return
- }
- uploadedKeys = append(uploadedKeys, objectKey)
- accessURLs = append(accessURLs, accessURL)
- totalSize += fileSize
- }
-
- usageAdded := false
- if totalSize > 0 && h.quotaService != nil {
- if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- var quotaErr *service.QuotaExceededError
- if errors.As(err, "aErr) {
- response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
- return
- }
- response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
- return
- }
- usageAdded = true
- }
-
- if err := h.genService.UpdateStorageForCompleted(
- c.Request.Context(),
- id,
- accessURLs[0],
- accessURLs,
- service.SoraStorageTypeS3,
- uploadedKeys,
- totalSize,
- ); err != nil {
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- if usageAdded && h.quotaService != nil {
- _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
- }
- response.ErrorFrom(c, err)
- return
- }
-
- response.Success(c, gin.H{
- "message": "已保存到 S3",
- "object_key": uploadedKeys[0],
- "object_keys": uploadedKeys,
- })
-}
-
-// GetStorageStatus 返回存储状态。
-// GET /api/v1/sora/storage-status
-func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
- s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
- s3Healthy := false
- if s3Enabled {
- s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
- }
- localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
- response.Success(c, gin.H{
- "s3_enabled": s3Enabled,
- "s3_healthy": s3Healthy,
- "local_enabled": localEnabled,
- })
-}
-
-func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
- switch storageType {
- case service.SoraStorageTypeS3:
- if h.s3Storage != nil && len(s3Keys) > 0 {
- if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
- }
- }
- case service.SoraStorageTypeLocal:
- if h.mediaStorage != nil && len(localPaths) > 0 {
- if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
- }
- }
- }
-}
-
-// getUserIDFromContext 从 gin 上下文中提取用户 ID。
-func getUserIDFromContext(c *gin.Context) int64 {
- if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
- return subject.UserID
- }
-
- if id, ok := c.Get("user_id"); ok {
- switch v := id.(type) {
- case int64:
- return v
- case float64:
- return int64(v)
- case string:
- n, _ := strconv.ParseInt(v, 10, 64)
- return n
- }
- }
- // 尝试从 JWT claims 获取
- if id, ok := c.Get("userID"); ok {
- if v, ok := id.(int64); ok {
- return v
- }
- }
- return 0
-}
-
-func groupIDForLog(groupID *int64) int64 {
- if groupID == nil {
- return 0
- }
- return *groupID
-}
-
-func trimForLog(raw string, maxLen int) string {
- trimmed := strings.TrimSpace(raw)
- if maxLen <= 0 || len(trimmed) <= maxLen {
- return trimmed
- }
- return trimmed[:maxLen] + "...(truncated)"
-}
-
-// GetModels 获取可用 Sora 模型家族列表。
-// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
-// GET /api/v1/sora/models
-func (h *SoraClientHandler) GetModels(c *gin.Context) {
- families := h.getModelFamilies(c.Request.Context())
- response.Success(c, families)
-}
-
-// getModelFamilies 获取模型家族列表(带缓存)。
-func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
- // 读锁检查缓存
- h.modelCacheMu.RLock()
- ttl := modelCacheTTL
- if !h.modelCacheUpstream {
- ttl = modelCacheFailedTTL
- }
- if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
- families := h.cachedFamilies
- h.modelCacheMu.RUnlock()
- return families
- }
- h.modelCacheMu.RUnlock()
-
- // 写锁更新缓存
- h.modelCacheMu.Lock()
- defer h.modelCacheMu.Unlock()
-
- // double-check
- ttl = modelCacheTTL
- if !h.modelCacheUpstream {
- ttl = modelCacheFailedTTL
- }
- if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
- return h.cachedFamilies
- }
-
- // 尝试从上游获取
- families, err := h.fetchUpstreamModels(ctx)
- if err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
- families = service.BuildSoraModelFamilies()
- h.cachedFamilies = families
- h.modelCacheTime = time.Now()
- h.modelCacheUpstream = false
- return families
- }
-
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
- h.cachedFamilies = families
- h.modelCacheTime = time.Now()
- h.modelCacheUpstream = true
- return families
-}
-
-// fetchUpstreamModels 从上游 Sora API 获取模型列表。
-func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
- if h.gatewayService == nil {
- return nil, fmt.Errorf("gatewayService 未初始化")
- }
-
- // 设置 ForcePlatform 用于 Sora 账号选择
- ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
-
- // 选择一个 Sora 账号
- account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
- if err != nil {
- return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
- }
-
- // 仅支持 API Key 类型账号
- if account.Type != service.AccountTypeAPIKey {
- return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
- }
-
- apiKey := account.GetCredential("api_key")
- if apiKey == "" {
- return nil, fmt.Errorf("账号缺少 api_key")
- }
-
- baseURL := account.GetBaseURL()
- if baseURL == "" {
- return nil, fmt.Errorf("账号缺少 base_url")
- }
-
- // 构建上游模型列表请求
- modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
-
- reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
-
- req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
- if err != nil {
- return nil, fmt.Errorf("创建请求失败: %w", err)
- }
- req.Header.Set("Authorization", "Bearer "+apiKey)
-
- client := &http.Client{Timeout: 10 * time.Second}
- resp, err := client.Do(req)
- if err != nil {
- return nil, fmt.Errorf("请求上游失败: %w", err)
- }
- defer func() {
- _ = resp.Body.Close()
- }()
-
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
- }
-
- body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
- if err != nil {
- return nil, fmt.Errorf("读取响应失败: %w", err)
- }
-
- // 解析 OpenAI 格式的模型列表
- var modelsResp struct {
- Data []struct {
- ID string `json:"id"`
- } `json:"data"`
- }
- if err := json.Unmarshal(body, &modelsResp); err != nil {
- return nil, fmt.Errorf("解析响应失败: %w", err)
- }
-
- if len(modelsResp.Data) == 0 {
- return nil, fmt.Errorf("上游返回空模型列表")
- }
-
- // 提取模型 ID
- modelIDs := make([]string, 0, len(modelsResp.Data))
- for _, m := range modelsResp.Data {
- modelIDs = append(modelIDs, m.ID)
- }
-
- // 转换为模型家族
- families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
- if len(families) == 0 {
- return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
- }
-
- return families, nil
-}
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
deleted file mode 100644
index 5705578660..0000000000
--- a/backend/internal/handler/sora_client_handler_test.go
+++ /dev/null
@@ -1,3179 +0,0 @@
-//go:build unit
-
-package handler
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "os"
- "strings"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
- middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
-)
-
-func init() {
- gin.SetMode(gin.TestMode)
-}
-
-// ==================== Stub: SoraGenerationRepository ====================
-
-var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
-
-type stubSoraGenRepo struct {
- gens map[int64]*service.SoraGeneration
- nextID int64
- createErr error
- getErr error
- updateErr error
- deleteErr error
- listErr error
- countErr error
- countValue int64
-
- // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
- updateCallCount *int32
- updateFailAfterN int32
-
- // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
- getByIDCallCount int32
- getByIDOverrideAfterN int32 // 0 = 不覆盖
- getByIDOverrideStatus string
-}
-
-func newStubSoraGenRepo() *stubSoraGenRepo {
- return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
-}
-
-func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
- if r.createErr != nil {
- return r.createErr
- }
- gen.ID = r.nextID
- r.nextID++
- r.gens[gen.ID] = gen
- return nil
-}
-func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
- if r.getErr != nil {
- return nil, r.getErr
- }
- gen, ok := r.gens[id]
- if !ok {
- return nil, fmt.Errorf("not found")
- }
- // 条件性状态覆盖:模拟外部取消等场景
- if r.getByIDOverrideAfterN > 0 {
- n := atomic.AddInt32(&r.getByIDCallCount, 1)
- if n > r.getByIDOverrideAfterN {
- cp := *gen
- cp.Status = r.getByIDOverrideStatus
- return &cp, nil
- }
- }
- return gen, nil
-}
-func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
- // 条件性失败:前 N 次成功,之后失败
- if r.updateCallCount != nil {
- n := atomic.AddInt32(r.updateCallCount, 1)
- if n > r.updateFailAfterN {
- return fmt.Errorf("conditional update error (call #%d)", n)
- }
- }
- if r.updateErr != nil {
- return r.updateErr
- }
- r.gens[gen.ID] = gen
- return nil
-}
-func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
- if r.deleteErr != nil {
- return r.deleteErr
- }
- delete(r.gens, id)
- return nil
-}
-func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
- if r.listErr != nil {
- return nil, 0, r.listErr
- }
- var result []*service.SoraGeneration
- for _, gen := range r.gens {
- if gen.UserID != params.UserID {
- continue
- }
- result = append(result, gen)
- }
- return result, int64(len(result)), nil
-}
-func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
- if r.countErr != nil {
- return 0, r.countErr
- }
- return r.countValue, nil
-}
-
-// ==================== 辅助函数 ====================
-
-func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
- genService := service.NewSoraGenerationService(repo, nil, nil)
- return &SoraClientHandler{genService: genService}
-}
-
-func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- if body != "" {
- c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
- c.Request.Header.Set("Content-Type", "application/json")
- } else {
- c.Request = httptest.NewRequest(method, path, nil)
- }
- if userID > 0 {
- c.Set("user_id", userID)
- }
- return c, rec
-}
-
-func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
- t.Helper()
- var resp map[string]any
- require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
- return resp
-}
-
-// ==================== 纯函数测试: buildAsyncRequestBody ====================
-
-func TestBuildAsyncRequestBody(t *testing.T) {
- body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, "sora2-landscape-10s", parsed["model"])
- require.Equal(t, false, parsed["stream"])
-
- msgs := parsed["messages"].([]any)
- require.Len(t, msgs, 1)
- msg := msgs[0].(map[string]any)
- require.Equal(t, "user", msg["role"])
- require.Equal(t, "一只猫在跳舞", msg["content"])
-}
-
-func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
- body := buildAsyncRequestBody("gpt-image", "", "", 1)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, "gpt-image", parsed["model"])
- msgs := parsed["messages"].([]any)
- msg := msgs[0].(map[string]any)
- require.Equal(t, "", msg["content"])
-}
-
-func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
- body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
-}
-
-func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
- body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, float64(3), parsed["video_count"])
-}
-
-func TestNormalizeVideoCount(t *testing.T) {
- require.Equal(t, 1, normalizeVideoCount("video", 0))
- require.Equal(t, 2, normalizeVideoCount("video", 2))
- require.Equal(t, 3, normalizeVideoCount("video", 5))
- require.Equal(t, 1, normalizeVideoCount("image", 3))
-}
-
-// ==================== 纯函数测试: parseMediaURLsFromBody ====================
-
-func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
- urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
- require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
-}
-
-func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
- urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
- require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
-}
-
-func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody(nil))
- require.Nil(t, parseMediaURLsFromBody([]byte{}))
-}
-
-func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
-}
-
-func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
-}
-
-func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
-}
-
-func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
-}
-
-func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
- body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
- urls := parseMediaURLsFromBody([]byte(body))
- require.Len(t, urls, 2)
- require.Equal(t, "https://multi.com/a.mp4", urls[0])
-}
-
-func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
- urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
- require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
-}
-
-func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
-}
-
-func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
- // media_urls 不是 string 数组
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
-}
-
-func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
-}
-
-// ==================== 纯函数测试: extractMediaURLsFromResult ====================
-
-func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
- result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
- recorder := httptest.NewRecorder()
- url, urls := extractMediaURLsFromResult(result, recorder)
- require.Equal(t, "https://oauth.com/video.mp4", url)
- require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
-}
-
-func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
- result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
- recorder := httptest.NewRecorder()
- _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
- url, urls := extractMediaURLsFromResult(result, recorder)
- require.Equal(t, "https://body.com/1.mp4", url)
- require.Len(t, urls, 2)
-}
-
-func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
- recorder := httptest.NewRecorder()
- _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
- url, urls := extractMediaURLsFromResult(nil, recorder)
- require.Equal(t, "https://upstream.com/video.mp4", url)
- require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
-}
-
-func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
- recorder := httptest.NewRecorder()
- url, urls := extractMediaURLsFromResult(nil, recorder)
- require.Empty(t, url)
- require.Nil(t, urls)
-}
-
-func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
- result := &service.ForwardResult{MediaURL: ""}
- recorder := httptest.NewRecorder()
- url, urls := extractMediaURLsFromResult(result, recorder)
- require.Empty(t, url)
- require.Nil(t, urls)
-}
-
-// ==================== getUserIDFromContext ====================
-
-func TestGetUserIDFromContext_Int64(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", int64(42))
- require.Equal(t, int64(42), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
- require.Equal(t, int64(777), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_Float64(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", float64(99))
- require.Equal(t, int64(99), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_String(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", "123")
- require.Equal(t, int64(123), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("userID", int64(55))
- require.Equal(t, int64(55), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_NoID(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- require.Equal(t, int64(0), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_InvalidString(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", "not-a-number")
- require.Equal(t, int64(0), getUserIDFromContext(c))
-}
-
-// ==================== Handler: Generate ====================
-
-func TestGenerate_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
- h.Generate(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestGenerate_BadRequest_MissingModel(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGenerate_TooManyRequests(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.countValue = 3
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
-}
-
-func TestGenerate_CountError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.countErr = fmt.Errorf("db error")
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-func TestGenerate_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.NotZero(t, data["generation_id"])
- require.Equal(t, "pending", data["status"])
-}
-
-func TestGenerate_DefaultMediaType(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "video", repo.gens[1].MediaType)
-}
-
-func TestGenerate_ImageMediaType(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "image", repo.gens[1].MediaType)
-}
-
-func TestGenerate_CreatePendingError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.createErr = fmt.Errorf("create failed")
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-func TestGenerate_APIKeyInContext(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- c.Set("api_key_id", int64(42))
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_NoAPIKeyInContext(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Nil(t, repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_ConcurrencyBoundary(t *testing.T) {
- // activeCount == 2 应该允许
- repo := newStubSoraGenRepo()
- repo.countValue = 2
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-// ==================== Handler: ListGenerations ====================
-
-func TestListGenerations_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
- h.ListGenerations(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestListGenerations_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
- repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
- repo.nextID = 3
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
- h.ListGenerations(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- items := data["data"].([]any)
- require.Len(t, items, 2)
- require.Equal(t, float64(2), data["total"])
-}
-
-func TestListGenerations_ListError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.listErr = fmt.Errorf("db error")
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
- h.ListGenerations(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-func TestListGenerations_DefaultPagination(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- // 不传分页参数,应默认 page=1 page_size=20
- c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
- h.ListGenerations(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, float64(1), data["page"])
-}
-
-// ==================== Handler: GetGeneration ====================
-
-func TestGetGeneration_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestGetGeneration_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGetGeneration_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestGetGeneration_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestGetGeneration_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, float64(1), data["id"])
-}
-
-// ==================== Handler: DeleteGeneration ====================
-
-func TestDeleteGeneration_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestDeleteGeneration_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestDeleteGeneration_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestDeleteGeneration_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestDeleteGeneration_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- _, exists := repo.gens[1]
- require.False(t, exists)
-}
-
-// ==================== Handler: CancelGeneration ====================
-
-func TestCancelGeneration_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestCancelGeneration_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestCancelGeneration_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestCancelGeneration_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestCancelGeneration_Pending(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "cancelled", repo.gens[1].Status)
-}
-
-func TestCancelGeneration_Generating(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "cancelled", repo.gens[1].Status)
-}
-
-func TestCancelGeneration_Completed(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
-
-func TestCancelGeneration_Failed(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
-
-func TestCancelGeneration_Cancelled(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
-
-// ==================== Handler: GetQuota ====================
-
-func TestGetQuota_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
- h.GetQuota(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestGetQuota_NilQuotaService(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
- h.GetQuota(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, "unlimited", data["source"])
-}
-
-// ==================== Handler: GetModels ====================
-
-func TestGetModels(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
- h.GetModels(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].([]any)
- require.Len(t, data, 4)
- // 验证类型分布
- videoCount, imageCount := 0, 0
- for _, item := range data {
- m := item.(map[string]any)
- if m["type"] == "video" {
- videoCount++
- } else if m["type"] == "image" {
- imageCount++
- }
- }
- require.Equal(t, 3, videoCount)
- require.Equal(t, 1, imageCount)
-}
-
-// ==================== Handler: GetStorageStatus ====================
-
-func TestGetStorageStatus_NilS3(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, false, data["s3_enabled"])
- require.Equal(t, false, data["s3_healthy"])
- require.Equal(t, false, data["local_enabled"])
-}
-
-func TestGetStorageStatus_LocalEnabled(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, false, data["s3_enabled"])
- require.Equal(t, false, data["s3_healthy"])
- require.Equal(t, true, data["local_enabled"])
-}
-
-// ==================== Handler: SaveToStorage ====================
-
-func TestSaveToStorage_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestSaveToStorage_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestSaveToStorage_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestSaveToStorage_NotUpstream(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestSaveToStorage_S3Nil(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusServiceUnavailable, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
-}
-
-func TestSaveToStorage_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-// ==================== storeMediaWithDegradation — nil guard 路径 ====================
-
-func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
- h := &SoraClientHandler{}
- url, urls, storageType, keys, size := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Equal(t, "https://upstream.com/v.mp4", url)
- require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
- require.Nil(t, keys)
- require.Equal(t, int64(0), size)
-}
-
-func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
- h := &SoraClientHandler{}
- url, urls, storageType, keys, size := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Equal(t, "https://a.com/1.mp4", url)
- require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
- require.Nil(t, keys)
- require.Equal(t, int64(0), size)
-}
-
-func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
- h := &SoraClientHandler{}
- url, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Equal(t, "https://upstream.com/v.mp4", url)
-}
-
-// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
-
-var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
-
-type stubUserRepoForHandler struct {
- users map[int64]*service.User
- updateErr error
-}
-
-func newStubUserRepoForHandler() *stubUserRepoForHandler {
- return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
-}
-
-func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
- if u, ok := r.users[id]; ok {
- return u, nil
- }
- return nil, fmt.Errorf("user not found")
-}
-func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
- if r.updateErr != nil {
- return r.updateErr
- }
- r.users[user.ID] = user
- return nil
-}
-func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
-func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
- return nil, nil
-}
-func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
- return nil, nil
-}
-func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
-func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
-func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
-func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
- return false, nil
-}
-func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
-func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
-
-// ==================== NewSoraClientHandler ====================
-
-func TestNewSoraClientHandler(t *testing.T) {
- h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
- require.NotNil(t, h)
-}
-
-func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
- h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
- require.NotNil(t, h)
- require.Nil(t, h.apiKeyService)
-}
-
-// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
-
-var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
-
-type stubAPIKeyRepoForHandler struct {
- keys map[int64]*service.APIKey
- getErr error
-}
-
-func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
- return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
-}
-
-func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
- if r.getErr != nil {
- return nil, r.getErr
- }
- if k, ok := r.keys[id]; ok {
- return k, nil
- }
- return nil, fmt.Errorf("api key not found: %d", id)
-}
-func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
-func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
- return "", 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
-func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
-func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
- return false, nil
-}
-func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) {
- var updated int64
- for id, key := range r.keys {
- if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID {
- continue
- }
- clone := *key
- gid := newGroupID
- clone.GroupID = &gid
- r.keys[id] = &clone
- updated++
- }
- return updated, nil
-}
-func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
- return nil
-}
-func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error {
- return nil
-}
-func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error {
- return nil
-}
-func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) {
- return nil, nil
-}
-
-// newTestAPIKeyService 创建测试用的 APIKeyService
-func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
- return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
-}
-
-// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
-
-func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
- // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- groupID := int64(5)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyActive,
- GroupID: &groupID,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.NotZero(t, data["generation_id"])
-
- // 验证 api_key_id 已关联到生成记录
- gen := repo.gens[1]
- require.NotNil(t, gen.APIKeyID)
- require.Equal(t, int64(42), *gen.APIKeyID)
-}
-
-func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
- // 前端传递不存在的 api_key_id → 400
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
-}
-
-func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
- // 前端传递别人的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 999, // 属于 user 999
- Status: service.StatusAPIKeyActive,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
-}
-
-func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
- // 前端传递已禁用的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyDisabled,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
-}
-
-func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
- // 前端传递配额耗尽的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyQuotaExhausted,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
-}
-
-func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
- // 前端传递已过期的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyExpired,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
-}
-
-func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
- // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- h := &SoraClientHandler{genService: genService} // apiKeyService = nil
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
- require.Nil(t, repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
- // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyActive,
- GroupID: nil, // 无分组
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
- // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Nil(t, repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
- // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- groupID := int64(10)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyActive,
- GroupID: &groupID,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // 应使用 body 中的 api_key_id=42,而不是 context 中的 99
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
- // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- c.Set("api_key_id", int64(99))
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // 应使用 context 中的 api_key_id=99
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
- // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
- // api_key_id=0 不存在 → 400
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
-
-func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
- // groupID 不为 nil → 不设置 ForcePlatform
- // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- gid := int64(5)
- h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
-}
-
-func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
- // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
-}
-
-func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
- // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
- require.Equal(t, "cancelled", repo.gens[1].Status)
-}
-
-// ==================== GenerateRequest JSON 解析 ====================
-
-func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
- // 验证 api_key_id 在 JSON 中正确解析为 *int64
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
- require.NoError(t, err)
- require.NotNil(t, req.APIKeyID)
- require.Equal(t, int64(42), *req.APIKeyID)
-}
-
-func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
- // 不传 api_key_id → 解析后为 nil
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
- require.NoError(t, err)
- require.Nil(t, req.APIKeyID)
-}
-
-func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
- // api_key_id: null → 解析后为 nil
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
- require.NoError(t, err)
- require.Nil(t, req.APIKeyID)
-}
-
-func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
- // 全字段解析
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{
- "model":"sora2-landscape-10s",
- "prompt":"test prompt",
- "media_type":"video",
- "video_count":2,
- "image_input":"data:image/png;base64,abc",
- "api_key_id":100
- }`), &req)
- require.NoError(t, err)
- require.Equal(t, "sora2-landscape-10s", req.Model)
- require.Equal(t, "test prompt", req.Prompt)
- require.Equal(t, "video", req.MediaType)
- require.Equal(t, 2, req.VideoCount)
- require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
- require.NotNil(t, req.APIKeyID)
- require.Equal(t, int64(100), *req.APIKeyID)
-}
-
-func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
- // api_key_id 为 nil 时 JSON 序列化应省略
- req := GenerateRequest{Model: "sora2", Prompt: "test"}
- b, err := json.Marshal(req)
- require.NoError(t, err)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(b, &parsed))
- _, hasAPIKeyID := parsed["api_key_id"]
- require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
-}
-
-func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
- // api_key_id 不为 nil 时 JSON 序列化应包含
- id := int64(42)
- req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
- b, err := json.Marshal(req)
- require.NoError(t, err)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(b, &parsed))
- require.Equal(t, float64(42), parsed["api_key_id"])
-}
-
-// ==================== GetQuota: 有配额服务 ====================
-
-func TestGetQuota_WithQuotaService_Success(t *testing.T) {
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 10 * 1024 * 1024,
- SoraStorageUsedBytes: 3 * 1024 * 1024,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
- h.GetQuota(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, "user", data["source"])
- require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
- require.Equal(t, float64(3*1024*1024), data["used_bytes"])
-}
-
-func TestGetQuota_WithQuotaService_Error(t *testing.T) {
- // 用户不存在时 GetQuota 返回错误
- userRepo := newStubUserRepoForHandler()
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
- h.GetQuota(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== Generate: 配额检查 ====================
-
-func TestGenerate_QuotaCheckFailed(t *testing.T) {
- // 配额超限时返回 429
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 1024,
- SoraStorageUsedBytes: 1025, // 已超限
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
-}
-
-func TestGenerate_QuotaCheckPassed(t *testing.T) {
- // 配额充足时允许生成
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 10 * 1024 * 1024,
- SoraStorageUsedBytes: 0,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
-
-var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
-
-type stubSettingRepoForHandler struct {
- values map[string]string
-}
-
-func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
- if values == nil {
- values = make(map[string]string)
- }
- return &stubSettingRepoForHandler{values: values}
-}
-
-func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
- if v, ok := r.values[key]; ok {
- return &service.Setting{Key: key, Value: v}, nil
- }
- return nil, service.ErrSettingNotFound
-}
-func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
- if v, ok := r.values[key]; ok {
- return v, nil
- }
- return "", service.ErrSettingNotFound
-}
-func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
- r.values[key] = value
- return nil
-}
-func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
- result := make(map[string]string)
- for _, k := range keys {
- if v, ok := r.values[k]; ok {
- result[k] = v
- }
- }
- return result, nil
-}
-func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
- for k, v := range settings {
- r.values[k] = v
- }
- return nil
-}
-func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
- return r.values, nil
-}
-func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
- delete(r.values, key)
- return nil
-}
-
-// ==================== S3 / MediaStorage 辅助函数 ====================
-
-// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
-func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
- settingRepo := newStubSettingRepoForHandler(map[string]string{
- "sora_s3_enabled": "true",
- "sora_s3_endpoint": endpoint,
- "sora_s3_region": "us-east-1",
- "sora_s3_bucket": "test-bucket",
- "sora_s3_access_key_id": "AKIATEST",
- "sora_s3_secret_access_key": "test-secret",
- "sora_s3_prefix": "sora",
- "sora_s3_force_path_style": "true",
- })
- settingService := service.NewSettingService(settingRepo, &config.Config{})
- return service.NewSoraS3Storage(settingService)
-}
-
-// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
-func newFakeSourceServer() *httptest.Server {
- return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "video/mp4")
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte("fake video data for test"))
- }))
-}
-
-// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
-// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
-func newFakeS3Server(mode string) *httptest.Server {
- var counter atomic.Int32
- return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, _ = io.Copy(io.Discard, r.Body)
- _ = r.Body.Close()
-
- switch mode {
- case "ok":
- w.Header().Set("ETag", `"test-etag"`)
- w.WriteHeader(http.StatusOK)
- case "fail":
- w.WriteHeader(http.StatusForbidden)
- _, _ = w.Write([]byte(`AccessDenied`))
- case "fail-second":
- n := counter.Add(1)
- if n <= 1 {
- w.Header().Set("ETag", `"test-etag"`)
- w.WriteHeader(http.StatusOK)
- } else {
- w.WriteHeader(http.StatusForbidden)
- _, _ = w.Write([]byte(`AccessDenied`))
- }
- }
- }))
-}
-
-// ==================== processGeneration 直接调用测试 ====================
-
-func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- repo.updateErr = fmt.Errorf("db error")
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- // 直接调用(非 goroutine),MarkGenerating 失败 → 早退
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
- // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
- // 因此 ErrorMessage 为空(证明未调用 MarkFailed)
- require.Equal(t, "generating", repo.gens[1].Status)
- require.Empty(t, repo.gens[1].ErrorMessage)
-}
-
-func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
- // gatewayService 未设置 → MarkFailed
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
-}
-
-// ==================== storeMediaWithDegradation: S3 路径 ====================
-
-func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeS3, storageType)
- require.Len(t, s3Keys, 1)
- require.NotEmpty(t, s3Keys[0])
- require.Len(t, storedURLs, 1)
- require.Equal(t, storedURL, storedURLs[0])
- require.Contains(t, storedURL, fakeS3.URL)
- require.Contains(t, storedURL, "/test-bucket/")
- require.Greater(t, fileSize, int64(0))
-}
-
-func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
- storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
- )
- require.Equal(t, service.SoraStorageTypeS3, storageType)
- require.Len(t, s3Keys, 2)
- require.Len(t, storedURLs, 2)
- require.Equal(t, storedURL, storedURLs[0])
- require.Contains(t, storedURLs[0], fakeS3.URL)
- require.Contains(t, storedURLs[1], fakeS3.URL)
- require.Greater(t, fileSize, int64(0))
-}
-
-func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
- // 上游返回 404 → 下载失败 → S3 上传不会开始
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
- badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNotFound)
- }))
- defer badSource.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- _, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
-}
-
-func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- // S3 失败,降级到 upstream
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Nil(t, s3Keys)
-}
-
-func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail-second")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
- _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
- )
- // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Nil(t, s3Keys)
-}
-
-// ==================== storeMediaWithDegradation: 本地存储路径 ====================
-
-func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
- // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: "/dev/null/invalid_dir",
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- _, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
- )
- // 本地存储失败,降级到 upstream
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
-}
-
-func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- DownloadTimeoutSeconds: 5,
- MaxDownloadBytes: 10 * 1024 * 1024,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeLocal, storageType)
- require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
-}
-
-func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- DownloadTimeoutSeconds: 5,
- MaxDownloadBytes: 10 * 1024 * 1024,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{
- s3Storage: s3Storage,
- mediaStorage: mediaStorage,
- }
-
- _, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- // S3 失败 → 本地存储成功
- require.Equal(t, service.SoraStorageTypeLocal, storageType)
-}
-
-// ==================== SaveToStorage: S3 路径 ====================
-
-func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, resp["message"], "S3")
-}
-
-func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
- expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusForbidden)
- }))
- defer expiredServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: expiredServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusGone, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "过期")
-}
-
-func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Contains(t, data["message"], "S3")
- require.NotEmpty(t, data["object_key"])
- // 验证记录已更新为 S3 存储
- require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
-}
-
-func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v1.mp4",
- MediaURLs: []string{
- sourceServer.URL + "/v1.mp4",
- sourceServer.URL + "/v2.mp4",
- },
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Len(t, data["object_keys"].([]any), 2)
- require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
- require.Len(t, repo.gens[1].S3ObjectKeys, 2)
- require.Len(t, repo.gens[1].MediaURLs, 2)
-}
-
-func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 100 * 1024 * 1024,
- SoraStorageUsedBytes: 0,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // 验证配额已累加
- require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
-}
-
-func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
- repo.updateErr = fmt.Errorf("db error")
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== GetStorageStatus: S3 路径 ====================
-
-func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
- // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, true, data["s3_enabled"])
- require.Equal(t, false, data["s3_healthy"])
-}
-
-func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, true, data["s3_enabled"])
- require.Equal(t, true, data["s3_healthy"])
-}
-
-// ==================== Stub: AccountRepository (用于 GatewayService) ====================
-
-var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
-
-type stubAccountRepoForHandler struct {
- accounts []service.Account
-}
-
-func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
-func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
- for i := range r.accounts {
- if r.accounts[i].ID == id {
- return &r.accounts[i], nil
- }
- }
- return nil, fmt.Errorf("account not found")
-}
-func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
- return false, nil
-}
-func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
-func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
-func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
- return 0, nil
-}
-func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
-func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
- return 0, nil
-}
-
-func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
- return nil
-}
-
-func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
- return nil
-}
-
-// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
-
-var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
-
-type stubSoraClientForHandler struct {
- videoStatus *service.SoraVideoTaskStatus
-}
-
-func (s *stubSoraClientForHandler) Enabled() bool { return true }
-func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
- return "task-image", nil
-}
-func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
- return "task-video", nil
-}
-func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
- return "task-video", nil
-}
-func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
- return nil, nil
-}
-func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
- return nil, nil
-}
-func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
- return nil
-}
-func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
- return nil
-}
-func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
- return nil
-}
-func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
- return nil, nil
-}
-func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
- return s.videoStatus, nil
-}
-
-// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
-
-// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
-func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
- return service.NewGatewayService(
- accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
- nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
- nil, nil,
- )
-}
-
-// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
-func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
-}
-
-// ==================== processGeneration: 更多路径测试 ====================
-
-func TestProcessGeneration_SelectAccountError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
- accountRepo := &stubAccountRepoForHandler{accounts: nil}
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
-}
-
-func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- // 提供可用账号使 SelectAccountForModel 成功
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- // soraGatewayService 为 nil
- h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
-}
-
-func TestProcessGeneration_ForwardError(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- // SoraClient 返回视频任务失败
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "failed",
- ErrorMsg: "content policy violation",
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
-}
-
-func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
- // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
- repo.getByIDOverrideAfterN = 1
- repo.getByIDOverrideStatus = "cancelled"
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
- require.Equal(t, "generating", repo.gens[1].Status)
-}
-
-func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- // SoraClient 返回 completed 但无 URL
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: nil, // 无 URL
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
-}
-
-func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
- // 第 2 次返回 "cancelled" 状态,模拟外部取消
- repo.getByIDOverrideAfterN = 1
- repo.getByIDOverrideStatus = "cancelled"
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
- require.Equal(t, "generating", repo.gens[1].Status)
-}
-
-func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- // 无 S3 和本地存储,降级到 upstream
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- require.Equal(t, "completed", repo.gens[1].Status)
- require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
- require.NotEmpty(t, repo.gens[1].MediaURL)
-}
-
-func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{sourceServer.URL + "/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- s3Storage := newS3StorageForHandler(fakeS3.URL)
-
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- s3Storage: s3Storage,
- quotaService: quotaService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- require.Equal(t, "completed", repo.gens[1].Status)
- require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
- require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
- require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
- // 验证配额已累加
- require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
-}
-
-func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
- repo.updateCallCount = new(int32)
- repo.updateFailAfterN = 1
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
- // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
- // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
- require.Equal(t, "completed", repo.gens[1].Status)
-}
-
-// ==================== cleanupStoredMedia 直接测试 ====================
-
-func TestCleanupStoredMedia_S3Path(t *testing.T) {
- // S3 清理路径:s3Storage 为 nil 时不 panic
- h := &SoraClientHandler{}
- // 不应 panic
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
-}
-
-func TestCleanupStoredMedia_LocalPath(t *testing.T) {
- // 本地清理路径:mediaStorage 为 nil 时不 panic
- h := &SoraClientHandler{}
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
-}
-
-func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
- // upstream 类型不清理
- h := &SoraClientHandler{}
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
-}
-
-func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
- // 空 keys 不触发清理
- h := &SoraClientHandler{}
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
-}
-
-// ==================== DeleteGeneration: 本地存储清理路径 ====================
-
-func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1,
- UserID: 1,
- Status: "completed",
- StorageType: service.SoraStorageTypeLocal,
- MediaURL: "video/test.mp4",
- MediaURLs: []string{"video/test.mp4"},
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- _, exists := repo.gens[1]
- require.False(t, exists)
-}
-
-func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
- // MediaURLs 为空,使用 MediaURL 作为清理路径
- tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1,
- UserID: 1,
- Status: "completed",
- StorageType: service.SoraStorageTypeLocal,
- MediaURL: "video/test.mp4",
- MediaURLs: nil, // 空
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
- // 非本地存储类型 → 跳过清理
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1,
- UserID: 1,
- Status: "completed",
- StorageType: service.SoraStorageTypeUpstream,
- MediaURL: "https://upstream.com/v.mp4",
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-func TestDeleteGeneration_DeleteError(t *testing.T) {
- // repo.Delete 出错
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
- repo.deleteErr = fmt.Errorf("delete failed")
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-// ==================== fetchUpstreamModels 测试 ====================
-
-func TestFetchUpstreamModels_NilGateway(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- h := &SoraClientHandler{}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "gatewayService 未初始化")
-}
-
-func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- accountRepo := &stubAccountRepoForHandler{accounts: nil}
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "选择 Sora 账号失败")
-}
-
-func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "不支持模型同步")
-}
-
-func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"base_url": "https://sora.test"}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "api_key")
-}
-
-func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
- // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test"}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
-}
-
-func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "状态码 500")
-}
-
-func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte("not json"))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "解析响应失败")
-}
-
-func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "空模型列表")
-}
-
-func TestFetchUpstreamModels_Success(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // 验证请求头
- require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
- require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- families, err := h.fetchUpstreamModels(context.Background())
- require.NoError(t, err)
- require.NotEmpty(t, families)
-}
-
-func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "未能从上游模型列表中识别")
-}
-
-// ==================== getModelFamilies 缓存测试 ====================
-
-func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
- // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
- h := &SoraClientHandler{}
- families := h.getModelFamilies(context.Background())
- require.NotEmpty(t, families)
-
- // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
- families2 := h.getModelFamilies(context.Background())
- require.Equal(t, families, families2)
- require.False(t, h.modelCacheUpstream)
-}
-
-func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
- t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
-
- families := h.getModelFamilies(context.Background())
- require.NotEmpty(t, families)
- require.True(t, h.modelCacheUpstream)
-
- // 第二次调用命中缓存
- families2 := h.getModelFamilies(context.Background())
- require.Equal(t, families, families2)
-}
-
-func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
- // 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
- h := &SoraClientHandler{
- cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
- modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
- modelCacheUpstream: false,
- }
- // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
- families := h.getModelFamilies(context.Background())
- require.NotEmpty(t, families)
- // 缓存已刷新,不再是 "old"
- found := false
- for _, f := range families {
- if f.ID == "old" {
- found = true
- }
- }
- require.False(t, found, "过期缓存应被刷新")
-}
-
-// ==================== processGeneration: groupID 与 ForcePlatform ====================
-
-func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
- // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 空账号列表 → SelectAccountForModel 失败
- accountRepo := &stubAccountRepoForHandler{accounts: nil}
- gatewayService := newMinimalGatewayService(accountRepo)
-
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
-}
-
-// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
-
-func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
- // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
- userRepo := newStubUserRepoForHandler()
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
-
- body := `{"model":"sora2-landscape-10s","prompt":"test"}`
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
-}
-
-// ==================== Generate: CreatePending 并发限制错误 ====================
-
-// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
-type stubSoraGenRepoWithAtomicCreate struct {
- stubSoraGenRepo
- limitErr error
-}
-
-func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
- if r.limitErr != nil {
- return r.limitErr
- }
- return r.stubSoraGenRepo.Create(context.Background(), gen)
-}
-
-func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
- repo := &stubSoraGenRepoWithAtomicCreate{
- stubSoraGenRepo: *newStubSoraGenRepo(),
- limitErr: service.ErrSoraGenerationConcurrencyLimit,
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
-
- body := `{"model":"sora2-landscape-10s","prompt":"test"}`
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
- h.Generate(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, resp["message"], "3")
-}
-
-// ==================== SaveToStorage: 配额超限 ====================
-
-func TestSaveToStorage_QuotaExceeded(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 用户配额已满
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 10,
- SoraStorageUsedBytes: 10,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
-}
-
-// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
-
-func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
- userRepo := newStubUserRepoForHandler()
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== SaveToStorage: MediaURLs 全为空 ====================
-
-func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: "",
- MediaURLs: []string{},
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, resp["message"], "已过期")
-}
-
-// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
-
-func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail-second")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v1.mp4",
- MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
-
-func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- repo.updateErr = fmt.Errorf("db error")
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 100 * 1024 * 1024,
- SoraStorageUsedBytes: 0,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
-
-func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
-}
-
-func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
-}
-
-func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
-}
-
-// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
-
-func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: service.SoraStorageTypeLocal,
- MediaURL: "nonexistent/video.mp4",
- MediaURLs: []string{"nonexistent/video.mp4"},
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-// ==================== CancelGeneration: 任务已结束冲突 ====================
-
-func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
deleted file mode 100644
index d1e7e00fe5..0000000000
--- a/backend/internal/handler/sora_gateway_handler.go
+++ /dev/null
@@ -1,697 +0,0 @@
-package handler
-
-import (
- "context"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "os"
- "path"
- "path/filepath"
- "strconv"
- "strings"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
- "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
- "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
- middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
- "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
-
- "github.com/gin-gonic/gin"
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
- "go.uber.org/zap"
-)
-
-// SoraGatewayHandler handles Sora chat completions requests
-//
-// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
-type SoraGatewayHandler struct {
- gatewayService *service.GatewayService
- soraGatewayService *service.SoraGatewayService
- billingCacheService *service.BillingCacheService
- usageRecordWorkerPool *service.UsageRecordWorkerPool
- concurrencyHelper *ConcurrencyHelper
- maxAccountSwitches int
- streamMode string
- soraTLSEnabled bool
- soraMediaSigningKey string
- soraMediaRoot string
-}
-
-// NewSoraGatewayHandler creates a new SoraGatewayHandler
-func NewSoraGatewayHandler(
- gatewayService *service.GatewayService,
- soraGatewayService *service.SoraGatewayService,
- concurrencyService *service.ConcurrencyService,
- billingCacheService *service.BillingCacheService,
- usageRecordWorkerPool *service.UsageRecordWorkerPool,
- cfg *config.Config,
-) *SoraGatewayHandler {
- pingInterval := time.Duration(0)
- maxAccountSwitches := 3
- streamMode := "force"
- soraTLSEnabled := true
- signKey := ""
- mediaRoot := "/app/data/sora"
- if cfg != nil {
- pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
- if cfg.Gateway.MaxAccountSwitches > 0 {
- maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
- }
- if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
- streamMode = mode
- }
- soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
- signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
- if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
- mediaRoot = root
- }
- }
- return &SoraGatewayHandler{
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- billingCacheService: billingCacheService,
- usageRecordWorkerPool: usageRecordWorkerPool,
- concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
- maxAccountSwitches: maxAccountSwitches,
- streamMode: strings.ToLower(streamMode),
- soraTLSEnabled: soraTLSEnabled,
- soraMediaSigningKey: signKey,
- soraMediaRoot: mediaRoot,
- }
-}
-
-// ChatCompletions handles Sora /v1/chat/completions endpoint
-func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
- apiKey, ok := middleware2.GetAPIKeyFromContext(c)
- if !ok {
- h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
- return
- }
-
- subject, ok := middleware2.GetAuthSubjectFromContext(c)
- if !ok {
- h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
- return
- }
- reqLog := requestLogger(
- c,
- "handler.sora_gateway.chat_completions",
- zap.Int64("user_id", subject.UserID),
- zap.Int64("api_key_id", apiKey.ID),
- zap.Any("group_id", apiKey.GroupID),
- )
-
- body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
- if err != nil {
- if maxErr, ok := extractMaxBytesError(err); ok {
- h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
- return
- }
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
- return
- }
- if len(body) == 0 {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
- return
- }
-
- setOpsRequestContext(c, "", false, body)
-
- // 校验请求体 JSON 合法性
- if !gjson.ValidBytes(body) {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
- return
- }
-
- // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
- modelResult := gjson.GetBytes(body, "model")
- if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
- return
- }
- reqModel := modelResult.String()
-
- msgsResult := gjson.GetBytes(body, "messages")
- if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
- return
- }
-
- clientStream := gjson.GetBytes(body, "stream").Bool()
- reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
- if !clientStream {
- if h.streamMode == "error" {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
- return
- }
- var err error
- body, err = sjson.SetBytes(body, "stream", true)
- if err != nil {
- h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
- return
- }
- }
-
- setOpsRequestContext(c, reqModel, clientStream, body)
- setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
-
- platform := ""
- if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
- platform = forced
- } else if apiKey.Group != nil {
- platform = apiKey.Group.Platform
- }
- if platform != service.PlatformSora {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
- return
- }
-
- streamStarted := false
- subscription, _ := middleware2.GetSubscriptionFromContext(c)
-
- maxWait := service.CalculateMaxWait(subject.Concurrency)
- canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
- waitCounted := false
- if err != nil {
- reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
- } else if !canWait {
- reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
- h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
- return
- }
- if err == nil && canWait {
- waitCounted = true
- }
- defer func() {
- if waitCounted {
- h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
- }
- }()
-
- userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
- if err != nil {
- reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
- h.handleConcurrencyError(c, err, "user", streamStarted)
- return
- }
- if waitCounted {
- h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
- waitCounted = false
- }
- userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
- if userReleaseFunc != nil {
- defer userReleaseFunc()
- }
-
- if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
- h.handleStreamingAwareError(c, status, code, message, streamStarted)
- return
- }
-
- sessionHash := generateOpenAISessionHash(c, body)
-
- maxAccountSwitches := h.maxAccountSwitches
- switchCount := 0
- failedAccountIDs := make(map[int64]struct{})
- lastFailoverStatus := 0
- var lastFailoverBody []byte
- var lastFailoverHeaders http.Header
-
- for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0))
- if err != nil {
- reqLog.Warn("sora.account_select_failed",
- zap.Error(err),
- zap.Int("excluded_account_count", len(failedAccountIDs)),
- )
- if len(failedAccountIDs) == 0 {
- h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
- return
- }
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
- fields := []zap.Field{
- zap.Int("last_upstream_status", lastFailoverStatus),
- }
- if rayID != "" {
- fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
- }
- if mitigated != "" {
- fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
- }
- if contentType != "" {
- fields = append(fields, zap.String("last_upstream_content_type", contentType))
- }
- reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
- h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
- return
- }
- account := selection.Account
- setOpsSelectedAccount(c, account.ID, account.Platform)
- proxyBound := account.ProxyID != nil
- proxyID := int64(0)
- if account.ProxyID != nil {
- proxyID = *account.ProxyID
- }
- tlsFingerprintEnabled := h.soraTLSEnabled
-
- accountReleaseFunc := selection.ReleaseFunc
- if !selection.Acquired {
- if selection.WaitPlan == nil {
- h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
- return
- }
- accountWaitCounted := false
- canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
- if err != nil {
- reqLog.Warn("sora.account_wait_counter_increment_failed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Error(err),
- )
- } else if !canWait {
- reqLog.Info("sora.account_wait_queue_full",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
- )
- h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
- return
- }
- if err == nil && canWait {
- accountWaitCounted = true
- }
- defer func() {
- if accountWaitCounted {
- h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
- }
- }()
-
- accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
- c,
- account.ID,
- selection.WaitPlan.MaxConcurrency,
- selection.WaitPlan.Timeout,
- clientStream,
- &streamStarted,
- )
- if err != nil {
- reqLog.Warn("sora.account_slot_acquire_failed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Error(err),
- )
- h.handleConcurrencyError(c, err, "account", streamStarted)
- return
- }
- if accountWaitCounted {
- h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
- accountWaitCounted = false
- }
- }
- accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
-
- result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
- if accountReleaseFunc != nil {
- accountReleaseFunc()
- }
- if err != nil {
- var failoverErr *service.UpstreamFailoverError
- if errors.As(err, &failoverErr) {
- failedAccountIDs[account.ID] = struct{}{}
- if switchCount >= maxAccountSwitches {
- lastFailoverStatus = failoverErr.StatusCode
- lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
- lastFailoverBody = failoverErr.ResponseBody
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
- fields := []zap.Field{
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
- }
- if rayID != "" {
- fields = append(fields, zap.String("upstream_cf_ray", rayID))
- }
- if mitigated != "" {
- fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
- }
- if contentType != "" {
- fields = append(fields, zap.String("upstream_content_type", contentType))
- }
- reqLog.Warn("sora.upstream_failover_exhausted", fields...)
- h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
- return
- }
- lastFailoverStatus = failoverErr.StatusCode
- lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
- lastFailoverBody = failoverErr.ResponseBody
- switchCount++
- upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
- fields := []zap.Field{
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.String("upstream_error_code", upstreamErrCode),
- zap.String("upstream_error_message", upstreamErrMsg),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
- }
- if rayID != "" {
- fields = append(fields, zap.String("upstream_cf_ray", rayID))
- }
- if mitigated != "" {
- fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
- }
- if contentType != "" {
- fields = append(fields, zap.String("upstream_content_type", contentType))
- }
- reqLog.Warn("sora.upstream_failover_switching", fields...)
- continue
- }
- reqLog.Error("sora.forward_failed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Error(err),
- )
- return
- }
-
- userAgent := c.GetHeader("User-Agent")
- clientIP := ip.GetClientIP(c)
- requestPayloadHash := service.HashUsageRequestPayload(body)
- inboundEndpoint := GetInboundEndpoint(c)
- upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
-
- // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
- h.submitUsageRecordTask(func(ctx context.Context) {
- if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: account,
- Subscription: subscription,
- InboundEndpoint: inboundEndpoint,
- UpstreamEndpoint: upstreamEndpoint,
- UserAgent: userAgent,
- IPAddress: clientIP,
- RequestPayloadHash: requestPayloadHash,
- }); err != nil {
- logger.L().With(
- zap.String("component", "handler.sora_gateway.chat_completions"),
- zap.Int64("user_id", subject.UserID),
- zap.Int64("api_key_id", apiKey.ID),
- zap.Any("group_id", apiKey.GroupID),
- zap.String("model", reqModel),
- zap.Int64("account_id", account.ID),
- ).Error("sora.record_usage_failed", zap.Error(err))
- }
- })
- reqLog.Debug("sora.request_completed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("switch_count", switchCount),
- )
- return
- }
-}
-
-func generateOpenAISessionHash(c *gin.Context, body []byte) string {
- if c == nil {
- return ""
- }
- sessionID := strings.TrimSpace(c.GetHeader("session_id"))
- if sessionID == "" {
- sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
- }
- if sessionID == "" && len(body) > 0 {
- sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
- }
- if sessionID == "" {
- return ""
- }
- hash := sha256.Sum256([]byte(sessionID))
- return hex.EncodeToString(hash[:])
-}
-
-func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.sora_gateway.chat_completions"),
- zap.Any("panic", recovered),
- ).Error("sora.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
-}
-
-func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
- h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
- fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
-}
-
-func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
- upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
- service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
-
- status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
- h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
-}
-
-func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
- if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
- baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
- return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
- }
-
- upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
- if strings.EqualFold(upstreamCode, "cf_shield_429") {
- baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
- return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
- }
- if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
- switch statusCode {
- case 401, 403, 404, 500, 502, 503, 504:
- return http.StatusBadGateway, "upstream_error", upstreamMessage
- case 429:
- return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
- }
- }
-
- switch statusCode {
- case 401:
- return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
- case 403:
- return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
- case 404:
- if strings.EqualFold(upstreamCode, "unsupported_country_code") {
- return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
- }
- return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
- case 429:
- return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
- case 529:
- return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
- case 500, 502, 503, 504:
- return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
- default:
- return http.StatusBadGateway, "upstream_error", "Upstream request failed"
- }
-}
-
-func cloneHTTPHeaders(headers http.Header) http.Header {
- if headers == nil {
- return nil
- }
- return headers.Clone()
-}
-
-func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
- if headers != nil {
- mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
- contentType = strings.TrimSpace(headers.Get("content-type"))
- if contentType == "" {
- contentType = strings.TrimSpace(headers.Get("Content-Type"))
- }
- }
- rayID = soraerror.ExtractCloudflareRayID(headers, body)
- return rayID, mitigated, contentType
-}
-
-func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
- return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
-}
-
-func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
- message = strings.TrimSpace(message)
- if message == "" {
- return false
- }
- if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
- lower := strings.ToLower(message)
- if strings.Contains(lower, "
Just a moment...`)
-
- h := &SoraGatewayHandler{}
- h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
-
- lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
- require.Len(t, lines, 2)
- jsonStr := strings.TrimPrefix(lines[1], "data: ")
-
- var parsed map[string]any
- require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
-
- errorObj, ok := parsed["error"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "upstream_error", errorObj["type"])
- msg, _ := errorObj["message"].(string)
- require.Contains(t, msg, "Cloudflare challenge")
- require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
-}
-
-func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
- gin.SetMode(gin.TestMode)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
-
- headers := http.Header{}
- headers.Set("cf-ray", "9d03b68c086027a1-SEA")
- body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
-
- h := &SoraGatewayHandler{}
- h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
-
- lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
- require.Len(t, lines, 2)
- jsonStr := strings.TrimPrefix(lines[1], "data: ")
-
- var parsed map[string]any
- require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
-
- errorObj, ok := parsed["error"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "rate_limit_error", errorObj["type"])
- msg, _ := errorObj["message"].(string)
- require.Contains(t, msg, "Cloudflare shield")
- require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
-}
-
-func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
- headers := http.Header{}
- headers.Set("cf-mitigated", "challenge")
- headers.Set("content-type", "text/html")
- body := []byte(``)
-
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
- require.Equal(t, "9cff2d62d83bb98d", rayID)
- require.Equal(t, "challenge", mitigated)
- require.Equal(t, "text/html", contentType)
-}
diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go
index c7c48e14bc..5c9458158a 100644
--- a/backend/internal/handler/usage_record_submit_task_test.go
+++ b/backend/internal/handler/usage_record_submit_task_test.go
@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
- pool := newUsageRecordTestPool(t)
- h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
-
- done := make(chan struct{})
- h.submitUsageRecordTask(func(ctx context.Context) {
- close(done)
- })
-
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("task not executed")
- }
-}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
- h := &SoraGatewayHandler{}
- var called atomic.Bool
-
- h.submitUsageRecordTask(func(ctx context.Context) {
- if _, ok := ctx.Deadline(); !ok {
- t.Fatal("expected deadline in fallback context")
- }
- called.Store(true)
- })
-
- require.True(t, called.Load())
-}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
- h := &SoraGatewayHandler{}
- require.NotPanics(t, func() {
- h.submitUsageRecordTask(nil)
- })
-}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
- h := &SoraGatewayHandler{}
- var called atomic.Bool
-
- require.NotPanics(t, func() {
- h.submitUsageRecordTask(func(ctx context.Context) {
- panic("usage task panic")
- })
- })
-
- h.submitUsageRecordTask(func(ctx context.Context) {
- called.Store(true)
- })
- require.True(t, called.Load(), "panic 后后续任务应仍可执行")
-}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index c917f24a0d..d9622594f2 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -86,8 +86,6 @@ func ProvideHandlers(
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
- soraGatewayHandler *SoraGatewayHandler,
- soraClientHandler *SoraClientHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator,
@@ -104,8 +102,6 @@ func ProvideHandlers(
Admin: adminHandlers,
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
- SoraGateway: soraGatewayHandler,
- SoraClient: soraClientHandler,
Setting: settingHandler,
Totp: totpHandler,
}
@@ -123,7 +119,6 @@ var ProviderSet = wire.NewSet(
NewAnnouncementHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
- NewSoraGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go
index 6b8521bdcd..618b6adb60 100644
--- a/backend/internal/pkg/openai/oauth.go
+++ b/backend/internal/pkg/openai/oauth.go
@@ -17,8 +17,6 @@ import (
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
- // OAuth Client ID for Sora mobile flow (aligned with sora2api)
- SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
@@ -39,8 +37,6 @@ const (
const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai"
- // OAuthPlatformSora uses Sora OAuth client.
- OAuthPlatformSora = "sora"
)
// OAuthSession stores OAuth flow state for OpenAI
@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
-// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
-// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
- switch strings.ToLower(strings.TrimSpace(platform)) {
- case OAuthPlatformSora:
- return ClientID, false
- default:
- return ClientID, true
- }
+ return ClientID, true
}
// TokenRequest represents the token exchange request body
diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go
index 2970addff5..56b42fc9fb 100644
--- a/backend/internal/pkg/openai/oauth_test.go
+++ b/backend/internal/pkg/openai/oauth_test.go
@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
-
-// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
-// 但不启用 codex_cli_simplified_flow。
-func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
- authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
- parsed, err := url.Parse(authURL)
- if err != nil {
- t.Fatalf("Parse URL failed: %v", err)
- }
- q := parsed.Query()
- if got := q.Get("client_id"); got != ClientID {
- t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
- }
- if got := q.Get("codex_cli_simplified_flow"); got != "" {
- t.Fatalf("codex flow should be empty for sora, got=%q", got)
- }
- if got := q.Get("id_token_add_organizations"); got != "true" {
- t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
- }
-}
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index d45e8a1297..94bfb09d58 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -1692,20 +1692,13 @@ func itoa(v int) string {
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
-// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
-// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
-//
// FindByExtraField finds accounts by key-value pairs in the extra field.
-// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
-//
-// Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
- dbaccount.PlatformEQ("sora"), // 限定平台为 sora
dbaccount.DeletedAtIsNil(),
func(s *entsql.Selector) {
path := sqljson.Path(key)
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index ade0d46486..b3b12e8113 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
- group.FieldSoraImagePrice360,
- group.FieldSoraImagePrice540,
- group.FieldSoraVideoPricePerRequest,
- group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
@@ -608,22 +604,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
- SoraStorageUsedBytes: u.SoraStorageUsedBytes,
- TotpSecretEncrypted: u.TotpSecretEncrypted,
- TotpEnabled: u.TotpEnabled,
- TotpEnabledAt: u.TotpEnabledAt,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
}
@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
- SoraImagePrice360: g.SoraImagePrice360,
- SoraImagePrice540: g.SoraImagePrice540,
- SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
- SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 3cfd649bce..a075b586c4 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
- SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
- SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
- SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
- SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
- SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
- SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
- SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
- SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
- SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
- SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go
index 44fa291bed..c1901d71aa 100644
--- a/backend/internal/repository/openai_oauth_service_test.go
+++ b/backend/internal/repository/openai_oauth_service_test.go
@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
-// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
-func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
- var seenClientIDs []string
- s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if err := r.ParseForm(); err != nil {
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- clientID := r.PostForm.Get("client_id")
- seenClientIDs = append(seenClientIDs, clientID)
- if clientID == openai.SoraClientID {
- w.Header().Set("Content-Type", "application/json")
- _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- }))
-
- resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
- require.NoError(s.T(), err, "RefreshTokenWithClientID")
- require.Equal(s.T(), "at-sora", resp.AccessToken)
- require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
-}
-
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
- wantClientID := openai.SoraClientID
+ wantClientID := "custom-exchange-client-id"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go
deleted file mode 100644
index ad2ae638fe..0000000000
--- a/backend/internal/repository/sora_account_repo.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package repository
-
-import (
- "context"
- "database/sql"
- "errors"
-
- "github.com/Wei-Shaw/sub2api/internal/service"
-)
-
-// soraAccountRepository 实现 service.SoraAccountRepository 接口。
-// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
-//
-// 设计说明:
-// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
-// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
-// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
-type soraAccountRepository struct {
- sql *sql.DB
-}
-
-// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
-func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
- return &soraAccountRepository{sql: sqlDB}
-}
-
-// Upsert 创建或更新 Sora 账号扩展信息
-// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
-func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
- accessToken, accessOK := updates["access_token"].(string)
- refreshToken, refreshOK := updates["refresh_token"].(string)
- sessionToken, sessionOK := updates["session_token"].(string)
-
- if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
- if !sessionOK {
- return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
- }
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_accounts
- SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
- updated_at = NOW()
- WHERE account_id = $1
- `, accountID, sessionToken)
- if err != nil {
- return err
- }
- rows, err := result.RowsAffected()
- if err != nil {
- return err
- }
- if rows == 0 {
- return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
- }
- return nil
- }
-
- _, err := r.sql.ExecContext(ctx, `
- INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
- VALUES ($1, $2, $3, $4, NOW(), NOW())
- ON CONFLICT (account_id) DO UPDATE SET
- access_token = EXCLUDED.access_token,
- refresh_token = EXCLUDED.refresh_token,
- session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
- updated_at = NOW()
- `, accountID, accessToken, refreshToken, sessionToken)
- return err
-}
-
-// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
-func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
- rows, err := r.sql.QueryContext(ctx, `
- SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
- FROM sora_accounts
- WHERE account_id = $1
- `, accountID)
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- if !rows.Next() {
- return nil, nil // 记录不存在
- }
-
- var sa service.SoraAccount
- if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
- return nil, err
- }
- return &sa, nil
-}
-
-// Delete 删除 Sora 账号扩展信息
-func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
- _, err := r.sql.ExecContext(ctx, `
- DELETE FROM sora_accounts WHERE account_id = $1
- `, accountID)
- return err
-}
diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go
deleted file mode 100644
index aaf3cb2f54..0000000000
--- a/backend/internal/repository/sora_generation_repo.go
+++ /dev/null
@@ -1,419 +0,0 @@
-package repository
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "fmt"
- "strings"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/service"
-)
-
-// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
-// 使用原生 SQL 操作 sora_generations 表。
-type soraGenerationRepository struct {
- sql *sql.DB
-}
-
-// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
-func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
- return &soraGenerationRepository{sql: sqlDB}
-}
-
-func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
- mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
- s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
-
- err := r.sql.QueryRowContext(ctx, `
- INSERT INTO sora_generations (
- user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message
- ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
- RETURNING id, created_at
- `,
- gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
- gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
- gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
- ).Scan(&gen.ID, &gen.CreatedAt)
- return err
-}
-
-// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
-func (r *soraGenerationRepository) CreatePendingWithLimit(
- ctx context.Context,
- gen *service.SoraGeneration,
- activeStatuses []string,
- maxActive int64,
-) error {
- if gen == nil {
- return fmt.Errorf("generation is nil")
- }
- if maxActive <= 0 {
- return r.Create(ctx, gen)
- }
- if len(activeStatuses) == 0 {
- activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
- }
-
- tx, err := r.sql.BeginTx(ctx, nil)
- if err != nil {
- return err
- }
- defer func() { _ = tx.Rollback() }()
-
- // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
- if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
- return err
- }
-
- placeholders := make([]string, len(activeStatuses))
- args := make([]any, 0, 1+len(activeStatuses))
- args = append(args, gen.UserID)
- for i, s := range activeStatuses {
- placeholders[i] = fmt.Sprintf("$%d", i+2)
- args = append(args, s)
- }
- countQuery := fmt.Sprintf(
- `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
- strings.Join(placeholders, ","),
- )
- var activeCount int64
- if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
- return err
- }
- if activeCount >= maxActive {
- return service.ErrSoraGenerationConcurrencyLimit
- }
-
- mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
- s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
- if err := tx.QueryRowContext(ctx, `
- INSERT INTO sora_generations (
- user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message
- ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
- RETURNING id, created_at
- `,
- gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
- gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
- gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
- ).Scan(&gen.ID, &gen.CreatedAt); err != nil {
- return err
- }
-
- return tx.Commit()
-}
-
-func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
- gen := &service.SoraGeneration{}
- var mediaURLsJSON, s3KeysJSON []byte
- var completedAt sql.NullTime
- var apiKeyID sql.NullInt64
-
- err := r.sql.QueryRowContext(ctx, `
- SELECT id, user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message,
- created_at, completed_at
- FROM sora_generations WHERE id = $1
- `, id).Scan(
- &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
- &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
- &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
- &gen.CreatedAt, &completedAt,
- )
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, fmt.Errorf("生成记录不存在")
- }
- return nil, err
- }
-
- if apiKeyID.Valid {
- gen.APIKeyID = &apiKeyID.Int64
- }
- if completedAt.Valid {
- gen.CompletedAt = &completedAt.Time
- }
- _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
- _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
- return gen, nil
-}
-
-func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
- mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
- s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
-
- var completedAt *time.Time
- if gen.CompletedAt != nil {
- completedAt = gen.CompletedAt
- }
-
- _, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations SET
- status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
- storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
- error_message = $9, completed_at = $10
- WHERE id = $1
- `,
- gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
- gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
- gen.ErrorMessage, completedAt,
- )
- return err
-}
-
-// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
-func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2, upstream_task_id = $3
- WHERE id = $1 AND status = $4
- `,
- id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
-func (r *soraGenerationRepository) UpdateCompletedIfActive(
- ctx context.Context,
- id int64,
- mediaURL string,
- mediaURLs []string,
- storageType string,
- s3Keys []string,
- fileSizeBytes int64,
- completedAt time.Time,
-) (bool, error) {
- mediaURLsJSON, _ := json.Marshal(mediaURLs)
- s3KeysJSON, _ := json.Marshal(s3Keys)
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2,
- media_url = $3,
- media_urls = $4,
- file_size_bytes = $5,
- storage_type = $6,
- s3_object_keys = $7,
- error_message = '',
- completed_at = $8
- WHERE id = $1 AND status IN ($9, $10)
- `,
- id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
- storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
-func (r *soraGenerationRepository) UpdateFailedIfActive(
- ctx context.Context,
- id int64,
- errMsg string,
- completedAt time.Time,
-) (bool, error) {
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2,
- error_message = $3,
- completed_at = $4
- WHERE id = $1 AND status IN ($5, $6)
- `,
- id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
-func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2, completed_at = $3
- WHERE id = $1 AND status IN ($4, $5)
- `,
- id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
-func (r *soraGenerationRepository) UpdateStorageIfCompleted(
- ctx context.Context,
- id int64,
- mediaURL string,
- mediaURLs []string,
- storageType string,
- s3Keys []string,
- fileSizeBytes int64,
-) (bool, error) {
- mediaURLsJSON, _ := json.Marshal(mediaURLs)
- s3KeysJSON, _ := json.Marshal(s3Keys)
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET media_url = $2,
- media_urls = $3,
- file_size_bytes = $4,
- storage_type = $5,
- s3_object_keys = $6
- WHERE id = $1 AND status = $7
- `,
- id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
- _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
- return err
-}
-
-func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
- // 构建 WHERE 条件
- conditions := []string{"user_id = $1"}
- args := []any{params.UserID}
- argIdx := 2
-
- if params.Status != "" {
- // 支持逗号分隔的多状态
- statuses := strings.Split(params.Status, ",")
- placeholders := make([]string, len(statuses))
- for i, s := range statuses {
- placeholders[i] = fmt.Sprintf("$%d", argIdx)
- args = append(args, strings.TrimSpace(s))
- argIdx++
- }
- conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
- }
- if params.StorageType != "" {
- storageTypes := strings.Split(params.StorageType, ",")
- placeholders := make([]string, len(storageTypes))
- for i, s := range storageTypes {
- placeholders[i] = fmt.Sprintf("$%d", argIdx)
- args = append(args, strings.TrimSpace(s))
- argIdx++
- }
- conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
- }
- if params.MediaType != "" {
- conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
- args = append(args, params.MediaType)
- argIdx++
- }
-
- whereClause := "WHERE " + strings.Join(conditions, " AND ")
-
- // 计数
- var total int64
- countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
- if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
- return nil, 0, err
- }
-
- // 分页查询
- offset := (params.Page - 1) * params.PageSize
- listQuery := fmt.Sprintf(`
- SELECT id, user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message,
- created_at, completed_at
- FROM sora_generations %s
- ORDER BY created_at DESC
- LIMIT $%d OFFSET $%d
- `, whereClause, argIdx, argIdx+1)
- args = append(args, params.PageSize, offset)
-
- rows, err := r.sql.QueryContext(ctx, listQuery, args...)
- if err != nil {
- return nil, 0, err
- }
- defer func() {
- _ = rows.Close()
- }()
-
- var results []*service.SoraGeneration
- for rows.Next() {
- gen := &service.SoraGeneration{}
- var mediaURLsJSON, s3KeysJSON []byte
- var completedAt sql.NullTime
- var apiKeyID sql.NullInt64
-
- if err := rows.Scan(
- &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
- &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
- &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
- &gen.CreatedAt, &completedAt,
- ); err != nil {
- return nil, 0, err
- }
-
- if apiKeyID.Valid {
- gen.APIKeyID = &apiKeyID.Int64
- }
- if completedAt.Valid {
- gen.CompletedAt = &completedAt.Time
- }
- _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
- _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
- results = append(results, gen)
- }
-
- return results, total, rows.Err()
-}
-
-func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
- if len(statuses) == 0 {
- return 0, nil
- }
-
- placeholders := make([]string, len(statuses))
- args := []any{userID}
- for i, s := range statuses {
- placeholders[i] = fmt.Sprintf("$%d", i+2)
- args = append(args, s)
- }
-
- var count int64
- query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
- err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
- return count, err
-}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index 66d0b4ec19..d7bcd0944d 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -73,7 +73,6 @@ var usageLogInsertArgTypes = [...]string{
"text", // ip_address
"integer", // image_count
"text", // image_size
- "text", // media_type
"text", // service_tier
"text", // reasoning_effort
"text", // inbound_endpoint
@@ -352,7 +351,6 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -369,7 +367,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -790,7 +788,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -803,7 +800,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
- args := make([]any, 0, len(keys)*47)
+ args := make([]any, 0, len(keys)*46)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -867,7 +864,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -915,7 +911,6 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1003,7 +998,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1016,7 +1010,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
- args := make([]any, 0, len(preparedList)*46)
+ args := make([]any, 0, len(preparedList)*45)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1077,7 +1071,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1125,7 +1118,6 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1181,7 +1173,6 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1198,7 +1189,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1225,7 +1216,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
- mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
@@ -1286,7 +1276,6 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
ipAddress,
log.ImageCount,
imageSize,
- mediaType,
serviceTier,
reasoningEffort,
inboundEndpoint,
@@ -4051,7 +4040,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
- mediaType sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
inboundEndpoint sql.NullString
@@ -4101,7 +4089,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress,
&imageCount,
&imageSize,
- &mediaType,
&serviceTier,
&reasoningEffort,
&inboundEndpoint,
@@ -4179,9 +4166,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
- if mediaType.Valid {
- log.MediaType = &mediaType.String
- }
if serviceTier.Valid {
log.ServiceTier = &serviceTier.String
}
diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go
index 77f695e378..ce0c5f0040 100644
--- a/backend/internal/repository/usage_log_repo_request_type_test.go
+++ b/backend/internal/repository/usage_log_repo_request_type_test.go
@@ -76,7 +76,6 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
- sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
@@ -155,7 +154,6 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
log.ImageCount,
sqlmock.AnyArg(),
- sqlmock.AnyArg(),
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
@@ -471,7 +469,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
- sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
@@ -519,7 +516,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
- sql.NullString{},
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
@@ -567,7 +563,6 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
- sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 575754e03b..06c79113e2 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
- SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
-// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
-func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
- if deltaBytes <= 0 {
- user, err := r.GetByID(ctx, userID)
- if err != nil {
- return 0, err
- }
- return user.SoraStorageUsedBytes, nil
- }
- var newUsed int64
- err := scanSingleRow(ctx, r.sql, `
- UPDATE users
- SET sora_storage_used_bytes = sora_storage_used_bytes + $2
- WHERE id = $1
- AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
- RETURNING sora_storage_used_bytes
- `, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
- if err == nil {
- return newUsed, nil
- }
- if errors.Is(err, sql.ErrNoRows) {
- // 区分用户不存在和配额冲突
- exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
- if existsErr != nil {
- return 0, existsErr
- }
- if !exists {
- return 0, service.ErrUserNotFound
- }
- return 0, service.ErrSoraStorageQuotaExceeded
- }
- return 0, err
-}
-
-// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
-func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
- if deltaBytes <= 0 {
- user, err := r.GetByID(ctx, userID)
- if err != nil {
- return 0, err
- }
- return user.SoraStorageUsedBytes, nil
- }
- var newUsed int64
- err := scanSingleRow(ctx, r.sql, `
- UPDATE users
- SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
- WHERE id = $1
- RETURNING sora_storage_used_bytes
- `, []any{userID, deltaBytes}, &newUsed)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return 0, service.ErrUserNotFound
- }
- return 0, err
- }
- return newUsed, nil
-}
-
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 4548c02882..657e3ed66c 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
- NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 450c312265..d412ea34d6 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -204,11 +204,6 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
- "sora_image_price_360": null,
- "sora_image_price_540": null,
- "sora_storage_quota_bytes": 0,
- "sora_video_price_per_request": null,
- "sora_video_price_per_request_hd": null,
"claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
@@ -532,7 +527,6 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
- "sora_client_enabled": false,
"invitation_code_enabled": false,
"home_content": "",
"hide_ccs_import_button": false,
@@ -653,11 +647,11 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
- adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
- adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
+ adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index d9ec951e77..73210bfcbb 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -94,7 +94,6 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
- strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/responses")
}
diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go
index 997015317f..d60a142c87 100644
--- a/backend/internal/server/router.go
+++ b/backend/internal/server/router.go
@@ -109,7 +109,6 @@ func registerRoutes(
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
- routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 76f4c4b4d6..b921da9511 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
- // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
- registerSoraOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
@@ -321,19 +319,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
-func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
- sora := admin.Group("/sora")
- {
- sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
- sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
- sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
- sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
- sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
- sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
- sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
- }
-}
-
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
@@ -422,15 +407,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
- // Sora S3 存储配置
- adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
- adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
- adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
- adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
- adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
- adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
- adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
- adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
}
}
diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go
index 072cfdee37..cbf9829326 100644
--- a/backend/internal/server/routes/gateway.go
+++ b/backend/internal/server/routes/gateway.go
@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
- soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
- if soraMaxBodySize <= 0 {
- soraMaxBodySize = cfg.Gateway.MaxBodySize
- }
- soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware()
@@ -163,28 +158,6 @@ func RegisterGatewayRoutes(
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
- // Sora 专用路由(强制使用 sora 平台)
- soraV1 := r.Group("/sora/v1")
- soraV1.Use(soraBodyLimit)
- soraV1.Use(clientRequestID)
- soraV1.Use(opsErrorLogger)
- soraV1.Use(endpointNorm)
- soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
- soraV1.Use(gin.HandlerFunc(apiKeyAuth))
- soraV1.Use(requireGroupAnthropic)
- {
- soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
- soraV1.GET("/models", h.Gateway.Models)
- }
-
- // Sora 媒体代理(可选 API Key 验证)
- if cfg.Gateway.SoraMediaRequireAPIKey {
- r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
- } else {
- r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
- }
- // Sora 媒体代理(签名 URL,无需 API Key)
- r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
}
// getGroupPlatform extracts the group platform from the API Key stored in context.
diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go
index 00edd31b84..4d65a626f9 100644
--- a/backend/internal/server/routes/gateway_test.go
+++ b/backend/internal/server/routes/gateway_test.go
@@ -22,7 +22,6 @@ func newGatewayRoutesTestRouter() *gin.Engine {
&handler.Handlers{
Gateway: &handler.GatewayHandler{},
OpenAIGateway: &handler.OpenAIGatewayHandler{},
- SoraGateway: &handler.SoraGatewayHandler{},
},
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
c.Next()
diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go
deleted file mode 100644
index 13fceb812c..0000000000
--- a/backend/internal/server/routes/sora_client.go
+++ /dev/null
@@ -1,36 +0,0 @@
-package routes
-
-import (
- "github.com/Wei-Shaw/sub2api/internal/handler"
- "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
-
- "github.com/gin-gonic/gin"
-)
-
-// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
-func RegisterSoraClientRoutes(
- v1 *gin.RouterGroup,
- h *handler.Handlers,
- jwtAuth middleware.JWTAuthMiddleware,
- settingService *service.SettingService,
-) {
- if h.SoraClient == nil {
- return
- }
-
- authenticated := v1.Group("/sora")
- authenticated.Use(gin.HandlerFunc(jwtAuth))
- authenticated.Use(middleware.BackendModeUserGuard(settingService))
- {
- authenticated.POST("/generate", h.SoraClient.Generate)
- authenticated.GET("/generations", h.SoraClient.ListGenerations)
- authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
- authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
- authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
- authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
- authenticated.GET("/quota", h.SoraClient.GetQuota)
- authenticated.GET("/models", h.SoraClient.GetModels)
- authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
- }
-}
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index 328790a87f..3189a7290f 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -28,8 +28,7 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
- // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
- // 用于查找通过 linked_openai_account_id 关联的 Sora 账号
+ // FindByExtraField 根据 extra 字段中的键值对查找账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 8218c2db0e..55865945c6 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -13,18 +13,14 @@ import (
"log"
"net/http"
"net/http/httptest"
- "net/url"
"regexp"
"strings"
- "sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
- "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
- "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
- soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
- soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
- soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
- soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
- soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
)
// TestEvent represents a SSE event for account testing
@@ -71,13 +62,8 @@ type AccountTestService struct {
httpUpstream HTTPUpstream
cfg *config.Config
tlsFPProfileService *TLSFingerprintProfileService
- soraTestGuardMu sync.Mutex
- soraTestLastRun map[int64]time.Time
- soraTestCooldown time.Duration
}
-const defaultSoraTestCooldown = 10 * time.Second
-
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
@@ -94,8 +80,6 @@ func NewAccountTestService(
httpUpstream: httpUpstream,
cfg: cfg,
tlsFPProfileService: tlsFPProfileService,
- soraTestLastRun: make(map[int64]time.Time),
- soraTestCooldown: defaultSoraTestCooldown,
}
}
@@ -197,10 +181,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt)
}
- if account.Platform == PlatformSora {
- return s.testSoraAccountConnection(c, account)
- }
-
return s.testClaudeAccountConnection(c, account, modelID)
}
@@ -634,698 +614,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
-type soraProbeStep struct {
- Name string `json:"name"`
- Status string `json:"status"`
- HTTPStatus int `json:"http_status,omitempty"`
- ErrorCode string `json:"error_code,omitempty"`
- Message string `json:"message,omitempty"`
-}
-
-type soraProbeSummary struct {
- Status string `json:"status"`
- Steps []soraProbeStep `json:"steps"`
-}
-
-type soraProbeRecorder struct {
- steps []soraProbeStep
-}
-
-func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
- r.steps = append(r.steps, soraProbeStep{
- Name: name,
- Status: status,
- HTTPStatus: httpStatus,
- ErrorCode: strings.TrimSpace(errorCode),
- Message: strings.TrimSpace(message),
- })
-}
-
-func (r *soraProbeRecorder) finalize() soraProbeSummary {
- meSuccess := false
- partial := false
- for _, step := range r.steps {
- if step.Name == "me" {
- meSuccess = strings.EqualFold(step.Status, "success")
- continue
- }
- if strings.EqualFold(step.Status, "failed") {
- partial = true
- }
- }
-
- status := "success"
- if !meSuccess {
- status = "failed"
- } else if partial {
- status = "partial_success"
- }
-
- return soraProbeSummary{
- Status: status,
- Steps: append([]soraProbeStep(nil), r.steps...),
- }
-}
-
-func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
- if rec == nil {
- return
- }
- summary := rec.finalize()
- code := ""
- for _, step := range summary.Steps {
- if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
- code = step.ErrorCode
- break
- }
- }
- s.sendEvent(c, TestEvent{
- Type: "sora_test_result",
- Status: summary.Status,
- Code: code,
- Data: summary,
- })
-}
-
-func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
- if accountID <= 0 {
- return 0, true
- }
- s.soraTestGuardMu.Lock()
- defer s.soraTestGuardMu.Unlock()
-
- if s.soraTestLastRun == nil {
- s.soraTestLastRun = make(map[int64]time.Time)
- }
- cooldown := s.soraTestCooldown
- if cooldown <= 0 {
- cooldown = defaultSoraTestCooldown
- }
-
- now := time.Now()
- if lastRun, ok := s.soraTestLastRun[accountID]; ok {
- elapsed := now.Sub(lastRun)
- if elapsed < cooldown {
- return cooldown - elapsed, false
- }
- }
- s.soraTestLastRun[accountID] = now
- return 0, true
-}
-
-func ceilSeconds(d time.Duration) int {
- if d <= 0 {
- return 1
- }
- sec := int(d / time.Second)
- if d%time.Second != 0 {
- sec++
- }
- if sec < 1 {
- sec = 1
- }
- return sec
-}
-
-// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
-// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
-func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
- ctx := c.Request.Context()
-
- apiKey := account.GetCredential("api_key")
- if apiKey == "" {
- return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
- }
-
- baseURL := account.GetBaseURL()
- if baseURL == "" {
- return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
- }
-
- // 验证 base_url 格式
- normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
- if err != nil {
- return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
- }
- upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
-
- // 设置 SSE 头
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- c.Writer.Flush()
-
- if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
- msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
- return s.sendErrorAndEnd(c, msg)
- }
-
- s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
-
- // 构建轻量级 prompt-enhance 请求作为连通性测试
- testPayload := map[string]any{
- "model": "prompt-enhance-short-10s",
- "messages": []map[string]string{{"role": "user", "content": "test"}},
- "stream": false,
- }
- payloadBytes, _ := json.Marshal(testPayload)
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
- if err != nil {
- return s.sendErrorAndEnd(c, "构建测试请求失败")
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+apiKey)
-
- // 获取代理 URL
- proxyURL := ""
- if account.ProxyID != nil && account.Proxy != nil {
- proxyURL = account.Proxy.URL()
- }
-
- resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
- if err != nil {
- return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
- }
- defer func() { _ = resp.Body.Close() }()
-
- respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
-
- if resp.StatusCode == http.StatusOK {
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
-
- if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
- return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
- }
-
- // 其他错误但能连通(如 400 参数错误)也算连通性测试通过
- if resp.StatusCode == http.StatusBadRequest {
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
-
- return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
-}
-
-// testSoraAccountConnection 测试 Sora 账号的连接
-// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
-// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
-func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
- // apikey 类型走独立测试流程
- if account.Type == AccountTypeAPIKey {
- return s.testSoraAPIKeyAccountConnection(c, account)
- }
-
- ctx := c.Request.Context()
- recorder := &soraProbeRecorder{}
-
- authToken := account.GetCredential("access_token")
- if authToken == "" {
- recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "No access token available")
- }
-
- // Set SSE headers
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- c.Writer.Flush()
-
- if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
- msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
- recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, msg)
- }
-
- // Send test_start event
- s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
-
- req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
- if err != nil {
- recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "Failed to create request")
- }
-
- // 使用 Sora 客户端标准请求头
- req.Header.Set("Authorization", "Bearer "+authToken)
- req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Accept-Language", "en-US,en;q=0.9")
- req.Header.Set("Origin", "https://sora.chatgpt.com")
- req.Header.Set("Referer", "https://sora.chatgpt.com/")
-
- // Get proxy URL
- proxyURL := ""
- if account.ProxyID != nil && account.Proxy != nil {
- proxyURL = account.Proxy.URL()
- }
- soraTLSProfile := s.resolveSoraTLSProfile()
-
- resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
- if err != nil {
- recorder.addStep("me", "failed", 0, "network_error", err.Error())
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
- }
- defer func() { _ = resp.Body.Close() }()
-
- body, _ := io.ReadAll(resp.Body)
-
- if resp.StatusCode != http.StatusOK {
- if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
- recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
- s.emitSoraProbeSummary(c, recorder)
- s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
- return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
- }
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
- switch {
- case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
- recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
- case strings.EqualFold(upstreamCode, "unsupported_country_code"):
- recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
- case strings.TrimSpace(upstreamMessage) != "":
- recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
- default:
- recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
- }
- }
- recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
-
- // 解析 /me 响应,提取用户信息
- var meResp map[string]any
- if err := json.Unmarshal(body, &meResp); err != nil {
- // 能收到 200 就说明 token 有效
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
- } else {
- // 尝试提取用户名或邮箱信息
- info := "Sora connection OK"
- if name, ok := meResp["name"].(string); ok && name != "" {
- info = fmt.Sprintf("Sora connection OK - User: %s", name)
- } else if email, ok := meResp["email"].(string); ok && email != "" {
- info = fmt.Sprintf("Sora connection OK - Email: %s", email)
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: info})
- }
-
- // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
- subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
- if err == nil {
- subReq.Header.Set("Authorization", "Bearer "+authToken)
- subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- subReq.Header.Set("Accept", "application/json")
- subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
- subReq.Header.Set("Origin", "https://sora.chatgpt.com")
- subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
-
- subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
- if subErr != nil {
- recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
- } else {
- subBody, _ := io.ReadAll(subResp.Body)
- _ = subResp.Body.Close()
- if subResp.StatusCode == http.StatusOK {
- recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
- if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
- s.sendEvent(c, TestEvent{Type: "content", Text: summary})
- } else {
- s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
- }
- } else {
- if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
- recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
- s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
- } else {
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
- recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
- }
- }
- }
- }
-
- // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
- s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder)
-
- s.emitSoraProbeSummary(c, recorder)
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
-}
-
-func (s *AccountTestService) testSora2Capabilities(
- c *gin.Context,
- ctx context.Context,
- account *Account,
- authToken string,
- proxyURL string,
- tlsProfile *tlsfingerprint.Profile,
- recorder *soraProbeRecorder,
-) {
- inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraInviteMineURL,
- proxyURL,
- tlsProfile,
- )
- if err != nil {
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
- return
- }
-
- if inviteStatus == http.StatusUnauthorized {
- bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraBootstrapURL,
- proxyURL,
- tlsProfile,
- )
- if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
- if recorder != nil {
- recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
- inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraInviteMineURL,
- proxyURL,
- tlsProfile,
- )
- if err != nil {
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
- return
- }
- } else if recorder != nil {
- code := ""
- msg := ""
- if bootstrapErr != nil {
- code = "network_error"
- msg = bootstrapErr.Error()
- }
- recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
- }
- }
-
- if inviteStatus != http.StatusOK {
- if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
- }
- s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
- return
- }
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
- return
- }
- if recorder != nil {
- recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
- }
-
- if summary := parseSoraInviteSummary(inviteBody); summary != "" {
- s.sendEvent(c, TestEvent{Type: "content", Text: summary})
- } else {
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
- }
-
- remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraRemainingURL,
- proxyURL,
- tlsProfile,
- )
- if remainingErr != nil {
- if recorder != nil {
- recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
- return
- }
- if remainingStatus != http.StatusOK {
- if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
- if recorder != nil {
- recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
- }
- s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
- return
- }
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
- if recorder != nil {
- recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
- return
- }
- if recorder != nil {
- recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
- }
- if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
- s.sendEvent(c, TestEvent{Type: "content", Text: summary})
- } else {
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
- }
-}
-
-func (s *AccountTestService) fetchSoraTestEndpoint(
- ctx context.Context,
- account *Account,
- authToken string,
- url string,
- proxyURL string,
- tlsProfile *tlsfingerprint.Profile,
-) (int, http.Header, []byte, error) {
- req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
- if err != nil {
- return 0, nil, nil, err
- }
- req.Header.Set("Authorization", "Bearer "+authToken)
- req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Accept-Language", "en-US,en;q=0.9")
- req.Header.Set("Origin", "https://sora.chatgpt.com")
- req.Header.Set("Referer", "https://sora.chatgpt.com/")
-
- resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
- if err != nil {
- return 0, nil, nil, err
- }
- defer func() { _ = resp.Body.Close() }()
-
- body, readErr := io.ReadAll(resp.Body)
- if readErr != nil {
- return resp.StatusCode, resp.Header, nil, readErr
- }
- return resp.StatusCode, resp.Header, body, nil
-}
-
-func parseSoraSubscriptionSummary(body []byte) string {
- var subResp struct {
- Data []struct {
- Plan struct {
- ID string `json:"id"`
- Title string `json:"title"`
- } `json:"plan"`
- EndTS string `json:"end_ts"`
- } `json:"data"`
- }
- if err := json.Unmarshal(body, &subResp); err != nil {
- return ""
- }
- if len(subResp.Data) == 0 {
- return ""
- }
-
- first := subResp.Data[0]
- parts := make([]string, 0, 3)
- if first.Plan.Title != "" {
- parts = append(parts, first.Plan.Title)
- }
- if first.Plan.ID != "" {
- parts = append(parts, first.Plan.ID)
- }
- if first.EndTS != "" {
- parts = append(parts, "end="+first.EndTS)
- }
- if len(parts) == 0 {
- return ""
- }
- return "Subscription: " + strings.Join(parts, " | ")
-}
-
-func parseSoraInviteSummary(body []byte) string {
- var inviteResp struct {
- InviteCode string `json:"invite_code"`
- RedeemedCount int64 `json:"redeemed_count"`
- TotalCount int64 `json:"total_count"`
- }
- if err := json.Unmarshal(body, &inviteResp); err != nil {
- return ""
- }
-
- parts := []string{"Sora2: supported"}
- if inviteResp.InviteCode != "" {
- parts = append(parts, "invite="+inviteResp.InviteCode)
- }
- if inviteResp.TotalCount > 0 {
- parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
- }
- return strings.Join(parts, " | ")
-}
-
-func parseSoraRemainingSummary(body []byte) string {
- var remainingResp struct {
- RateLimitAndCreditBalance struct {
- EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
- RateLimitReached bool `json:"rate_limit_reached"`
- AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
- } `json:"rate_limit_and_credit_balance"`
- }
- if err := json.Unmarshal(body, &remainingResp); err != nil {
- return ""
- }
- info := remainingResp.RateLimitAndCreditBalance
- parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
- if info.RateLimitReached {
- parts = append(parts, "rate_limited=true")
- }
- if info.AccessResetsInSeconds > 0 {
- parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
- }
- return strings.Join(parts, " | ")
-}
-
-func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile {
- if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint {
- // Sora TLS fingerprint enabled — use built-in default profile
- return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"}
- }
- return nil // disabled
-}
-
-func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
- return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
-}
-
-func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
- return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
-}
-
-func extractCloudflareRayID(headers http.Header, body []byte) string {
- return soraerror.ExtractCloudflareRayID(headers, body)
-}
-
-func extractSoraEgressIPHint(headers http.Header) string {
- if headers == nil {
- return "unknown"
- }
- candidates := []string{
- "x-openai-public-ip",
- "x-envoy-external-address",
- "cf-connecting-ip",
- "x-forwarded-for",
- }
- for _, key := range candidates {
- if value := strings.TrimSpace(headers.Get(key)); value != "" {
- return value
- }
- }
- return "unknown"
-}
-
-func sanitizeProxyURLForLog(raw string) string {
- raw = strings.TrimSpace(raw)
- if raw == "" {
- return ""
- }
- u, err := url.Parse(raw)
- if err != nil {
- return ""
- }
- if u.User != nil {
- u.User = nil
- }
- return u.String()
-}
-
-func endpointPathForLog(endpoint string) string {
- parsed, err := url.Parse(strings.TrimSpace(endpoint))
- if err != nil || parsed.Path == "" {
- return endpoint
- }
- return parsed.Path
-}
-
-func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
- accountID := int64(0)
- platform := ""
- proxyID := "none"
- if account != nil {
- accountID = account.ID
- platform = account.Platform
- if account.ProxyID != nil {
- proxyID = fmt.Sprintf("%d", *account.ProxyID)
- }
- }
- cfRay := extractCloudflareRayID(headers, body)
- if cfRay == "" {
- cfRay = "unknown"
- }
- log.Printf(
- "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
- accountID,
- platform,
- endpoint,
- endpointPathForLog(endpoint),
- proxyID,
- sanitizeProxyURLForLog(proxyURL),
- cfRay,
- extractSoraEgressIPHint(headers),
- )
-}
-
-func truncateSoraErrorBody(body []byte, max int) string {
- return soraerror.TruncateBody(body, max)
-}
-
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go
index 5ba04c69b7..f38264a215 100644
--- a/backend/internal/service/account_test_service_gemini_test.go
+++ b/backend/internal/service/account_test_service_gemini_test.go
@@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
- ctx, recorder := newSoraTestContext()
+ ctx, recorder := newTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index efa6f7da78..5125db5ba5 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -4,16 +4,61 @@ package service
import (
"context"
+ "fmt"
"io"
"net/http"
+ "net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
)
+// --- shared test helpers ---
+
+type queuedHTTPUpstream struct {
+ responses []*http.Response
+ requests []*http.Request
+ tlsFlags []bool
+}
+
+func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
+ return nil, fmt.Errorf("unexpected Do call")
+}
+
+func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
+ u.requests = append(u.requests, req)
+ u.tlsFlags = append(u.tlsFlags, profile != nil)
+ if len(u.responses) == 0 {
+ return nil, fmt.Errorf("no mocked response")
+ }
+ resp := u.responses[0]
+ u.responses = u.responses[1:]
+ return resp, nil
+}
+
+func newJSONResponse(status int, body string) *http.Response {
+ return &http.Response{
+ StatusCode: status,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
+
+// --- test functions ---
+
+func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+ return c, rec
+}
+
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
@@ -34,7 +79,7 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
- ctx, recorder := newSoraTestContext()
+ ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
@@ -68,7 +113,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
- ctx, _ := newSoraTestContext()
+ ctx, _ := newTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go
deleted file mode 100644
index 52f832a9c7..0000000000
--- a/backend/internal/service/account_test_service_sora_test.go
+++ /dev/null
@@ -1,320 +0,0 @@
-package service
-
-import (
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
-)
-
-type queuedHTTPUpstream struct {
- responses []*http.Response
- requests []*http.Request
- tlsFlags []bool
-}
-
-func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
- return nil, fmt.Errorf("unexpected Do call")
-}
-
-func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
- u.requests = append(u.requests, req)
- u.tlsFlags = append(u.tlsFlags, profile != nil)
- if len(u.responses) == 0 {
- return nil, fmt.Errorf("no mocked response")
- }
- resp := u.responses[0]
- u.responses = u.responses[1:]
- return resp, nil
-}
-
-func newJSONResponse(status int, body string) *http.Response {
- return &http.Response{
- StatusCode: status,
- Header: make(http.Header),
- Body: io.NopCloser(strings.NewReader(body)),
- }
-}
-
-func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
- resp := newJSONResponse(status, body)
- resp.Header.Set(key, value)
- return resp
-}
-
-func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
- gin.SetMode(gin.TestMode)
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
- return c, rec
-}
-
-func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
- newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
- newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
- newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
- },
- }
- svc := &AccountTestService{
- httpUpstream: upstream,
- cfg: &config.Config{
- Gateway: config.GatewayConfig{
- TLSFingerprint: config.TLSFingerprintConfig{
- Enabled: true,
- },
- },
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- DisableTLSFingerprint: false,
- },
- },
- },
- }
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.NoError(t, err)
- require.Len(t, upstream.requests, 4)
- require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
- require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
- require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
- require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
- require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
- require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
- require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
-
- body := rec.Body.String()
- require.Contains(t, body, `"type":"test_start"`)
- require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
- require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
- require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
- require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"status":"success"`)
- require.Contains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
- newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
- newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
- newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.NoError(t, err)
- require.Len(t, upstream.requests, 4)
- body := rec.Body.String()
- require.Contains(t, body, "Sora connection OK - User: demo-user")
- require.Contains(t, body, "Subscription check returned 403")
- require.Contains(t, body, "Sora2 invite check returned 401")
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"status":"partial_success"`)
- require.Contains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.Error(t, err)
- require.Contains(t, err.Error(), "Cloudflare challenge")
- require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
- body := rec.Body.String()
- require.Contains(t, body, `"type":"error"`)
- require.Contains(t, body, "Cloudflare challenge")
- require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
-}
-
-func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.Error(t, err)
- require.Contains(t, err.Error(), "Cloudflare challenge")
- require.Contains(t, err.Error(), "HTTP 429")
- body := rec.Body.String()
- require.Contains(t, body, "Cloudflare challenge")
-}
-
-func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.Error(t, err)
- require.Contains(t, err.Error(), "token_invalidated")
- body := rec.Body.String()
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"status":"failed"`)
- require.Contains(t, body, "token_invalidated")
- require.NotContains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
- },
- }
- svc := &AccountTestService{
- httpUpstream: upstream,
- soraTestCooldown: time.Hour,
- }
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c1, _ := newSoraTestContext()
- err := svc.testSoraAccountConnection(c1, account)
- require.NoError(t, err)
-
- c2, rec2 := newSoraTestContext()
- err = svc.testSoraAccountConnection(c2, account)
- require.Error(t, err)
- require.Contains(t, err.Error(), "测试过于频繁")
- body := rec2.Body.String()
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"code":"test_rate_limited"`)
- require.Contains(t, body, `"status":"failed"`)
- require.NotContains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
- newJSONResponse(http.StatusForbidden, `Just a moment...`),
- newJSONResponse(http.StatusForbidden, `Just a moment...`),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.NoError(t, err)
- body := rec.Body.String()
- require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
- require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
- require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
- require.Contains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestSanitizeProxyURLForLog(t *testing.T) {
- require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
- require.Equal(t, "", sanitizeProxyURLForLog(""))
- require.Equal(t, "", sanitizeProxyURLForLog("://invalid"))
-}
-
-func TestExtractSoraEgressIPHint(t *testing.T) {
- h := make(http.Header)
- h.Set("x-openai-public-ip", "203.0.113.10")
- require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
-
- h2 := make(http.Header)
- h2.Set("x-envoy-external-address", "198.51.100.9")
- require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
-
- require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
- require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
-}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index b6d7d634df..8032f8717a 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -15,7 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
- "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
+ "github.com/Wei-Shaw/sub2api/internal/util/httputil"
)
// AdminService interface defines admin management operations
@@ -104,14 +104,13 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
- Email string
- Password string
- Username string
- Notes string
- Balance float64
- Concurrency int
- AllowedGroups []int64
- SoraStorageQuotaBytes int64
+ Email string
+ Password string
+ Username string
+ Notes string
+ Balance float64
+ Concurrency int
+ AllowedGroups []int64
}
type UpdateUserInput struct {
@@ -125,8 +124,7 @@ type UpdateUserInput struct {
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
- GroupRates map[int64]*float64
- SoraStorageQuotaBytes *int64
+ GroupRates map[int64]*float64
}
type CreateGroupInput struct {
@@ -140,16 +138,11 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
- ImagePrice1K *float64
- ImagePrice2K *float64
- ImagePrice4K *float64
- // Sora 按次计费配置
- SoraImagePrice360 *float64
- SoraImagePrice540 *float64
- SoraVideoPricePerRequest *float64
- SoraVideoPricePerRequestHD *float64
- ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
- FallbackGroupID *int64 // 降级分组 ID
+ ImagePrice1K *float64
+ ImagePrice2K *float64
+ ImagePrice4K *float64
+ ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
+ FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@@ -158,8 +151,6 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
- // Sora 存储配额
- SoraStorageQuotaBytes int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
@@ -181,16 +172,11 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
- ImagePrice1K *float64
- ImagePrice2K *float64
- ImagePrice4K *float64
- // Sora 按次计费配置
- SoraImagePrice360 *float64
- SoraImagePrice540 *float64
- SoraVideoPricePerRequest *float64
- SoraVideoPricePerRequestHD *float64
- ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
- FallbackGroupID *int64 // 降级分组 ID
+ ImagePrice1K *float64
+ ImagePrice2K *float64
+ ImagePrice4K *float64
+ ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
+ FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@@ -199,8 +185,6 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
- // Sora 存储配额
- SoraStorageQuotaBytes *int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool
DefaultMappedModel *string
@@ -426,14 +410,6 @@ var proxyQualityTargets = []proxyQualityTarget{
http.StatusOK: {},
},
},
- {
- Target: "sora",
- URL: "https://sora.chatgpt.com/backend/me",
- Method: http.MethodGet,
- AllowedStatuses: map[int]struct{}{
- http.StatusUnauthorized: {},
- },
- },
}
const (
@@ -448,7 +424,6 @@ type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
- soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
@@ -473,7 +448,6 @@ func NewAdminService(
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
- soraAccountRepo SoraAccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
@@ -492,7 +466,6 @@ func NewAdminService(
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
- soraAccountRepo: soraAccountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
@@ -574,15 +547,14 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
- Email: input.Email,
- Username: input.Username,
- Notes: input.Notes,
- Role: RoleUser, // Always create as regular user, never admin
- Balance: input.Balance,
- Concurrency: input.Concurrency,
- Status: StatusActive,
- AllowedGroups: input.AllowedGroups,
- SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
+ Email: input.Email,
+ Username: input.Username,
+ Notes: input.Notes,
+ Role: RoleUser, // Always create as regular user, never admin
+ Balance: input.Balance,
+ Concurrency: input.Concurrency,
+ Status: StatusActive,
+ AllowedGroups: input.AllowedGroups,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -654,10 +626,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
- if input.SoraStorageQuotaBytes != nil {
- user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
- }
-
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
@@ -860,10 +828,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
- soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
- soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
- soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
- soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组
if input.FallbackGroupID != nil {
@@ -934,17 +898,12 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
- SoraImagePrice360: soraImagePrice360,
- SoraImagePrice540: soraImagePrice540,
- SoraVideoPricePerRequest: soraVideoPrice,
- SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
- SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet,
@@ -1115,21 +1074,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
- if input.SoraImagePrice360 != nil {
- group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
- }
- if input.SoraImagePrice540 != nil {
- group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
- }
- if input.SoraVideoPricePerRequest != nil {
- group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
- }
- if input.SoraVideoPricePerRequestHD != nil {
- group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
- }
- if input.SoraStorageQuotaBytes != nil {
- group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
- }
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
@@ -1566,18 +1510,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
- // Sora apikey 账号的 base_url 必填校验
- if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
- baseURL, _ := input.Credentials["base_url"].(string)
- baseURL = strings.TrimSpace(baseURL)
- if baseURL == "" {
- return nil, errors.New("sora apikey 账号必须设置 base_url")
- }
- if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
- return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
- }
- }
-
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
@@ -1623,18 +1555,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return nil, err
}
- // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
- if account.Platform == PlatformSora && s.soraAccountRepo != nil {
- soraUpdates := map[string]any{
- "access_token": account.GetCredential("access_token"),
- "refresh_token": account.GetCredential("refresh_token"),
- }
- if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
- // 只记录警告日志,不阻塞账号创建
- logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
- }
- }
-
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
@@ -1763,18 +1683,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
- // Sora apikey 账号的 base_url 必填校验
- if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
- baseURL, _ := account.Credentials["base_url"].(string)
- baseURL = strings.TrimSpace(baseURL)
- if baseURL == "" {
- return nil, errors.New("sora apikey 账号必须设置 base_url")
- }
- if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
- return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
- }
- }
-
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
@@ -2377,10 +2285,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox
body = body[:proxyQualityMaxBodyBytes]
}
- if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
+ // Cloudflare challenge 检测
+ if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
item.Status = "challenge"
- item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
- item.Message = "Sora 命中 Cloudflare challenge"
+ item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body)
+ item.Message = "命中 Cloudflare challenge"
return item
}
diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go
index 5a43cd9c29..d3b3f61bc6 100644
--- a/backend/internal/service/admin_service_proxy_quality_test.go
+++ b/backend/internal/service/admin_service_proxy_quality_test.go
@@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
require.Contains(t, result.Summary, "挑战 1 项")
}
-func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
+func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("cf-ray", "test-ray-123")
@@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
defer server.Close()
target := proxyQualityTarget{
- Target: "sora",
+ Target: "openai",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go
index ecaffcbcb7..e3b60a2790 100644
--- a/backend/internal/service/antigravity_smart_retry_test.go
+++ b/backend/internal/service/antigravity_smart_retry_test.go
@@ -5,13 +5,12 @@ package service
import (
"bytes"
"context"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
+ "github.com/stretchr/testify/require"
"io"
"net/http"
"strings"
"testing"
-
- "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
- "github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
@@ -81,17 +80,12 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
m.responseBodies[respIdx] = bodyBytes
}
- // 用缓存的 body 字节重建新的 reader
- var body io.ReadCloser
+ // 用缓存的 body 重建 reader(支持重试场景多次读取)
+ cloned := *resp
if m.responseBodies[respIdx] != nil {
- body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
+ cloned.Body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
}
-
- return &http.Response{
- StatusCode: resp.StatusCode,
- Header: resp.Header.Clone(),
- Body: body,
- }, respErr
+ return &cloned, respErr
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go
index e8ad5c9c32..ad6ba0e930 100644
--- a/backend/internal/service/api_key_auth_cache.go
+++ b/backend/internal/service/api_key_auth_cache.go
@@ -49,10 +49,6 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
- SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index f727ab10f3..64a70e8cf4 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -234,10 +234,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
- SoraImagePrice360: apiKey.Group.SoraImagePrice360,
- SoraImagePrice540: apiKey.Group.SoraImagePrice540,
- SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
@@ -293,10 +289,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
- SoraImagePrice360: snapshot.Group.SoraImagePrice360,
- SoraImagePrice540: snapshot.Group.SoraImagePrice540,
- SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 2fe13686a7..763abadbfc 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -808,14 +808,6 @@ type ImagePriceConfig struct {
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
}
-// SoraPriceConfig Sora 按次计费配置
-type SoraPriceConfig struct {
- ImagePrice360 *float64
- ImagePrice540 *float64
- VideoPricePerRequest *float64
- VideoPricePerRequestHD *float64
-}
-
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
@@ -846,65 +838,6 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
}
}
-// CalculateSoraImageCost 计算 Sora 图片按次费用
-func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
- if imageCount <= 0 {
- return &CostBreakdown{}
- }
-
- unitPrice := 0.0
- if groupConfig != nil {
- switch imageSize {
- case "540":
- if groupConfig.ImagePrice540 != nil {
- unitPrice = *groupConfig.ImagePrice540
- }
- default:
- if groupConfig.ImagePrice360 != nil {
- unitPrice = *groupConfig.ImagePrice360
- }
- }
- }
-
- totalCost := unitPrice * float64(imageCount)
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
- }
- actualCost := totalCost * rateMultiplier
-
- return &CostBreakdown{
- TotalCost: totalCost,
- ActualCost: actualCost,
- }
-}
-
-// CalculateSoraVideoCost 计算 Sora 视频按次费用
-func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
- unitPrice := 0.0
- if groupConfig != nil {
- modelLower := strings.ToLower(model)
- if strings.Contains(modelLower, "sora2pro-hd") {
- if groupConfig.VideoPricePerRequestHD != nil {
- unitPrice = *groupConfig.VideoPricePerRequestHD
- }
- }
- if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
- unitPrice = *groupConfig.VideoPricePerRequest
- }
- }
-
- totalCost := unitPrice
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
- }
- actualCost := totalCost * rateMultiplier
-
- return &CostBreakdown{
- TotalCost: totalCost,
- ActualCost: actualCost,
- }
-}
-
// getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index 1094342285..dd58502c01 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -363,28 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
-func TestCalculateSoraVideoCost(t *testing.T) {
- svc := newTestBillingService()
-
- price := 0.5
- cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
- cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
-
- require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
-}
-
-func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
- svc := newTestBillingService()
-
- hdPrice := 1.0
- normalPrice := 0.5
- cfg := &SoraPriceConfig{
- VideoPricePerRequest: &normalPrice,
- VideoPricePerRequestHD: &hdPrice,
- }
- cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
- require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
-}
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
@@ -464,33 +442,6 @@ func TestForceUpdatePricing_NilService(t *testing.T) {
require.Contains(t, err.Error(), "not initialized")
}
-func TestCalculateSoraImageCost(t *testing.T) {
- svc := newTestBillingService()
-
- price360 := 0.05
- price540 := 0.08
- cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
-
- cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
- require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
-
- cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
- require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
- require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
-}
-
-func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
- svc := newTestBillingService()
- cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
- require.Equal(t, 0.0, cost.TotalCost)
-}
-
-func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
- svc := newTestBillingService()
- cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
- require.Equal(t, 0.0, cost.TotalCost)
-}
-
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
// 使用空的 fallback prices 让 GetModelPricing 失败
svc := &BillingService{
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index ecac0db0cf..52df52d656 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -24,7 +24,6 @@ const (
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
- PlatformSora = domain.PlatformSora
)
// Account type constants
@@ -107,7 +106,6 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
- SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
@@ -199,27 +197,6 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
- // =========================
- // Sora S3 存储配置
- // =========================
-
- SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
- SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
- SettingKeySoraS3Region = "sora_s3_region" // S3 区域
- SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
- SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
- SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
- SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
- SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
- SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
- SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
-
- // =========================
- // Sora 用户存储配额
- // =========================
-
- SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
-
// =========================
// Claude Code Version Check
// =========================
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index a95b62b133..16bfcd8e5a 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -60,13 +60,6 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
-// MediaType 媒体类型常量
-const (
- MediaTypeImage = "image"
- MediaTypeVideo = "video"
- MediaTypePrompt = "prompt"
-)
-
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
@@ -511,9 +504,6 @@ type ForwardResult struct {
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
- // Sora 媒体字段
- MediaType string // image / video / prompt
- MediaURL string // 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
@@ -1971,9 +1961,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
- if platform == PlatformSora {
- return s.listSoraSchedulableAccounts(ctx, groupID)
- }
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
@@ -2070,53 +2057,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
-func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
- const useMixed = false
-
- var accounts []Account
- var err error
- if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
- accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
- } else if groupID != nil {
- accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
- } else {
- accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
- }
- if err != nil {
- slog.Debug("account_scheduling_list_failed",
- "group_id", derefGroupID(groupID),
- "platform", PlatformSora,
- "error", err)
- return nil, useMixed, err
- }
-
- filtered := make([]Account, 0, len(accounts))
- for _, acc := range accounts {
- if acc.Platform != PlatformSora {
- continue
- }
- if !s.isSoraAccountSchedulable(&acc) {
- continue
- }
- filtered = append(filtered, acc)
- }
- slog.Debug("account_scheduling_list_sora",
- "group_id", derefGroupID(groupID),
- "platform", PlatformSora,
- "raw_count", len(accounts),
- "filtered_count", len(filtered))
- for _, acc := range filtered {
- slog.Debug("account_scheduling_account_detail",
- "account_id", acc.ID,
- "name", acc.Name,
- "platform", acc.Platform,
- "type", acc.Type,
- "status", acc.Status,
- "tls_fingerprint", acc.IsTLSFingerprintEnabled())
- }
- return filtered, useMixed, nil
-}
-
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
@@ -2141,33 +2081,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
-func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
- return s.soraUnschedulableReason(account) == ""
-}
-
-func (s *GatewayService) soraUnschedulableReason(account *Account) string {
- if account == nil {
- return "account_nil"
- }
- if account.Status != StatusActive {
- return fmt.Sprintf("status=%s", account.Status)
- }
- if !account.Schedulable {
- return "schedulable=false"
- }
- if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
- return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
- }
- return ""
-}
-
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
if account == nil {
return false
}
- if account.Platform == PlatformSora {
- return s.isSoraAccountSchedulable(account)
- }
return account.IsSchedulable()
}
@@ -2175,12 +2092,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if account == nil {
return false
}
- if account.Platform == PlatformSora {
- if !s.isSoraAccountSchedulable(account) {
- return false
- }
- return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
- }
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
}
@@ -3357,9 +3268,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats.SampleMappingIDs,
stats.SampleRateLimitIDs,
)
- if platform == PlatformSora {
- s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
- }
return stats
}
@@ -3416,11 +3324,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "excluded"}
}
if !s.isAccountSchedulableForSelection(acc) {
- detail := "generic_unschedulable"
- if acc.Platform == PlatformSora {
- detail = s.soraUnschedulableReason(acc)
- }
- return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
+ return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
}
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
return selectionFailureDiagnosis{
@@ -3444,57 +3348,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"}
}
-func (s *GatewayService) logSoraSelectionFailureDetails(
- ctx context.Context,
- groupID *int64,
- sessionHash string,
- requestedModel string,
- accounts []Account,
- excludedIDs map[int64]struct{},
- allowMixedScheduling bool,
-) {
- const maxLines = 30
- logged := 0
-
- for i := range accounts {
- if logged >= maxLines {
- break
- }
- acc := &accounts[i]
- diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
- if diagnosis.Category == "eligible" {
- continue
- }
- detail := diagnosis.Detail
- if detail == "" {
- detail = "-"
- }
- logger.LegacyPrintf(
- "service.gateway",
- "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
- derefGroupID(groupID),
- requestedModel,
- shortSessionHash(sessionHash),
- acc.ID,
- acc.Platform,
- diagnosis.Category,
- detail,
- )
- logged++
- }
- if len(accounts) > maxLines {
- logger.LegacyPrintf(
- "service.gateway",
- "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
- derefGroupID(groupID),
- requestedModel,
- shortSessionHash(sessionHash),
- len(accounts),
- logged,
- )
- }
-}
-
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil {
return true
@@ -3573,9 +3426,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return mapAntigravityModel(account, requestedModel) != ""
}
- if account.Platform == PlatformSora {
- return s.isSoraModelSupportedByAccount(account, requestedModel)
- }
if account.IsBedrock() {
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
@@ -3588,143 +3438,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
-func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
- if account == nil {
- return false
- }
- if strings.TrimSpace(requestedModel) == "" {
- return true
- }
-
- // 先走原始精确/通配符匹配。
- mapping := account.GetModelMapping()
- if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
- return true
- }
-
- aliases := buildSoraModelAliases(requestedModel)
- if len(aliases) == 0 {
- return false
- }
-
- hasSoraSelector := false
- for pattern := range mapping {
- if !isSoraModelSelector(pattern) {
- continue
- }
- hasSoraSelector = true
- if matchPatternAnyAlias(pattern, aliases) {
- return true
- }
- }
-
- // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
- // 此时不应误拦截 Sora 模型请求。
- if !hasSoraSelector {
- return true
- }
-
- return false
-}
-
-func matchPatternAnyAlias(pattern string, aliases []string) bool {
- normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
- if normalizedPattern == "" {
- return false
- }
- for _, alias := range aliases {
- if matchWildcard(normalizedPattern, alias) {
- return true
- }
- }
- return false
-}
-
-func isSoraModelSelector(pattern string) bool {
- p := strings.ToLower(strings.TrimSpace(pattern))
- if p == "" {
- return false
- }
-
- switch {
- case strings.HasPrefix(p, "sora"),
- strings.HasPrefix(p, "gpt-image"),
- strings.HasPrefix(p, "prompt-enhance"),
- strings.HasPrefix(p, "sy_"):
- return true
- }
-
- return p == "video" || p == "image"
-}
-
-func buildSoraModelAliases(requestedModel string) []string {
- modelID := strings.ToLower(strings.TrimSpace(requestedModel))
- if modelID == "" {
- return nil
- }
-
- aliases := make([]string, 0, 8)
- addAlias := func(value string) {
- v := strings.ToLower(strings.TrimSpace(value))
- if v == "" {
- return
- }
- for _, existing := range aliases {
- if existing == v {
- return
- }
- }
- aliases = append(aliases, v)
- }
-
- addAlias(modelID)
- cfg, ok := GetSoraModelConfig(modelID)
- if ok {
- addAlias(cfg.Model)
- switch cfg.Type {
- case "video":
- addAlias("video")
- addAlias("sora")
- addAlias(soraVideoFamilyAlias(modelID))
- case "image":
- addAlias("image")
- addAlias("gpt-image")
- case "prompt_enhance":
- addAlias("prompt-enhance")
- }
- return aliases
- }
-
- switch {
- case strings.HasPrefix(modelID, "sora"):
- addAlias("video")
- addAlias("sora")
- addAlias(soraVideoFamilyAlias(modelID))
- case strings.HasPrefix(modelID, "gpt-image"):
- addAlias("image")
- addAlias("gpt-image")
- case strings.HasPrefix(modelID, "prompt-enhance"):
- addAlias("prompt-enhance")
- default:
- return nil
- }
-
- return aliases
-}
-
-func soraVideoFamilyAlias(modelID string) string {
- switch {
- case strings.HasPrefix(modelID, "sora2pro-hd"):
- return "sora2pro-hd"
- case strings.HasPrefix(modelID, "sora2pro"):
- return "sora2pro"
- case strings.HasPrefix(modelID, "sora2"):
- return "sora2"
- default:
- return ""
- }
-}
-
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
@@ -7592,9 +7305,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount
- if usageLog.MediaType != nil {
- cmd.MediaType = *usageLog.MediaType
- }
if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier
}
@@ -7750,8 +7460,6 @@ type recordUsageOpts struct {
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
- // - Sora 媒体类型分支(image/video/prompt)
- // - MediaType 字段写入使用日志
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
@@ -7842,7 +7550,6 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
-// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
@@ -7944,16 +7651,6 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
- // Sora 媒体类型分支(仅 Claude 路径启用)
- if opts.EnableClaudePath {
- if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
- return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
- }
- if result.MediaType == MediaTypePrompt {
- return &CostBreakdown{}
- }
- }
-
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
@@ -7963,28 +7660,6 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
-// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
-func (s *GatewayService) calculateSoraMediaCost(
- result *ForwardResult,
- apiKey *APIKey,
- billingModel string,
- multiplier float64,
-) *CostBreakdown {
- var soraConfig *SoraPriceConfig
- if apiKey.Group != nil {
- soraConfig = &SoraPriceConfig{
- ImagePrice360: apiKey.Group.SoraImagePrice360,
- ImagePrice540: apiKey.Group.SoraImagePrice540,
- VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
- VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
- }
- }
- if result.MediaType == MediaTypeImage {
- return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
- }
- return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
-}
-
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -8133,13 +7808,12 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
- BillingMode: resolveBillingMode(opts, result, cost),
+ BillingMode: resolveBillingMode(result, cost),
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
- MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
@@ -8163,13 +7837,7 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
-// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
-func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
- isSoraMedia := opts.EnableClaudePath &&
- (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
- if isSoraMedia {
- return nil
- }
+func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
var mode string
switch {
case cost != nil && cost.BillingMode != "":
@@ -8182,13 +7850,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
return &mode
}
-func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
- if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
- return &result.MediaType
- }
- return nil
-}
-
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
return &subscription.ID
diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go
index 743d70bbbf..ac8c6df67c 100644
--- a/backend/internal/service/gateway_service_selection_failure_stats_test.go
+++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go
@@ -9,35 +9,35 @@ import (
func TestCollectSelectionFailureStats(t *testing.T) {
svc := &GatewayService{}
- model := "sora2-landscape-10s"
+ model := "gpt-5.4"
resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
accounts := []Account{
// excluded
{
ID: 1,
- Platform: PlatformSora,
+ Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
// unschedulable
{
ID: 2,
- Platform: PlatformSora,
+ Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: false,
},
// platform filtered
{
ID: 3,
- Platform: PlatformOpenAI,
+ Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
},
// model unsupported
{
ID: 4,
- Platform: PlatformSora,
+ Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{
@@ -49,7 +49,7 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// model rate limited
{
ID: 5,
- Platform: PlatformSora,
+ Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
@@ -63,14 +63,14 @@ func TestCollectSelectionFailureStats(t *testing.T) {
// eligible
{
ID: 6,
- Platform: PlatformSora,
+ Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
},
}
excluded := map[int64]struct{}{1: {}}
- stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
+ stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformOpenAI, excluded, false)
if stats.Total != 6 {
t.Fatalf("total=%d want=6", stats.Total)
@@ -95,31 +95,31 @@ func TestCollectSelectionFailureStats(t *testing.T) {
}
}
-func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
+func TestDiagnoseSelectionFailure_UnschedulableDetail(t *testing.T) {
svc := &GatewayService{}
acc := &Account{
ID: 7,
- Platform: PlatformSora,
+ Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: false,
}
- diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
+ diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "gpt-5.4", PlatformOpenAI, map[int64]struct{}{}, false)
if diagnosis.Category != "unschedulable" {
t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
}
- if diagnosis.Detail != "schedulable=false" {
- t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
+ if diagnosis.Detail != "generic_unschedulable" {
+ t.Fatalf("detail=%s want=generic_unschedulable", diagnosis.Detail)
}
}
-func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
+func TestDiagnoseSelectionFailure_ModelRateLimitedDetail(t *testing.T) {
svc := &GatewayService{}
- model := "sora2-landscape-10s"
+ model := "gpt-5.4"
resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
acc := &Account{
ID: 8,
- Platform: PlatformSora,
+ Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
@@ -131,7 +131,7 @@ func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
},
}
- diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
+ diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformOpenAI, map[int64]struct{}{}, false)
if diagnosis.Category != "model_rate_limited" {
t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
}
diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go
deleted file mode 100644
index 8ee2a960d1..0000000000
--- a/backend/internal/service/gateway_service_sora_model_support_test.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package service
-
-import "testing"
-
-func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
- svc := &GatewayService{}
- account := &Account{
- Platform: PlatformSora,
- Credentials: map[string]any{},
- }
-
- if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
- t.Fatalf("expected sora model to be supported when model_mapping is empty")
- }
-}
-
-func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
- svc := &GatewayService{}
- account := &Account{
- Platform: PlatformSora,
- Credentials: map[string]any{
- "model_mapping": map[string]any{
- "gpt-4o": "gpt-4o",
- },
- },
- }
-
- if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
- t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
- }
-}
-
-func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
- svc := &GatewayService{}
- account := &Account{
- Platform: PlatformSora,
- Credentials: map[string]any{
- "model_mapping": map[string]any{
- "sora2": "sora2",
- },
- },
- }
-
- if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
- t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
- }
-}
-
-func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
- svc := &GatewayService{}
- account := &Account{
- Platform: PlatformSora,
- Credentials: map[string]any{
- "model_mapping": map[string]any{
- "sy_8": "sy_8",
- },
- },
- }
-
- if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
- t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
- }
-}
-
-func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
- svc := &GatewayService{}
- account := &Account{
- Platform: PlatformSora,
- Credentials: map[string]any{
- "model_mapping": map[string]any{
- "gpt-image": "gpt-image",
- },
- },
- }
-
- if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
- t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
- }
-}
diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go
deleted file mode 100644
index 5178e68e40..0000000000
--- a/backend/internal/service/gateway_service_sora_scheduling_test.go
+++ /dev/null
@@ -1,89 +0,0 @@
-package service
-
-import (
- "context"
- "testing"
- "time"
-)
-
-func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
- svc := &GatewayService{}
- now := time.Now()
- past := now.Add(-1 * time.Minute)
- future := now.Add(5 * time.Minute)
-
- acc := &Account{
- Platform: PlatformSora,
- Status: StatusActive,
- Schedulable: true,
- AutoPauseOnExpired: true,
- ExpiresAt: &past,
- OverloadUntil: &future,
- RateLimitResetAt: &future,
- }
-
- if !svc.isAccountSchedulableForSelection(acc) {
- t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
- }
-}
-
-func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
- svc := &GatewayService{}
- future := time.Now().Add(5 * time.Minute)
-
- acc := &Account{
- Platform: PlatformAnthropic,
- Status: StatusActive,
- Schedulable: true,
- RateLimitResetAt: &future,
- }
-
- if svc.isAccountSchedulableForSelection(acc) {
- t.Fatalf("expected non-sora account to keep generic schedulable checks")
- }
-}
-
-func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
- svc := &GatewayService{}
- model := "sora2-landscape-10s"
- resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
- globalResetAt := time.Now().Add(2 * time.Minute)
-
- acc := &Account{
- Platform: PlatformSora,
- Status: StatusActive,
- Schedulable: true,
- RateLimitResetAt: &globalResetAt,
- Extra: map[string]any{
- "model_rate_limits": map[string]any{
- model: map[string]any{
- "rate_limit_reset_at": resetAt,
- },
- },
- },
- }
-
- if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
- t.Fatalf("expected sora account to be blocked by model scope rate limit")
- }
-}
-
-func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
- svc := &GatewayService{}
- future := time.Now().Add(3 * time.Minute)
-
- accounts := []Account{
- {
- ID: 1,
- Platform: PlatformSora,
- Status: StatusActive,
- Schedulable: true,
- RateLimitResetAt: &future,
- },
- }
-
- stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
- if stats.Unschedulable != 0 || stats.Eligible != 1 {
- t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
- }
-}
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index e0f81a39a4..d59af9e1c0 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -26,15 +26,6 @@ type Group struct {
ImagePrice2K *float64
ImagePrice4K *float64
- // Sora 按次计费配置(阶段 1)
- SoraImagePrice360 *float64
- SoraImagePrice540 *float64
- SoraVideoPricePerRequest *float64
- SoraVideoPricePerRequestHD *float64
-
- // Sora 存储配额
- SoraStorageQuotaBytes int64
-
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
@@ -112,18 +103,6 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
}
}
-// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540)
-func (g *Group) GetSoraImagePrice(imageSize string) *float64 {
- switch imageSize {
- case "360":
- return g.SoraImagePrice360
- case "540":
- return g.SoraImagePrice540
- default:
- return g.SoraImagePrice360
- }
-}
-
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool {
if group == nil {
diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go
index f575cd625e..dc094d43ce 100644
--- a/backend/internal/service/openai_oauth_service.go
+++ b/backend/internal/service/openai_oauth_service.go
@@ -3,30 +3,15 @@ package service
import (
"context"
"crypto/subtle"
- "encoding/json"
- "io"
"log/slog"
"net/http"
- "regexp"
- "sort"
- "strconv"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
- "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
-var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
-
-var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
-
-type soraSessionChunk struct {
- index int
- value string
-}
-
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
@@ -225,7 +210,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
}
-// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
+// RefreshTokenWithClientID refreshes an OpenAI OAuth token with optional client_id.
func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil {
@@ -298,215 +283,10 @@ func (s *OpenAIOAuthService) enrichTokenInfo(ctx context.Context, tokenInfo *Ope
tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL)
}
-// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
-func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
- sessionToken = normalizeSoraSessionTokenInput(sessionToken)
- if strings.TrimSpace(sessionToken) == "" {
- return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
- }
-
- proxyURL, err := s.resolveProxyURL(ctx, proxyID)
- if err != nil {
- return nil, err
- }
-
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
- if err != nil {
- return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
- }
- req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Origin", "https://sora.chatgpt.com")
- req.Header.Set("Referer", "https://sora.chatgpt.com/")
- req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
-
- client, err := httpclient.GetClient(httpclient.Options{
- ProxyURL: proxyURL,
- Timeout: 120 * time.Second,
- })
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
- }
- resp, err := client.Do(req)
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
- }
- defer func() { _ = resp.Body.Close() }()
-
- body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
- if resp.StatusCode != http.StatusOK {
- return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
- }
-
- var sessionResp struct {
- AccessToken string `json:"accessToken"`
- Expires string `json:"expires"`
- User struct {
- Email string `json:"email"`
- Name string `json:"name"`
- } `json:"user"`
- }
- if err := json.Unmarshal(body, &sessionResp); err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
- }
- if strings.TrimSpace(sessionResp.AccessToken) == "" {
- return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
- }
-
- expiresAt := time.Now().Add(time.Hour).Unix()
- if strings.TrimSpace(sessionResp.Expires) != "" {
- if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
- expiresAt = parsed.Unix()
- }
- }
- expiresIn := expiresAt - time.Now().Unix()
- if expiresIn < 0 {
- expiresIn = 0
- }
-
- return &OpenAITokenInfo{
- AccessToken: strings.TrimSpace(sessionResp.AccessToken),
- ExpiresIn: expiresIn,
- ExpiresAt: expiresAt,
- ClientID: openai.SoraClientID,
- Email: strings.TrimSpace(sessionResp.User.Email),
- }, nil
-}
-
-func normalizeSoraSessionTokenInput(raw string) string {
- trimmed := strings.TrimSpace(raw)
- if trimmed == "" {
- return ""
- }
-
- matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
- if len(matches) == 0 {
- return sanitizeSessionToken(trimmed)
- }
-
- chunkMatches := make([]soraSessionChunk, 0, len(matches))
- singleValues := make([]string, 0, len(matches))
-
- for _, match := range matches {
- if len(match) < 3 {
- continue
- }
-
- value := sanitizeSessionToken(match[2])
- if value == "" {
- continue
- }
-
- if strings.TrimSpace(match[1]) == "" {
- singleValues = append(singleValues, value)
- continue
- }
-
- idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
- if err != nil || idx < 0 {
- continue
- }
- chunkMatches = append(chunkMatches, soraSessionChunk{
- index: idx,
- value: value,
- })
- }
-
- if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
- return merged
- }
-
- if len(singleValues) > 0 {
- return singleValues[len(singleValues)-1]
- }
-
- return ""
-}
-
-func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
- if len(chunks) == 0 {
- return ""
- }
-
- byIndex := make(map[int]string, len(chunks))
- for _, chunk := range chunks {
- byIndex[chunk.index] = chunk.value
- }
-
- if _, ok := byIndex[0]; !ok {
- return ""
- }
- if requireComplete {
- for idx := 0; idx <= requiredMaxIndex; idx++ {
- if _, ok := byIndex[idx]; !ok {
- return ""
- }
- }
- }
-
- orderedIndexes := make([]int, 0, len(byIndex))
- for idx := range byIndex {
- orderedIndexes = append(orderedIndexes, idx)
- }
- sort.Ints(orderedIndexes)
-
- var builder strings.Builder
- for _, idx := range orderedIndexes {
- if _, err := builder.WriteString(byIndex[idx]); err != nil {
- return ""
- }
- }
- return sanitizeSessionToken(builder.String())
-}
-
-func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
- if len(chunks) == 0 {
- return ""
- }
-
- requiredMaxIndex := 0
- for _, chunk := range chunks {
- if chunk.index > requiredMaxIndex {
- requiredMaxIndex = chunk.index
- }
- }
-
- groupStarts := make([]int, 0, len(chunks))
- for idx, chunk := range chunks {
- if chunk.index == 0 {
- groupStarts = append(groupStarts, idx)
- }
- }
-
- if len(groupStarts) == 0 {
- return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
- }
-
- for i := len(groupStarts) - 1; i >= 0; i-- {
- start := groupStarts[i]
- end := len(chunks)
- if i+1 < len(groupStarts) {
- end = groupStarts[i+1]
- }
- if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
- return merged
- }
- }
-
- return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
-}
-
-func sanitizeSessionToken(raw string) string {
- token := strings.TrimSpace(raw)
- token = strings.Trim(token, "\"'`")
- token = strings.TrimSuffix(token, ";")
- return strings.TrimSpace(token)
-}
-
-// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
+// RefreshAccountToken refreshes token for an OpenAI OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
- if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
- return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
+ if account.Platform != PlatformOpenAI {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
}
if account.Type != AccountTypeOAuth {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
@@ -594,25 +374,6 @@ func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop()
}
-func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
- if proxyID == nil {
- return "", nil
- }
- proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
- if err != nil {
- return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
- }
- if proxy == nil {
- return "", nil
- }
- return proxy.URL(), nil
-}
-
func normalizeOpenAIOAuthPlatform(platform string) string {
- switch strings.ToLower(strings.TrimSpace(platform)) {
- case PlatformSora:
- return openai.OAuthPlatformSora
- default:
- return openai.OAuthPlatformOpenAI
- }
+ return openai.OAuthPlatformOpenAI
}
diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go
index 5f26903db1..f3b507ca02 100644
--- a/backend/internal/service/openai_oauth_service_auth_url_test.go
+++ b/backend/internal/service/openai_oauth_service_auth_url_test.go
@@ -43,25 +43,3 @@ func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
require.True(t, ok)
require.Equal(t, openai.ClientID, session.ClientID)
}
-
-// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
-// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
-func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
- svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
- defer svc.Stop()
-
- result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
- require.NoError(t, err)
- require.NotEmpty(t, result.AuthURL)
- require.NotEmpty(t, result.SessionID)
-
- parsed, err := url.Parse(result.AuthURL)
- require.NoError(t, err)
- q := parsed.Query()
- require.Equal(t, openai.ClientID, q.Get("client_id"))
- require.Empty(t, q.Get("codex_cli_simplified_flow"))
-
- session, ok := svc.sessionStore.Get(result.SessionID)
- require.True(t, ok)
- require.Equal(t, openai.ClientID, session.ClientID)
-}
diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go
deleted file mode 100644
index 08da85571c..0000000000
--- a/backend/internal/service/openai_oauth_service_sora_session_test.go
+++ /dev/null
@@ -1,173 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
- "github.com/stretchr/testify/require"
-)
-
-type openaiOAuthClientNoopStub struct{}
-
-func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
- return nil, errors.New("not implemented")
-}
-
-func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
- return nil, errors.New("not implemented")
-}
-
-func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
- return nil, errors.New("not implemented")
-}
-
-func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
- require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
- }))
- defer server.Close()
-
- origin := openAISoraSessionAuthURL
- openAISoraSessionAuthURL = server.URL
- defer func() { openAISoraSessionAuthURL = origin }()
-
- svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
- defer svc.Stop()
-
- info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
- require.NoError(t, err)
- require.NotNil(t, info)
- require.Equal(t, "at-token", info.AccessToken)
- require.Equal(t, "demo@example.com", info.Email)
- require.Greater(t, info.ExpiresAt, int64(0))
-}
-
-func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
- }))
- defer server.Close()
-
- origin := openAISoraSessionAuthURL
- openAISoraSessionAuthURL = server.URL
- defer func() { openAISoraSessionAuthURL = origin }()
-
- svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
- defer svc.Stop()
-
- _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
- require.Error(t, err)
- require.Contains(t, err.Error(), "missing access token")
-}
-
-func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
- require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
- }))
- defer server.Close()
-
- origin := openAISoraSessionAuthURL
- openAISoraSessionAuthURL = server.URL
- defer func() { openAISoraSessionAuthURL = origin }()
-
- svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
- defer svc.Stop()
-
- raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
- info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
- require.NoError(t, err)
- require.Equal(t, "at-token", info.AccessToken)
-}
-
-func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
- require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
- }))
- defer server.Close()
-
- origin := openAISoraSessionAuthURL
- openAISoraSessionAuthURL = server.URL
- defer func() { openAISoraSessionAuthURL = origin }()
-
- svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
- defer svc.Stop()
-
- raw := strings.Join([]string{
- "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
- "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
- }, "\n")
- info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
- require.NoError(t, err)
- require.Equal(t, "at-token", info.AccessToken)
-}
-
-func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
- require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
- }))
- defer server.Close()
-
- origin := openAISoraSessionAuthURL
- openAISoraSessionAuthURL = server.URL
- defer func() { openAISoraSessionAuthURL = origin }()
-
- svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
- defer svc.Stop()
-
- raw := strings.Join([]string{
- "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
- "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
- "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
- "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
- }, "\n")
- info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
- require.NoError(t, err)
- require.Equal(t, "at-token", info.AccessToken)
-}
-
-func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- require.Equal(t, http.MethodGet, r.Method)
- require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
- w.Header().Set("Content-Type", "application/json")
- _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
- }))
- defer server.Close()
-
- origin := openAISoraSessionAuthURL
- openAISoraSessionAuthURL = server.URL
- defer func() { openAISoraSessionAuthURL = origin }()
-
- svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
- defer svc.Stop()
-
- raw := strings.Join([]string{
- "set-cookie",
- "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
- "set-cookie",
- "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
- "set-cookie",
- "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
- }, "\n")
- info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
- require.NoError(t, err)
- require.Equal(t, "at-token", info.AccessToken)
-}
diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go
index 69477ce7fc..e438588edb 100644
--- a/backend/internal/service/openai_token_provider.go
+++ b/backend/internal/service/openai_token_provider.go
@@ -75,7 +75,7 @@ func (m *openAITokenRuntimeMetricsStore) touchNow() {
// OpenAITokenCache token cache interface.
type OpenAITokenCache = GeminiTokenCache
-// OpenAITokenProvider manages access_token for OpenAI/Sora OAuth accounts.
+// OpenAITokenProvider manages access_token for OpenAI OAuth accounts.
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
@@ -131,8 +131,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
- if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth {
- return "", errors.New("not an openai/sora oauth account")
+ if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
+ return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
@@ -158,40 +158,34 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
p.metrics.refreshRequests.Add(1)
p.metrics.touchNow()
- // Sora accounts skip OpenAI OAuth refresh and keep existing token path.
- if account.Platform == PlatformSora {
- slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
+ result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
+ if err != nil {
+ if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
+ return "", err
+ }
+ slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
+ p.metrics.refreshFailure.Add(1)
refreshFailed = true
- } else {
- result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, openAITokenRefreshSkew)
- if err != nil {
- if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
- return "", err
+ } else if result.LockHeld {
+ if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
+ p.metrics.lockContention.Add(1)
+ p.metrics.touchNow()
+ token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
+ if waitErr != nil {
+ return "", waitErr
}
- slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
- p.metrics.refreshFailure.Add(1)
- refreshFailed = true
- } else if result.LockHeld {
- if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache {
- p.metrics.lockContention.Add(1)
- p.metrics.touchNow()
- token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
- if waitErr != nil {
- return "", waitErr
- }
- if strings.TrimSpace(token) != "" {
- slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
- return token, nil
- }
+ if strings.TrimSpace(token) != "" {
+ slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
+ return token, nil
}
- } else if result.Refreshed {
- p.metrics.refreshSuccess.Add(1)
- account = result.Account
- expiresAt = account.GetCredentialAsTime("expires_at")
- } else {
- account = result.Account
- expiresAt = account.GetCredentialAsTime("expires_at")
}
+ } else if result.Refreshed {
+ p.metrics.refreshSuccess.Add(1)
+ account = result.Account
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ } else {
+ account = result.Account
+ expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go
index 1cd923672a..e81fb46560 100644
--- a/backend/internal/service/openai_token_provider_test.go
+++ b/backend/internal/service/openai_token_provider_test.go
@@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an openai/sora oauth account")
+ require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
@@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an openai/sora oauth account")
+ require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 1a24bad149..a85efabd6d 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -22,8 +22,6 @@ import (
var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
- ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
- ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
"DEFAULT_SUBSCRIPTION_GROUP_INVALID",
"default subscription group must exist and be subscription type",
@@ -104,7 +102,6 @@ type SettingService struct {
defaultSubGroupReader DefaultSubscriptionGroupReader
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
- onS3Update func() // Callback when Sora S3 settings are updated
version string // Application version
}
@@ -162,7 +159,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
- SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
@@ -208,7 +204,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
- SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
@@ -222,11 +217,6 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s.onUpdate = callback
}
-// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
-func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
- s.onS3Update = callback
-}
-
// SetVersion sets the application version for injection into public settings
func (s *SettingService) SetVersion(version string) {
s.version = version
@@ -261,7 +251,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
@@ -287,7 +276,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
- SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
@@ -482,7 +470,6 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
- updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints
@@ -830,7 +817,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
- SettingKeySoraClientEnabled: "false",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
@@ -896,7 +882,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
- SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
@@ -1583,607 +1568,3 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
-
-type soraS3ProfilesStore struct {
- ActiveProfileID string `json:"active_profile_id"`
- Items []soraS3ProfileStoreItem `json:"items"`
-}
-
-type soraS3ProfileStoreItem struct {
- ProfileID string `json:"profile_id"`
- Name string `json:"name"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
- UpdatedAt string `json:"updated_at"`
-}
-
-// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
-func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
- profiles, err := s.ListSoraS3Profiles(ctx)
- if err != nil {
- return nil, err
- }
-
- activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
- if activeProfile == nil {
- return &SoraS3Settings{}, nil
- }
-
- return &SoraS3Settings{
- Enabled: activeProfile.Enabled,
- Endpoint: activeProfile.Endpoint,
- Region: activeProfile.Region,
- Bucket: activeProfile.Bucket,
- AccessKeyID: activeProfile.AccessKeyID,
- SecretAccessKey: activeProfile.SecretAccessKey,
- SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
- Prefix: activeProfile.Prefix,
- ForcePathStyle: activeProfile.ForcePathStyle,
- CDNURL: activeProfile.CDNURL,
- DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
- }, nil
-}
-
-// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
-func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
- if settings == nil {
- return fmt.Errorf("settings cannot be nil")
- }
-
- store, err := s.loadSoraS3ProfilesStore(ctx)
- if err != nil {
- return err
- }
-
- now := time.Now().UTC().Format(time.RFC3339)
- activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
- if activeIndex < 0 {
- activeID := "default"
- if hasSoraS3ProfileID(store.Items, activeID) {
- activeID = fmt.Sprintf("default-%d", time.Now().Unix())
- }
- store.Items = append(store.Items, soraS3ProfileStoreItem{
- ProfileID: activeID,
- Name: "Default",
- UpdatedAt: now,
- })
- store.ActiveProfileID = activeID
- activeIndex = len(store.Items) - 1
- }
-
- active := store.Items[activeIndex]
- active.Enabled = settings.Enabled
- active.Endpoint = strings.TrimSpace(settings.Endpoint)
- active.Region = strings.TrimSpace(settings.Region)
- active.Bucket = strings.TrimSpace(settings.Bucket)
- active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
- active.Prefix = strings.TrimSpace(settings.Prefix)
- active.ForcePathStyle = settings.ForcePathStyle
- active.CDNURL = strings.TrimSpace(settings.CDNURL)
- active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
- if settings.SecretAccessKey != "" {
- active.SecretAccessKey = settings.SecretAccessKey
- }
- active.UpdatedAt = now
- store.Items[activeIndex] = active
-
- return s.persistSoraS3ProfilesStore(ctx, store)
-}
-
-// ListSoraS3Profiles 获取 Sora S3 多配置列表
-func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
- store, err := s.loadSoraS3ProfilesStore(ctx)
- if err != nil {
- return nil, err
- }
- return convertSoraS3ProfilesStore(store), nil
-}
-
-// CreateSoraS3Profile 创建 Sora S3 配置
-func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
- if profile == nil {
- return nil, fmt.Errorf("profile cannot be nil")
- }
-
- profileID := strings.TrimSpace(profile.ProfileID)
- if profileID == "" {
- return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
- }
- name := strings.TrimSpace(profile.Name)
- if name == "" {
- return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
- }
-
- store, err := s.loadSoraS3ProfilesStore(ctx)
- if err != nil {
- return nil, err
- }
- if hasSoraS3ProfileID(store.Items, profileID) {
- return nil, ErrSoraS3ProfileExists
- }
-
- now := time.Now().UTC().Format(time.RFC3339)
- store.Items = append(store.Items, soraS3ProfileStoreItem{
- ProfileID: profileID,
- Name: name,
- Enabled: profile.Enabled,
- Endpoint: strings.TrimSpace(profile.Endpoint),
- Region: strings.TrimSpace(profile.Region),
- Bucket: strings.TrimSpace(profile.Bucket),
- AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
- SecretAccessKey: profile.SecretAccessKey,
- Prefix: strings.TrimSpace(profile.Prefix),
- ForcePathStyle: profile.ForcePathStyle,
- CDNURL: strings.TrimSpace(profile.CDNURL),
- DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
- UpdatedAt: now,
- })
-
- if setActive || store.ActiveProfileID == "" {
- store.ActiveProfileID = profileID
- }
-
- if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
- return nil, err
- }
-
- profiles := convertSoraS3ProfilesStore(store)
- created := findSoraS3ProfileByID(profiles.Items, profileID)
- if created == nil {
- return nil, ErrSoraS3ProfileNotFound
- }
- return created, nil
-}
-
-// UpdateSoraS3Profile 更新 Sora S3 配置
-func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
- if profile == nil {
- return nil, fmt.Errorf("profile cannot be nil")
- }
-
- targetID := strings.TrimSpace(profileID)
- if targetID == "" {
- return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
- }
-
- store, err := s.loadSoraS3ProfilesStore(ctx)
- if err != nil {
- return nil, err
- }
-
- targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
- if targetIndex < 0 {
- return nil, ErrSoraS3ProfileNotFound
- }
-
- target := store.Items[targetIndex]
- name := strings.TrimSpace(profile.Name)
- if name == "" {
- return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
- }
- target.Name = name
- target.Enabled = profile.Enabled
- target.Endpoint = strings.TrimSpace(profile.Endpoint)
- target.Region = strings.TrimSpace(profile.Region)
- target.Bucket = strings.TrimSpace(profile.Bucket)
- target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
- target.Prefix = strings.TrimSpace(profile.Prefix)
- target.ForcePathStyle = profile.ForcePathStyle
- target.CDNURL = strings.TrimSpace(profile.CDNURL)
- target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
- if profile.SecretAccessKey != "" {
- target.SecretAccessKey = profile.SecretAccessKey
- }
- target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
- store.Items[targetIndex] = target
-
- if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
- return nil, err
- }
-
- profiles := convertSoraS3ProfilesStore(store)
- updated := findSoraS3ProfileByID(profiles.Items, targetID)
- if updated == nil {
- return nil, ErrSoraS3ProfileNotFound
- }
- return updated, nil
-}
-
-// DeleteSoraS3Profile 删除 Sora S3 配置
-func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
- targetID := strings.TrimSpace(profileID)
- if targetID == "" {
- return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
- }
-
- store, err := s.loadSoraS3ProfilesStore(ctx)
- if err != nil {
- return err
- }
-
- targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
- if targetIndex < 0 {
- return ErrSoraS3ProfileNotFound
- }
-
- store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
- if store.ActiveProfileID == targetID {
- store.ActiveProfileID = ""
- if len(store.Items) > 0 {
- store.ActiveProfileID = store.Items[0].ProfileID
- }
- }
-
- return s.persistSoraS3ProfilesStore(ctx, store)
-}
-
-// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
-func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
- targetID := strings.TrimSpace(profileID)
- if targetID == "" {
- return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
- }
-
- store, err := s.loadSoraS3ProfilesStore(ctx)
- if err != nil {
- return nil, err
- }
-
- targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
- if targetIndex < 0 {
- return nil, ErrSoraS3ProfileNotFound
- }
-
- store.ActiveProfileID = targetID
- store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
- if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
- return nil, err
- }
-
- profiles := convertSoraS3ProfilesStore(store)
- active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
- if active == nil {
- return nil, ErrSoraS3ProfileNotFound
- }
- return active, nil
-}
-
-func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
- raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
- if err == nil {
- trimmed := strings.TrimSpace(raw)
- if trimmed == "" {
- return &soraS3ProfilesStore{}, nil
- }
- var store soraS3ProfilesStore
- if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
- legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
- if legacyErr != nil {
- return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
- }
- if isEmptyLegacySoraS3Settings(legacy) {
- return &soraS3ProfilesStore{}, nil
- }
- now := time.Now().UTC().Format(time.RFC3339)
- return &soraS3ProfilesStore{
- ActiveProfileID: "default",
- Items: []soraS3ProfileStoreItem{
- {
- ProfileID: "default",
- Name: "Default",
- Enabled: legacy.Enabled,
- Endpoint: strings.TrimSpace(legacy.Endpoint),
- Region: strings.TrimSpace(legacy.Region),
- Bucket: strings.TrimSpace(legacy.Bucket),
- AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
- SecretAccessKey: legacy.SecretAccessKey,
- Prefix: strings.TrimSpace(legacy.Prefix),
- ForcePathStyle: legacy.ForcePathStyle,
- CDNURL: strings.TrimSpace(legacy.CDNURL),
- DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
- UpdatedAt: now,
- },
- },
- }, nil
- }
- normalized := normalizeSoraS3ProfilesStore(store)
- return &normalized, nil
- }
-
- if !errors.Is(err, ErrSettingNotFound) {
- return nil, fmt.Errorf("get sora s3 profiles: %w", err)
- }
-
- legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
- if legacyErr != nil {
- return nil, legacyErr
- }
- if isEmptyLegacySoraS3Settings(legacy) {
- return &soraS3ProfilesStore{}, nil
- }
-
- now := time.Now().UTC().Format(time.RFC3339)
- return &soraS3ProfilesStore{
- ActiveProfileID: "default",
- Items: []soraS3ProfileStoreItem{
- {
- ProfileID: "default",
- Name: "Default",
- Enabled: legacy.Enabled,
- Endpoint: strings.TrimSpace(legacy.Endpoint),
- Region: strings.TrimSpace(legacy.Region),
- Bucket: strings.TrimSpace(legacy.Bucket),
- AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
- SecretAccessKey: legacy.SecretAccessKey,
- Prefix: strings.TrimSpace(legacy.Prefix),
- ForcePathStyle: legacy.ForcePathStyle,
- CDNURL: strings.TrimSpace(legacy.CDNURL),
- DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
- UpdatedAt: now,
- },
- },
- }, nil
-}
-
-func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
- if store == nil {
- return fmt.Errorf("sora s3 profiles store cannot be nil")
- }
-
- normalized := normalizeSoraS3ProfilesStore(*store)
- data, err := json.Marshal(normalized)
- if err != nil {
- return fmt.Errorf("marshal sora s3 profiles: %w", err)
- }
-
- updates := map[string]string{
- SettingKeySoraS3Profiles: string(data),
- }
-
- active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
- if active == nil {
- updates[SettingKeySoraS3Enabled] = "false"
- updates[SettingKeySoraS3Endpoint] = ""
- updates[SettingKeySoraS3Region] = ""
- updates[SettingKeySoraS3Bucket] = ""
- updates[SettingKeySoraS3AccessKeyID] = ""
- updates[SettingKeySoraS3Prefix] = ""
- updates[SettingKeySoraS3ForcePathStyle] = "false"
- updates[SettingKeySoraS3CDNURL] = ""
- updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
- updates[SettingKeySoraS3SecretAccessKey] = ""
- } else {
- updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
- updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
- updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
- updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
- updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
- updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
- updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
- updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
- updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
- updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
- }
-
- if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
- return err
- }
-
- if s.onUpdate != nil {
- s.onUpdate()
- }
- if s.onS3Update != nil {
- s.onS3Update()
- }
- return nil
-}
-
-func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
- keys := []string{
- SettingKeySoraS3Enabled,
- SettingKeySoraS3Endpoint,
- SettingKeySoraS3Region,
- SettingKeySoraS3Bucket,
- SettingKeySoraS3AccessKeyID,
- SettingKeySoraS3SecretAccessKey,
- SettingKeySoraS3Prefix,
- SettingKeySoraS3ForcePathStyle,
- SettingKeySoraS3CDNURL,
- SettingKeySoraDefaultStorageQuotaBytes,
- }
-
- values, err := s.settingRepo.GetMultiple(ctx, keys)
- if err != nil {
- return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
- }
-
- result := &SoraS3Settings{
- Enabled: values[SettingKeySoraS3Enabled] == "true",
- Endpoint: values[SettingKeySoraS3Endpoint],
- Region: values[SettingKeySoraS3Region],
- Bucket: values[SettingKeySoraS3Bucket],
- AccessKeyID: values[SettingKeySoraS3AccessKeyID],
- SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
- SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
- Prefix: values[SettingKeySoraS3Prefix],
- ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
- CDNURL: values[SettingKeySoraS3CDNURL],
- }
- if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
- result.DefaultStorageQuotaBytes = v
- }
- return result, nil
-}
-
-func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
- seen := make(map[string]struct{}, len(store.Items))
- normalized := soraS3ProfilesStore{
- ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
- Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
- }
- now := time.Now().UTC().Format(time.RFC3339)
-
- for idx := range store.Items {
- item := store.Items[idx]
- item.ProfileID = strings.TrimSpace(item.ProfileID)
- if item.ProfileID == "" {
- item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
- }
- if _, exists := seen[item.ProfileID]; exists {
- continue
- }
- seen[item.ProfileID] = struct{}{}
-
- item.Name = strings.TrimSpace(item.Name)
- if item.Name == "" {
- item.Name = item.ProfileID
- }
- item.Endpoint = strings.TrimSpace(item.Endpoint)
- item.Region = strings.TrimSpace(item.Region)
- item.Bucket = strings.TrimSpace(item.Bucket)
- item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
- item.Prefix = strings.TrimSpace(item.Prefix)
- item.CDNURL = strings.TrimSpace(item.CDNURL)
- item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
- item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
- if item.UpdatedAt == "" {
- item.UpdatedAt = now
- }
- normalized.Items = append(normalized.Items, item)
- }
-
- if len(normalized.Items) == 0 {
- normalized.ActiveProfileID = ""
- return normalized
- }
-
- if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
- return normalized
- }
-
- normalized.ActiveProfileID = normalized.Items[0].ProfileID
- return normalized
-}
-
-func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
- if store == nil {
- return &SoraS3ProfileList{}
- }
- items := make([]SoraS3Profile, 0, len(store.Items))
- for idx := range store.Items {
- item := store.Items[idx]
- items = append(items, SoraS3Profile{
- ProfileID: item.ProfileID,
- Name: item.Name,
- IsActive: item.ProfileID == store.ActiveProfileID,
- Enabled: item.Enabled,
- Endpoint: item.Endpoint,
- Region: item.Region,
- Bucket: item.Bucket,
- AccessKeyID: item.AccessKeyID,
- SecretAccessKey: item.SecretAccessKey,
- SecretAccessKeyConfigured: item.SecretAccessKey != "",
- Prefix: item.Prefix,
- ForcePathStyle: item.ForcePathStyle,
- CDNURL: item.CDNURL,
- DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
- UpdatedAt: item.UpdatedAt,
- })
- }
- return &SoraS3ProfileList{
- ActiveProfileID: store.ActiveProfileID,
- Items: items,
- }
-}
-
-func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
- for idx := range items {
- if items[idx].ProfileID == activeProfileID {
- return &items[idx]
- }
- }
- if len(items) == 0 {
- return nil
- }
- return &items[0]
-}
-
-func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
- for idx := range items {
- if items[idx].ProfileID == profileID {
- return &items[idx]
- }
- }
- return nil
-}
-
-func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
- for idx := range items {
- if items[idx].ProfileID == activeProfileID {
- return &items[idx]
- }
- }
- if len(items) == 0 {
- return nil
- }
- return &items[0]
-}
-
-func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
- for idx := range items {
- if items[idx].ProfileID == profileID {
- return idx
- }
- }
- return -1
-}
-
-func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
- return findSoraS3ProfileIndex(items, profileID) >= 0
-}
-
-func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
- if settings == nil {
- return true
- }
- if settings.Enabled {
- return false
- }
- if strings.TrimSpace(settings.Endpoint) != "" {
- return false
- }
- if strings.TrimSpace(settings.Region) != "" {
- return false
- }
- if strings.TrimSpace(settings.Bucket) != "" {
- return false
- }
- if strings.TrimSpace(settings.AccessKeyID) != "" {
- return false
- }
- if settings.SecretAccessKey != "" {
- return false
- }
- if strings.TrimSpace(settings.Prefix) != "" {
- return false
- }
- if strings.TrimSpace(settings.CDNURL) != "" {
- return false
- }
- return settings.DefaultStorageQuotaBytes == 0
-}
-
-func maxInt64(value int64, min int64) int64 {
- if value < min {
- return min
- }
- return value
-}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 411939bb62..4b64267fec 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -41,7 +41,6 @@ type SystemSettings struct {
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
- SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
@@ -107,7 +106,6 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
- SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
@@ -116,46 +114,6 @@ type PublicSettings struct {
Version string
}
-// SoraS3Settings Sora S3 存储配置
-type SoraS3Settings struct {
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
- SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-// SoraS3Profile Sora S3 多配置项(服务内部模型)
-type SoraS3Profile struct {
- ProfileID string `json:"profile_id"`
- Name string `json:"name"`
- IsActive bool `json:"is_active"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
- SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
- UpdatedAt string `json:"updated_at"`
-}
-
-// SoraS3ProfileList Sora S3 多配置列表
-type SoraS3ProfileList struct {
- ActiveProfileID string `json:"active_profile_id"`
- Items []SoraS3Profile `json:"items"`
-}
-
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go
deleted file mode 100644
index eccc1acff7..0000000000
--- a/backend/internal/service/sora_account_service.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package service
-
-import "context"
-
-// SoraAccountRepository Sora 账号扩展表仓储接口
-// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。
-//
-// 设计说明:
-// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本
-// - Sora gateway 优先读取此表的字段以获得更好的查询性能
-// - 主表 accounts 通过 credentials JSON 字段也存储相同信息
-// - Token 刷新时需要同时更新两个表以保持数据一致性
-type SoraAccountRepository interface {
- // Upsert 创建或更新 Sora 账号扩展信息
- // accountID: 关联的 accounts.id
- // updates: 要更新的字段,支持 access_token、refresh_token、session_token
- //
- // 如果记录不存在则创建,存在则更新。
- // 用于:
- // 1. 创建 Sora 账号时初始化扩展表
- // 2. Token 刷新时同步更新扩展表
- Upsert(ctx context.Context, accountID int64, updates map[string]any) error
-
- // GetByAccountID 根据账号 ID 获取 Sora 扩展信息
- // 返回 nil, nil 表示记录不存在(非错误)
- GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error)
-
- // Delete 删除 Sora 账号扩展信息
- // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理
- Delete(ctx context.Context, accountID int64) error
-}
-
-// SoraAccount Sora 账号扩展信息
-// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本
-type SoraAccount struct {
- AccountID int64 // 关联的 accounts.id
- AccessToken string // OAuth access_token
- RefreshToken string // OAuth refresh_token
- SessionToken string // Session token(可选,用于 ST→AT 兜底)
-}
diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go
deleted file mode 100644
index 0a914d2db4..0000000000
--- a/backend/internal/service/sora_client.go
+++ /dev/null
@@ -1,117 +0,0 @@
-package service
-
-import (
- "context"
- "fmt"
- "net/http"
-)
-
-// SoraClient 定义直连 Sora 的任务操作接口。
-type SoraClient interface {
- Enabled() bool
- UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
- CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
- CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
- CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
- UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
- GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
- DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
- UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
- FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
- SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
- DeleteCharacter(ctx context.Context, account *Account, characterID string) error
- PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
- DeletePost(ctx context.Context, account *Account, postID string) error
- GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
- EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
- GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
- GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
-}
-
-// SoraImageRequest 图片生成请求参数
-type SoraImageRequest struct {
- Prompt string
- Width int
- Height int
- MediaID string
-}
-
-// SoraVideoRequest 视频生成请求参数
-type SoraVideoRequest struct {
- Prompt string
- Orientation string
- Frames int
- Model string
- Size string
- VideoCount int
- MediaID string
- RemixTargetID string
- CameoIDs []string
-}
-
-// SoraStoryboardRequest 分镜视频生成请求参数
-type SoraStoryboardRequest struct {
- Prompt string
- Orientation string
- Frames int
- Model string
- Size string
- MediaID string
-}
-
-// SoraImageTaskStatus 图片任务状态
-type SoraImageTaskStatus struct {
- ID string
- Status string
- ProgressPct float64
- URLs []string
- ErrorMsg string
-}
-
-// SoraVideoTaskStatus 视频任务状态
-type SoraVideoTaskStatus struct {
- ID string
- Status string
- ProgressPct int
- URLs []string
- GenerationID string
- ErrorMsg string
-}
-
-// SoraCameoStatus 角色处理中间态
-type SoraCameoStatus struct {
- Status string
- StatusMessage string
- DisplayNameHint string
- UsernameHint string
- ProfileAssetURL string
- InstructionSetHint any
- InstructionSet any
-}
-
-// SoraCharacterFinalizeRequest 角色定稿请求参数
-type SoraCharacterFinalizeRequest struct {
- CameoID string
- Username string
- DisplayName string
- ProfileAssetPointer string
- InstructionSet any
-}
-
-// SoraUpstreamError 上游错误
-type SoraUpstreamError struct {
- StatusCode int
- Message string
- Headers http.Header
- Body []byte
-}
-
-func (e *SoraUpstreamError) Error() string {
- if e == nil {
- return "sora upstream error"
- }
- if e.Message != "" {
- return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message)
- }
- return fmt.Sprintf("sora upstream error: %d", e.StatusCode)
-}
diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go
deleted file mode 100644
index e9d325f452..0000000000
--- a/backend/internal/service/sora_gateway_service.go
+++ /dev/null
@@ -1,1559 +0,0 @@
-package service
-
-import (
- "bytes"
- "context"
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log"
- "math"
- "math/rand"
- "mime"
- "net"
- "net/http"
- "net/url"
- "regexp"
- "strconv"
- "strings"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
- "github.com/gin-gonic/gin"
-)
-
-const soraImageInputMaxBytes = 20 << 20
-const soraImageInputMaxRedirects = 3
-const soraImageInputTimeout = 20 * time.Second
-const soraVideoInputMaxBytes = 200 << 20
-const soraVideoInputMaxRedirects = 3
-const soraVideoInputTimeout = 60 * time.Second
-
-var soraImageSizeMap = map[string]string{
- "gpt-image": "360",
- "gpt-image-landscape": "540",
- "gpt-image-portrait": "540",
-}
-
-var soraBlockedHostnames = map[string]struct{}{
- "localhost": {},
- "localhost.localdomain": {},
- "metadata.google.internal": {},
- "metadata.google.internal.": {},
-}
-
-var soraBlockedCIDRs = mustParseCIDRs([]string{
- "0.0.0.0/8",
- "10.0.0.0/8",
- "100.64.0.0/10",
- "127.0.0.0/8",
- "169.254.0.0/16",
- "172.16.0.0/12",
- "192.168.0.0/16",
- "224.0.0.0/4",
- "240.0.0.0/4",
- "::/128",
- "::1/128",
- "fc00::/7",
- "fe80::/10",
-})
-
-// SoraGatewayService handles forwarding requests to Sora upstream.
-type SoraGatewayService struct {
- soraClient SoraClient
- rateLimitService *RateLimitService
- httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
- cfg *config.Config
-}
-
-type soraWatermarkOptions struct {
- Enabled bool
- ParseMethod string
- ParseURL string
- ParseToken string
- FallbackOnFailure bool
- DeletePost bool
-}
-
-type soraCharacterOptions struct {
- SetPublic bool
- DeleteAfterGenerate bool
-}
-
-type soraCharacterFlowResult struct {
- CameoID string
- CharacterID string
- Username string
- DisplayName string
-}
-
-var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
-var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
-var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
-var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
-
-type soraPreflightChecker interface {
- PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
-}
-
-func NewSoraGatewayService(
- soraClient SoraClient,
- rateLimitService *RateLimitService,
- httpUpstream HTTPUpstream,
- cfg *config.Config,
-) *SoraGatewayService {
- return &SoraGatewayService{
- soraClient: soraClient,
- rateLimitService: rateLimitService,
- httpUpstream: httpUpstream,
- cfg: cfg,
- }
-}
-
-func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
- startTime := time.Now()
-
- // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
- if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
- if s.httpUpstream == nil {
- s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
- return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
- }
- return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
- }
-
- if s.soraClient == nil || !s.soraClient.Enabled() {
- if c != nil {
- c.JSON(http.StatusServiceUnavailable, gin.H{
- "error": gin.H{
- "type": "api_error",
- "message": "Sora 上游未配置",
- },
- })
- }
- return nil, errors.New("sora upstream not configured")
- }
-
- var reqBody map[string]any
- if err := json.Unmarshal(body, &reqBody); err != nil {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream)
- return nil, fmt.Errorf("parse request: %w", err)
- }
- reqModel, _ := reqBody["model"].(string)
- reqStream, _ := reqBody["stream"].(bool)
- if strings.TrimSpace(reqModel) == "" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
- return nil, errors.New("model is required")
- }
- originalModel := reqModel
-
- mappedModel := account.GetMappedModel(reqModel)
- var upstreamModel string
- if mappedModel != "" && mappedModel != reqModel {
- reqModel = mappedModel
- upstreamModel = mappedModel
- }
-
- modelCfg, ok := GetSoraModelConfig(reqModel)
- if !ok {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
- return nil, fmt.Errorf("unsupported model: %s", reqModel)
- }
- prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
- prompt = strings.TrimSpace(prompt)
- imageInput = strings.TrimSpace(imageInput)
- videoInput = strings.TrimSpace(videoInput)
- remixTargetID = strings.TrimSpace(remixTargetID)
-
- if videoInput != "" && modelCfg.Type != "video" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
- return nil, errors.New("video input only supports video models")
- }
- if videoInput != "" && imageInput != "" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
- return nil, errors.New("image input and video input cannot be used together")
- }
- characterOnly := videoInput != "" && prompt == ""
- if modelCfg.Type == "prompt_enhance" && prompt == "" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
- return nil, errors.New("prompt is required")
- }
- if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
- return nil, errors.New("prompt is required")
- }
-
- reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
- if cancel != nil {
- defer cancel()
- }
- if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
- if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
- return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
- }
- }
-
- if modelCfg.Type == "prompt_enhance" {
- enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
- if err != nil {
- return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
- }
- content := strings.TrimSpace(enhancedPrompt)
- if content == "" {
- content = prompt
- }
- var firstTokenMs *int
- if clientStream {
- ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
- if streamErr != nil {
- return nil, streamErr
- }
- firstTokenMs = ms
- } else if c != nil {
- c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
- }
- return &ForwardResult{
- RequestID: "",
- Model: originalModel,
- UpstreamModel: upstreamModel,
- Stream: clientStream,
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- Usage: ClaudeUsage{},
- MediaType: "prompt",
- }, nil
- }
-
- characterOpts := parseSoraCharacterOptions(reqBody)
- watermarkOpts := parseSoraWatermarkOptions(reqBody)
- var characterResult *soraCharacterFlowResult
- if videoInput != "" {
- videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
- if videoErr != nil {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
- return nil, videoErr
- }
- characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
- if videoErr != nil {
- return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
- }
- if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
- characterID := strings.TrimSpace(characterResult.CharacterID)
- defer func() {
- cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
- defer cancelCleanup()
- if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
- log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
- }
- }()
- }
- if characterOnly {
- content := "角色创建成功"
- if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
- content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
- }
- var firstTokenMs *int
- if clientStream {
- ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
- if streamErr != nil {
- return nil, streamErr
- }
- firstTokenMs = ms
- } else if c != nil {
- resp := buildSoraNonStreamResponse(content, reqModel)
- if characterResult != nil {
- resp["character_id"] = characterResult.CharacterID
- resp["cameo_id"] = characterResult.CameoID
- resp["character_username"] = characterResult.Username
- resp["character_display_name"] = characterResult.DisplayName
- }
- c.JSON(http.StatusOK, resp)
- }
- return &ForwardResult{
- RequestID: "",
- Model: originalModel,
- UpstreamModel: upstreamModel,
- Stream: clientStream,
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- Usage: ClaudeUsage{},
- MediaType: "prompt",
- }, nil
- }
- if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
- prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
- }
- }
-
- var imageData []byte
- imageFilename := ""
- if imageInput != "" {
- decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
- if err != nil {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
- return nil, err
- }
- imageData = decoded
- imageFilename = filename
- }
-
- mediaID := ""
- if len(imageData) > 0 {
- uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename)
- if err != nil {
- return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
- }
- mediaID = uploadID
- }
-
- taskID := ""
- var err error
- videoCount := parseSoraVideoCount(reqBody)
- switch modelCfg.Type {
- case "image":
- taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
- Prompt: prompt,
- Width: modelCfg.Width,
- Height: modelCfg.Height,
- MediaID: mediaID,
- })
- case "video":
- if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
- taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
- Prompt: formatSoraStoryboardPrompt(prompt),
- Orientation: modelCfg.Orientation,
- Frames: modelCfg.Frames,
- Model: modelCfg.Model,
- Size: modelCfg.Size,
- MediaID: mediaID,
- })
- } else {
- taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
- Prompt: prompt,
- Orientation: modelCfg.Orientation,
- Frames: modelCfg.Frames,
- Model: modelCfg.Model,
- Size: modelCfg.Size,
- VideoCount: videoCount,
- MediaID: mediaID,
- RemixTargetID: remixTargetID,
- CameoIDs: extractSoraCameoIDs(reqBody),
- })
- }
- default:
- err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
- }
- if err != nil {
- return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
- }
-
- if clientStream && c != nil {
- s.prepareSoraStream(c, taskID)
- }
-
- var mediaURLs []string
- videoGenerationID := ""
- mediaType := modelCfg.Type
- imageCount := 0
- imageSize := ""
- switch modelCfg.Type {
- case "image":
- urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream)
- if pollErr != nil {
- return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
- }
- mediaURLs = urls
- imageCount = len(urls)
- imageSize = soraImageSizeFromModel(reqModel)
- case "video":
- videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
- if pollErr != nil {
- return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
- }
- if videoStatus != nil {
- mediaURLs = videoStatus.URLs
- videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
- }
- default:
- mediaType = "prompt"
- }
-
- watermarkPostID := ""
- if modelCfg.Type == "video" && watermarkOpts.Enabled {
- watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
- if watermarkErr != nil {
- if !watermarkOpts.FallbackOnFailure {
- return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
- }
- log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
- } else if strings.TrimSpace(watermarkURL) != "" {
- mediaURLs = []string{strings.TrimSpace(watermarkURL)}
- watermarkPostID = strings.TrimSpace(postID)
- }
- }
-
- // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
- // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
- finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
- if watermarkPostID != "" && watermarkOpts.DeletePost {
- if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
- log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
- }
- }
-
- content := buildSoraContent(mediaType, finalURLs)
- var firstTokenMs *int
- if clientStream {
- ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
- if streamErr != nil {
- return nil, streamErr
- }
- firstTokenMs = ms
- } else if c != nil {
- response := buildSoraNonStreamResponse(content, reqModel)
- if len(finalURLs) > 0 {
- response["media_url"] = finalURLs[0]
- if len(finalURLs) > 1 {
- response["media_urls"] = finalURLs
- }
- }
- c.JSON(http.StatusOK, response)
- }
-
- return &ForwardResult{
- RequestID: taskID,
- Model: originalModel,
- UpstreamModel: upstreamModel,
- Stream: clientStream,
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- Usage: ClaudeUsage{},
- MediaType: mediaType,
- MediaURL: firstMediaURL(finalURLs),
- ImageCount: imageCount,
- ImageSize: imageSize,
- }, nil
-}
-
-func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
- if s == nil || s.cfg == nil {
- return ctx, nil
- }
- timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds
- if stream {
- timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds
- }
- if timeoutSeconds <= 0 {
- return ctx, nil
- }
- return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
-}
-
-func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
- opts := soraWatermarkOptions{
- Enabled: parseBoolWithDefault(body, "watermark_free", false),
- ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
- ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
- ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
- FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
- DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
- }
- if opts.ParseMethod == "" {
- opts.ParseMethod = "third_party"
- }
- return opts
-}
-
-func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
- return soraCharacterOptions{
- SetPublic: parseBoolWithDefault(body, "character_set_public", true),
- DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
- }
-}
-
-func parseSoraVideoCount(body map[string]any) int {
- if body == nil {
- return 1
- }
- keys := []string{"video_count", "videos", "n_variants"}
- for _, key := range keys {
- count := parseIntWithDefault(body, key, 0)
- if count > 0 {
- return clampInt(count, 1, 3)
- }
- }
- return 1
-}
-
-func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
- if body == nil {
- return def
- }
- val, ok := body[key]
- if !ok {
- return def
- }
- switch typed := val.(type) {
- case bool:
- return typed
- case int:
- return typed != 0
- case int32:
- return typed != 0
- case int64:
- return typed != 0
- case float64:
- return typed != 0
- case string:
- typed = strings.ToLower(strings.TrimSpace(typed))
- if typed == "true" || typed == "1" || typed == "yes" {
- return true
- }
- if typed == "false" || typed == "0" || typed == "no" {
- return false
- }
- }
- return def
-}
-
-func parseStringWithDefault(body map[string]any, key, def string) string {
- if body == nil {
- return def
- }
- val, ok := body[key]
- if !ok {
- return def
- }
- if str, ok := val.(string); ok {
- return str
- }
- return def
-}
-
-func parseIntWithDefault(body map[string]any, key string, def int) int {
- if body == nil {
- return def
- }
- val, ok := body[key]
- if !ok {
- return def
- }
- switch typed := val.(type) {
- case int:
- return typed
- case int32:
- return int(typed)
- case int64:
- return int(typed)
- case float64:
- return int(typed)
- case string:
- parsed, err := strconv.Atoi(strings.TrimSpace(typed))
- if err == nil {
- return parsed
- }
- }
- return def
-}
-
-func clampInt(v, minVal, maxVal int) int {
- if v < minVal {
- return minVal
- }
- if v > maxVal {
- return maxVal
- }
- return v
-}
-
-func extractSoraCameoIDs(body map[string]any) []string {
- if body == nil {
- return nil
- }
- raw, ok := body["cameo_ids"]
- if !ok {
- return nil
- }
- switch typed := raw.(type) {
- case []string:
- out := make([]string, 0, len(typed))
- for _, item := range typed {
- item = strings.TrimSpace(item)
- if item != "" {
- out = append(out, item)
- }
- }
- return out
- case []any:
- out := make([]string, 0, len(typed))
- for _, item := range typed {
- str, ok := item.(string)
- if !ok {
- continue
- }
- str = strings.TrimSpace(str)
- if str != "" {
- out = append(out, str)
- }
- }
- return out
- default:
- return nil
- }
-}
-
-func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
- cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
- if err != nil {
- return nil, err
- }
-
- cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
- if err != nil {
- return nil, err
- }
- username := processSoraCharacterUsername(cameoStatus.UsernameHint)
- displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
- if displayName == "" {
- displayName = "Character"
- }
- profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
- if profileAssetURL == "" {
- return nil, errors.New("profile asset url not found in cameo status")
- }
-
- avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
- if err != nil {
- return nil, err
- }
- assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
- if err != nil {
- return nil, err
- }
- instructionSet := cameoStatus.InstructionSetHint
- if instructionSet == nil {
- instructionSet = cameoStatus.InstructionSet
- }
-
- characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
- CameoID: strings.TrimSpace(cameoID),
- Username: username,
- DisplayName: displayName,
- ProfileAssetPointer: assetPointer,
- InstructionSet: instructionSet,
- })
- if err != nil {
- return nil, err
- }
-
- if opts.SetPublic {
- if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
- return nil, err
- }
- }
-
- return &soraCharacterFlowResult{
- CameoID: strings.TrimSpace(cameoID),
- CharacterID: strings.TrimSpace(characterID),
- Username: strings.TrimSpace(username),
- DisplayName: displayName,
- }, nil
-}
-
-func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
- timeout := 10 * time.Minute
- interval := 5 * time.Second
- maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
- if maxAttempts < 1 {
- maxAttempts = 1
- }
-
- var lastErr error
- consecutiveErrors := 0
- for attempt := 0; attempt < maxAttempts; attempt++ {
- status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
- if err != nil {
- lastErr = err
- consecutiveErrors++
- if consecutiveErrors >= 3 {
- break
- }
- if attempt < maxAttempts-1 {
- if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
- return nil, sleepErr
- }
- }
- continue
- }
- consecutiveErrors = 0
- if status == nil {
- if attempt < maxAttempts-1 {
- if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
- return nil, sleepErr
- }
- }
- continue
- }
- currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
- statusMessage := strings.TrimSpace(status.StatusMessage)
- if currentStatus == "failed" {
- if statusMessage == "" {
- statusMessage = "character creation failed"
- }
- return nil, errors.New(statusMessage)
- }
- if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
- return status, nil
- }
- if attempt < maxAttempts-1 {
- if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
- return nil, sleepErr
- }
- }
- }
- if lastErr != nil {
- return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
- }
- return nil, errors.New("cameo processing timeout")
-}
-
-func processSoraCharacterUsername(usernameHint string) string {
- usernameHint = strings.TrimSpace(usernameHint)
- if usernameHint == "" {
- usernameHint = "character"
- }
- if strings.Contains(usernameHint, ".") {
- parts := strings.Split(usernameHint, ".")
- usernameHint = strings.TrimSpace(parts[len(parts)-1])
- }
- if usernameHint == "" {
- usernameHint = "character"
- }
- return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100)
-}
-
-func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
- generationID = strings.TrimSpace(generationID)
- if generationID == "" {
- return "", "", errors.New("generation id is required for watermark-free mode")
- }
- postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
- if err != nil {
- return "", "", err
- }
- postID = strings.TrimSpace(postID)
- if postID == "" {
- return "", "", errors.New("watermark-free publish returned empty post id")
- }
-
- switch opts.ParseMethod {
- case "custom":
- urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
- if parseErr != nil {
- return "", postID, parseErr
- }
- return strings.TrimSpace(urlVal), postID, nil
- case "", "third_party":
- return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
- default:
- return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
- }
-}
-
-func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
- switch statusCode {
- case 401, 402, 403, 404, 429, 529:
- return true
- default:
- return statusCode >= 500
- }
-}
-
-func buildSoraNonStreamResponse(content, model string) map[string]any {
- return map[string]any{
- "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
- "object": "chat.completion",
- "created": time.Now().Unix(),
- "model": model,
- "choices": []any{
- map[string]any{
- "index": 0,
- "message": map[string]any{
- "role": "assistant",
- "content": content,
- },
- "finish_reason": "stop",
- },
- },
- }
-}
-
-func soraImageSizeFromModel(model string) string {
- modelLower := strings.ToLower(model)
- if size, ok := soraImageSizeMap[modelLower]; ok {
- return size
- }
- if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") {
- return "540"
- }
- return "360"
-}
-
-func soraProErrorMessage(model, upstreamMsg string) string {
- modelLower := strings.ToLower(model)
- if strings.Contains(modelLower, "sora2pro-hd") {
- return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号"
- }
- if strings.Contains(modelLower, "sora2pro") {
- return "当前账号无法使用 Sora Pro 模型,请更换模型或账号"
- }
- return ""
-}
-
-func firstMediaURL(urls []string) string {
- if len(urls) == 0 {
- return ""
- }
- return urls[0]
-}
-
-func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string {
- if path == "" {
- return path
- }
- prefix := "/sora/media"
- values := url.Values{}
- if rawQuery != "" {
- if parsed, err := url.ParseQuery(rawQuery); err == nil {
- values = parsed
- }
- }
-
- signKey := ""
- ttlSeconds := 0
- if s != nil && s.cfg != nil {
- signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey)
- ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds
- }
- values.Del("sig")
- values.Del("expires")
- signingQuery := values.Encode()
- if signKey != "" && ttlSeconds > 0 {
- expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix()
- signature := SignSoraMediaURL(path, signingQuery, expires, signKey)
- if signature != "" {
- values.Set("expires", strconv.FormatInt(expires, 10))
- values.Set("sig", signature)
- prefix = "/sora/media-signed"
- }
- }
-
- encoded := values.Encode()
- if encoded == "" {
- return prefix + path
- }
- return prefix + path + "?" + encoded
-}
-
-func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) {
- if c == nil {
- return
- }
- c.Header("Content-Type", "text/event-stream")
- c.Header("Cache-Control", "no-cache")
- c.Header("Connection", "keep-alive")
- c.Header("X-Accel-Buffering", "no")
- if strings.TrimSpace(requestID) != "" {
- c.Header("x-request-id", requestID)
- }
-}
-
-func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) {
- if c == nil {
- return nil, nil
- }
- writer := c.Writer
- flusher, _ := writer.(http.Flusher)
-
- chunk := map[string]any{
- "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
- "object": "chat.completion.chunk",
- "created": time.Now().Unix(),
- "model": model,
- "choices": []any{
- map[string]any{
- "index": 0,
- "delta": map[string]any{
- "content": content,
- },
- },
- },
- }
- encoded, _ := jsonMarshalRaw(chunk)
- if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil {
- return nil, err
- }
- if flusher != nil {
- flusher.Flush()
- }
- ms := int(time.Since(startTime).Milliseconds())
- finalChunk := map[string]any{
- "id": chunk["id"],
- "object": "chat.completion.chunk",
- "created": time.Now().Unix(),
- "model": model,
- "choices": []any{
- map[string]any{
- "index": 0,
- "delta": map[string]any{},
- "finish_reason": "stop",
- },
- },
- }
- finalEncoded, _ := jsonMarshalRaw(finalChunk)
- if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil {
- return &ms, err
- }
- if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil {
- return &ms, err
- }
- if flusher != nil {
- flusher.Flush()
- }
- return &ms, nil
-}
-
-func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) {
- if c == nil {
- return
- }
- if stream {
- flusher, _ := c.Writer.(http.Flusher)
- errorData := map[string]any{
- "error": map[string]string{
- "type": errType,
- "message": message,
- },
- }
- jsonBytes, err := json.Marshal(errorData)
- if err != nil {
- _ = c.Error(err)
- return
- }
- errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
- _, _ = fmt.Fprint(c.Writer, errorEvent)
- _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
- if flusher != nil {
- flusher.Flush()
- }
- return
- }
- c.JSON(status, gin.H{
- "error": gin.H{
- "type": errType,
- "message": message,
- },
- })
-}
-
-func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error {
- if err == nil {
- return nil
- }
- var upstreamErr *SoraUpstreamError
- if errors.As(err, &upstreamErr) {
- accountID := int64(0)
- if account != nil {
- accountID = account.ID
- }
- logger.LegacyPrintf(
- "service.sora",
- "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
- accountID,
- model,
- upstreamErr.StatusCode,
- strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
- strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
- strings.TrimSpace(upstreamErr.Message),
- truncateForLog(upstreamErr.Body, 1024),
- )
- if s.rateLimitService != nil && account != nil {
- s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
- }
- if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
- var responseHeaders http.Header
- if upstreamErr.Headers != nil {
- responseHeaders = upstreamErr.Headers.Clone()
- }
- return &UpstreamFailoverError{
- StatusCode: upstreamErr.StatusCode,
- ResponseBody: upstreamErr.Body,
- ResponseHeaders: responseHeaders,
- }
- }
- msg := upstreamErr.Message
- if override := soraProErrorMessage(model, msg); override != "" {
- msg = override
- }
- s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream)
- return err
- }
- if errors.Is(err, context.DeadlineExceeded) {
- s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream)
- return err
- }
- s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream)
- return err
-}
-
-func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
- interval := s.pollInterval()
- maxAttempts := s.pollMaxAttempts()
- lastPing := time.Now()
- for attempt := 0; attempt < maxAttempts; attempt++ {
- status, err := s.soraClient.GetImageTask(ctx, account, taskID)
- if err != nil {
- return nil, err
- }
- switch strings.ToLower(status.Status) {
- case "succeeded", "completed":
- return status.URLs, nil
- case "failed":
- if status.ErrorMsg != "" {
- return nil, errors.New(status.ErrorMsg)
- }
- return nil, errors.New("sora image generation failed")
- }
- if stream {
- s.maybeSendPing(c, &lastPing)
- }
- if err := sleepWithContext(ctx, interval); err != nil {
- return nil, err
- }
- }
- return nil, errors.New("sora image generation timeout")
-}
-
-func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
- interval := s.pollInterval()
- maxAttempts := s.pollMaxAttempts()
- lastPing := time.Now()
- for attempt := 0; attempt < maxAttempts; attempt++ {
- status, err := s.soraClient.GetVideoTask(ctx, account, taskID)
- if err != nil {
- return nil, err
- }
- switch strings.ToLower(status.Status) {
- case "completed", "succeeded":
- return status, nil
- case "failed":
- if status.ErrorMsg != "" {
- return nil, errors.New(status.ErrorMsg)
- }
- return nil, errors.New("sora video generation failed")
- }
- if stream {
- s.maybeSendPing(c, &lastPing)
- }
- if err := sleepWithContext(ctx, interval); err != nil {
- return nil, err
- }
- }
- return nil, errors.New("sora video generation timeout")
-}
-
-func (s *SoraGatewayService) pollInterval() time.Duration {
- if s == nil || s.cfg == nil {
- return 2 * time.Second
- }
- interval := s.cfg.Sora.Client.PollIntervalSeconds
- if interval <= 0 {
- interval = 2
- }
- return time.Duration(interval) * time.Second
-}
-
-func (s *SoraGatewayService) pollMaxAttempts() int {
- if s == nil || s.cfg == nil {
- return 600
- }
- maxAttempts := s.cfg.Sora.Client.MaxPollAttempts
- if maxAttempts <= 0 {
- maxAttempts = 600
- }
- return maxAttempts
-}
-
-func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) {
- if c == nil {
- return
- }
- interval := 10 * time.Second
- if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 {
- interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second
- }
- if time.Since(*lastPing) < interval {
- return
- }
- if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil {
- if flusher, ok := c.Writer.(http.Flusher); ok {
- flusher.Flush()
- }
- *lastPing = time.Now()
- }
-}
-
-func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string {
- if len(urls) == 0 {
- return urls
- }
- output := make([]string, 0, len(urls))
- for _, raw := range urls {
- raw = strings.TrimSpace(raw)
- if raw == "" {
- continue
- }
- if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
- output = append(output, raw)
- continue
- }
- pathVal := raw
- if !strings.HasPrefix(pathVal, "/") {
- pathVal = "/" + pathVal
- }
- output = append(output, s.buildSoraMediaURL(pathVal, ""))
- }
- return output
-}
-
-// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符,
-// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。
-func jsonMarshalRaw(v any) ([]byte, error) {
- var buf bytes.Buffer
- enc := json.NewEncoder(&buf)
- enc.SetEscapeHTML(false)
- if err := enc.Encode(v); err != nil {
- return nil, err
- }
- // Encode 会追加换行符,去掉它
- b := buf.Bytes()
- if len(b) > 0 && b[len(b)-1] == '\n' {
- b = b[:len(b)-1]
- }
- return b, nil
-}
-
-func buildSoraContent(mediaType string, urls []string) string {
- switch mediaType {
- case "image":
- parts := make([]string, 0, len(urls))
- for _, u := range urls {
- parts = append(parts, fmt.Sprintf("", u))
- }
- return strings.Join(parts, "\n")
- case "video":
- if len(urls) == 0 {
- return ""
- }
- return fmt.Sprintf("```html\n\n```", urls[0])
- default:
- return ""
- }
-}
-
-func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) {
- if body == nil {
- return "", "", "", ""
- }
- if v, ok := body["remix_target_id"].(string); ok {
- remixTargetID = strings.TrimSpace(v)
- }
- if v, ok := body["image"].(string); ok {
- imageInput = v
- }
- if v, ok := body["video"].(string); ok {
- videoInput = v
- }
- if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" {
- prompt = v
- }
- if messages, ok := body["messages"].([]any); ok {
- builder := strings.Builder{}
- for _, raw := range messages {
- msg, ok := raw.(map[string]any)
- if !ok {
- continue
- }
- role, _ := msg["role"].(string)
- if role != "" && role != "user" {
- continue
- }
- content := msg["content"]
- text, img, vid := parseSoraMessageContent(content)
- if text != "" {
- if builder.Len() > 0 {
- _, _ = builder.WriteString("\n")
- }
- _, _ = builder.WriteString(text)
- }
- if imageInput == "" && img != "" {
- imageInput = img
- }
- if videoInput == "" && vid != "" {
- videoInput = vid
- }
- }
- if prompt == "" {
- prompt = builder.String()
- }
- }
- if remixTargetID == "" {
- remixTargetID = extractRemixTargetIDFromPrompt(prompt)
- }
- prompt = cleanRemixLinkFromPrompt(prompt)
- return prompt, imageInput, videoInput, remixTargetID
-}
-
-func parseSoraMessageContent(content any) (text, imageInput, videoInput string) {
- switch val := content.(type) {
- case string:
- return val, "", ""
- case []any:
- builder := strings.Builder{}
- for _, item := range val {
- itemMap, ok := item.(map[string]any)
- if !ok {
- continue
- }
- t, _ := itemMap["type"].(string)
- switch t {
- case "text":
- if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" {
- if builder.Len() > 0 {
- _, _ = builder.WriteString("\n")
- }
- _, _ = builder.WriteString(txt)
- }
- case "image_url":
- if imageInput == "" {
- if urlVal, ok := itemMap["image_url"].(map[string]any); ok {
- imageInput = fmt.Sprintf("%v", urlVal["url"])
- } else if urlStr, ok := itemMap["image_url"].(string); ok {
- imageInput = urlStr
- }
- }
- case "video_url":
- if videoInput == "" {
- if urlVal, ok := itemMap["video_url"].(map[string]any); ok {
- videoInput = fmt.Sprintf("%v", urlVal["url"])
- } else if urlStr, ok := itemMap["video_url"].(string); ok {
- videoInput = urlStr
- }
- }
- }
- }
- return builder.String(), imageInput, videoInput
- default:
- return "", "", ""
- }
-}
-
-func isSoraStoryboardPrompt(prompt string) bool {
- prompt = strings.TrimSpace(prompt)
- if prompt == "" {
- return false
- }
- return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
-}
-
-func formatSoraStoryboardPrompt(prompt string) string {
- prompt = strings.TrimSpace(prompt)
- if prompt == "" {
- return ""
- }
- matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
- if len(matches) == 0 {
- return prompt
- }
- firstBracketPos := strings.Index(prompt, "[")
- instructions := ""
- if firstBracketPos > 0 {
- instructions = strings.TrimSpace(prompt[:firstBracketPos])
- }
- shots := make([]string, 0, len(matches))
- for i, match := range matches {
- if len(match) < 3 {
- continue
- }
- duration := strings.TrimSpace(match[1])
- scene := strings.TrimSpace(match[2])
- if scene == "" {
- continue
- }
- shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
- }
- if len(shots) == 0 {
- return prompt
- }
- timeline := strings.Join(shots, "\n\n")
- if instructions == "" {
- return timeline
- }
- return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
-}
-
-func extractRemixTargetIDFromPrompt(prompt string) string {
- prompt = strings.TrimSpace(prompt)
- if prompt == "" {
- return ""
- }
- return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
-}
-
-func cleanRemixLinkFromPrompt(prompt string) string {
- prompt = strings.TrimSpace(prompt)
- if prompt == "" {
- return prompt
- }
- cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
- cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
- cleaned = strings.Join(strings.Fields(cleaned), " ")
- return strings.TrimSpace(cleaned)
-}
-
-func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
- raw := strings.TrimSpace(input)
- if raw == "" {
- return nil, "", errors.New("empty image input")
- }
- if strings.HasPrefix(raw, "data:") {
- parts := strings.SplitN(raw, ",", 2)
- if len(parts) != 2 {
- return nil, "", errors.New("invalid data url")
- }
- meta := parts[0]
- payload := parts[1]
- decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
- if err != nil {
- return nil, "", err
- }
- ext := ""
- if strings.HasPrefix(meta, "data:") {
- metaParts := strings.SplitN(meta[5:], ";", 2)
- if len(metaParts) > 0 {
- if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 {
- ext = exts[0]
- }
- }
- }
- filename := "image" + ext
- return decoded, filename, nil
- }
- if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
- return downloadSoraImageInput(ctx, raw)
- }
- decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
- if err != nil {
- return nil, "", errors.New("invalid base64 image")
- }
- return decoded, "image.png", nil
-}
-
-func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
- raw := strings.TrimSpace(input)
- if raw == "" {
- return nil, errors.New("empty video input")
- }
- if strings.HasPrefix(raw, "data:") {
- parts := strings.SplitN(raw, ",", 2)
- if len(parts) != 2 {
- return nil, errors.New("invalid video data url")
- }
- decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
- if err != nil {
- return nil, errors.New("invalid base64 video")
- }
- if len(decoded) == 0 {
- return nil, errors.New("empty video data")
- }
- return decoded, nil
- }
- if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
- return downloadSoraVideoInput(ctx, raw)
- }
- decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
- if err != nil {
- return nil, errors.New("invalid base64 video")
- }
- if len(decoded) == 0 {
- return nil, errors.New("empty video data")
- }
- return decoded, nil
-}
-
-func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
- parsed, err := validateSoraRemoteURL(rawURL)
- if err != nil {
- return nil, "", err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
- if err != nil {
- return nil, "", err
- }
- client := &http.Client{
- Timeout: soraImageInputTimeout,
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- if len(via) >= soraImageInputMaxRedirects {
- return errors.New("too many redirects")
- }
- return validateSoraRemoteURLValue(req.URL)
- },
- }
- resp, err := client.Do(req)
- if err != nil {
- return nil, "", err
- }
- defer func() { _ = resp.Body.Close() }()
- if resp.StatusCode != http.StatusOK {
- return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
- }
- data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
- if err != nil {
- return nil, "", err
- }
- ext := fileExtFromURL(parsed.String())
- if ext == "" {
- ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
- }
- filename := "image" + ext
- return data, filename, nil
-}
-
-func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
- parsed, err := validateSoraRemoteURL(rawURL)
- if err != nil {
- return nil, err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
- if err != nil {
- return nil, err
- }
- client := &http.Client{
- Timeout: soraVideoInputTimeout,
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- if len(via) >= soraVideoInputMaxRedirects {
- return errors.New("too many redirects")
- }
- return validateSoraRemoteURLValue(req.URL)
- },
- }
- resp, err := client.Do(req)
- if err != nil {
- return nil, err
- }
- defer func() { _ = resp.Body.Close() }()
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
- }
- data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
- if err != nil {
- return nil, err
- }
- if len(data) == 0 {
- return nil, errors.New("empty video content")
- }
- return data, nil
-}
-
-func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
- if maxBytes <= 0 {
- return nil, errors.New("invalid max bytes limit")
- }
- decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
- limited := io.LimitReader(decoder, maxBytes+1)
- data, err := io.ReadAll(limited)
- if err != nil {
- return nil, err
- }
- if int64(len(data)) > maxBytes {
- return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
- }
- return data, nil
-}
-
-func validateSoraRemoteURL(raw string) (*url.URL, error) {
- if strings.TrimSpace(raw) == "" {
- return nil, errors.New("empty remote url")
- }
- parsed, err := url.Parse(raw)
- if err != nil {
- return nil, fmt.Errorf("invalid remote url: %w", err)
- }
- if err := validateSoraRemoteURLValue(parsed); err != nil {
- return nil, err
- }
- return parsed, nil
-}
-
-func validateSoraRemoteURLValue(parsed *url.URL) error {
- if parsed == nil {
- return errors.New("invalid remote url")
- }
- scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
- if scheme != "http" && scheme != "https" {
- return errors.New("only http/https remote url is allowed")
- }
- if parsed.User != nil {
- return errors.New("remote url cannot contain userinfo")
- }
- host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
- if host == "" {
- return errors.New("remote url missing host")
- }
- if _, blocked := soraBlockedHostnames[host]; blocked {
- return errors.New("remote url is not allowed")
- }
- if ip := net.ParseIP(host); ip != nil {
- if isSoraBlockedIP(ip) {
- return errors.New("remote url is not allowed")
- }
- return nil
- }
- ips, err := net.LookupIP(host)
- if err != nil {
- return fmt.Errorf("resolve remote url failed: %w", err)
- }
- for _, ip := range ips {
- if isSoraBlockedIP(ip) {
- return errors.New("remote url is not allowed")
- }
- }
- return nil
-}
-
-func isSoraBlockedIP(ip net.IP) bool {
- if ip == nil {
- return true
- }
- for _, cidr := range soraBlockedCIDRs {
- if cidr.Contains(ip) {
- return true
- }
- }
- return false
-}
-
-func mustParseCIDRs(values []string) []*net.IPNet {
- out := make([]*net.IPNet, 0, len(values))
- for _, val := range values {
- _, cidr, err := net.ParseCIDR(val)
- if err != nil {
- continue
- }
- out = append(out, cidr)
- }
- return out
-}
diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go
deleted file mode 100644
index 2fef600c53..0000000000
--- a/backend/internal/service/sora_gateway_service_test.go
+++ /dev/null
@@ -1,564 +0,0 @@
-//go:build unit
-
-package service
-
-import (
- "context"
- "encoding/json"
- "errors"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
-)
-
-var _ SoraClient = (*stubSoraClientForPoll)(nil)
-
-type stubSoraClientForPoll struct {
- imageStatus *SoraImageTaskStatus
- videoStatus *SoraVideoTaskStatus
- imageCalls int
- videoCalls int
- enhanced string
- enhanceErr error
- storyboard bool
- videoReq SoraVideoRequest
- parseErr error
- postCalls int
- deleteCalls int
-}
-
-func (s *stubSoraClientForPoll) Enabled() bool { return true }
-func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) {
- return "task-image", nil
-}
-func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
- s.videoReq = req
- return "task-video", nil
-}
-func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
- s.storyboard = true
- return "task-video", nil
-}
-func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
- return "cameo-1", nil
-}
-func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
- return &SoraCameoStatus{
- Status: "finalized",
- StatusMessage: "Completed",
- DisplayNameHint: "Character",
- UsernameHint: "user.character",
- ProfileAssetURL: "https://example.com/avatar.webp",
- }, nil
-}
-func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
- return []byte("avatar"), nil
-}
-func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
- return "asset-pointer", nil
-}
-func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
- return "character-1", nil
-}
-func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
- return nil
-}
-func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
- return nil
-}
-func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
- s.postCalls++
- return "s_post", nil
-}
-func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
- s.deleteCalls++
- return nil
-}
-func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
- if s.parseErr != nil {
- return "", s.parseErr
- }
- return "https://example.com/no-watermark.mp4", nil
-}
-func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
- if s.enhanced != "" {
- return s.enhanced, s.enhanceErr
- }
- return "enhanced prompt", s.enhanceErr
-}
-func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
- s.imageCalls++
- return s.imageStatus, nil
-}
-func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
- s.videoCalls++
- return s.videoStatus, nil
-}
-
-func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
- client := &stubSoraClientForPoll{
- imageStatus: &SoraImageTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/a.png"},
- },
- }
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- service := NewSoraGatewayService(client, nil, nil, cfg)
-
- urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false)
- require.NoError(t, err)
- require.Equal(t, []string{"https://example.com/a.png"}, urls)
- require.Equal(t, 1, client.imageCalls)
-}
-
-func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
- client := &stubSoraClientForPoll{
- enhanced: "cinematic prompt",
- }
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- svc := NewSoraGatewayService(client, nil, nil, cfg)
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Status: StatusActive,
- Credentials: map[string]any{
- "model_mapping": map[string]any{
- "prompt-enhance-short-10s": "prompt-enhance-short-15s",
- },
- },
- }
- body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
-
- result, err := svc.Forward(context.Background(), nil, account, body, false)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "prompt", result.MediaType)
- require.Equal(t, "prompt-enhance-short-10s", result.Model)
- require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
-}
-
-func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
- client := &stubSoraClientForPoll{
- videoStatus: &SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/v.mp4"},
- },
- }
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- svc := NewSoraGatewayService(client, nil, nil, cfg)
- account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
- body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
-
- result, err := svc.Forward(context.Background(), nil, account, body, false)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.True(t, client.storyboard)
-}
-
-func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
- client := &stubSoraClientForPoll{
- videoStatus: &SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/v.mp4"},
- },
- }
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- svc := NewSoraGatewayService(client, nil, nil, cfg)
- account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
- body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
-
- result, err := svc.Forward(context.Background(), nil, account, body, false)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, 3, client.videoReq.VideoCount)
-}
-
-func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
- client := &stubSoraClientForPoll{}
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- svc := NewSoraGatewayService(client, nil, nil, cfg)
- account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
- body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
-
- result, err := svc.Forward(context.Background(), nil, account, body, false)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "prompt", result.MediaType)
- require.Equal(t, 0, client.videoCalls)
-}
-
-func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
- client := &stubSoraClientForPoll{
- videoStatus: &SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/original.mp4"},
- GenerationID: "gen_1",
- },
- parseErr: errors.New("parse failed"),
- }
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- svc := NewSoraGatewayService(client, nil, nil, cfg)
- account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
- body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
-
- result, err := svc.Forward(context.Background(), nil, account, body, false)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
- require.Equal(t, 1, client.postCalls)
- require.Equal(t, 0, client.deleteCalls)
-}
-
-func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
- client := &stubSoraClientForPoll{
- videoStatus: &SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/original.mp4"},
- GenerationID: "gen_1",
- },
- }
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- svc := NewSoraGatewayService(client, nil, nil, cfg)
- account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
- body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
-
- result, err := svc.Forward(context.Background(), nil, account, body, false)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
- require.Equal(t, 1, client.postCalls)
- require.Equal(t, 1, client.deleteCalls)
-}
-
-func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
- client := &stubSoraClientForPoll{
- videoStatus: &SoraVideoTaskStatus{
- Status: "failed",
- ErrorMsg: "reject",
- },
- }
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- service := NewSoraGatewayService(client, nil, nil, cfg)
-
- status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
- require.Error(t, err)
- require.Nil(t, status)
- require.Contains(t, err.Error(), "reject")
- require.Equal(t, 1, client.videoCalls)
-}
-
-func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
- cfg := &config.Config{
- Gateway: config.GatewayConfig{
- SoraMediaSigningKey: "test-key",
- SoraMediaSignedURLTTLSeconds: 600,
- },
- }
- service := NewSoraGatewayService(nil, nil, nil, cfg)
-
- url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "")
- require.Contains(t, url, "/sora/media-signed")
- require.Contains(t, url, "expires=")
- require.Contains(t, url, "sig=")
-}
-
-func TestNormalizeSoraMediaURLs_Empty(t *testing.T) {
- svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
- result := svc.normalizeSoraMediaURLs(nil)
- require.Empty(t, result)
-
- result = svc.normalizeSoraMediaURLs([]string{})
- require.Empty(t, result)
-}
-
-func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) {
- svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
- urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"}
- result := svc.normalizeSoraMediaURLs(urls)
- require.Equal(t, urls, result)
-}
-
-func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) {
- cfg := &config.Config{}
- svc := NewSoraGatewayService(nil, nil, nil, cfg)
- urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"}
- result := svc.normalizeSoraMediaURLs(urls)
- require.Len(t, result, 2)
- require.Contains(t, result[0], "/sora/media")
- require.Contains(t, result[1], "/sora/media")
-}
-
-func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) {
- svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
- urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"}
- result := svc.normalizeSoraMediaURLs(urls)
- require.Len(t, result, 2)
-}
-
-func TestBuildSoraContent_Image(t *testing.T) {
- content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"})
- require.Contains(t, content, "")
- require.Contains(t, content, "")
-}
-
-func TestBuildSoraContent_Video(t *testing.T) {
- content := buildSoraContent("video", []string{"https://a.com/v.mp4"})
- require.Contains(t, content, "