Skip to content

Commit 0f1afcd

Browse files
authored
fix(ai-proxy): Do not change the configured components of Azure URL (#2782)
1 parent 19d1548 commit 0f1afcd

File tree

1 file changed

+45
-15
lines changed
  • plugins/wasm-go/extensions/ai-proxy/provider

1 file changed

+45
-15
lines changed

plugins/wasm-go/extensions/ai-proxy/provider/azure.go

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@ import (
1414
"github.com/higress-group/wasm-go/pkg/wrapper"
1515
)
1616

17+
type azureServiceUrlType int
18+
1719
const (
1820
pathAzurePrefix = "/openai"
1921
pathAzureModelPlaceholder = "{model}"
2022
pathAzureWithModelPrefix = "/openai/deployments/" + pathAzureModelPlaceholder
2123
queryAzureApiVersion = "api-version"
2224
)
2325

26+
const (
27+
azureServiceUrlTypeFull azureServiceUrlType = iota
28+
azureServiceUrlTypeWithDeployment
29+
azureServiceUrlTypeDomainOnly
30+
)
31+
2432
var (
2533
azureModelIrrelevantApis = map[ApiName]bool{
2634
ApiNameModels: true,
@@ -31,7 +39,7 @@ var (
3139
ApiNameRetrieveFile: true,
3240
ApiNameRetrieveFileContent: true,
3341
}
34-
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)")
42+
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(?:/(.*)|$)")
3543
)
3644

3745
// azureProvider is the provider for Azure OpenAI service.
@@ -82,32 +90,44 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
8290

8391
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
8492
defaultModel := "placeholder"
93+
var serviceUrlType azureServiceUrlType
8594
if modelSubMatch != nil {
8695
defaultModel = modelSubMatch[1]
96+
if modelSubMatch[2] != "" {
97+
serviceUrlType = azureServiceUrlTypeFull
98+
} else {
99+
serviceUrlType = azureServiceUrlTypeWithDeployment
100+
}
87101
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
88102
} else {
103+
serviceUrlType = azureServiceUrlTypeDomainOnly
89104
log.Debugf("azureProvider: no default model found in serviceUrl")
90105
}
106+
log.Debugf("azureProvider: serviceUrlType=%d", serviceUrlType)
91107

92108
config.setDefaultCapabilities(m.DefaultCapabilities())
93109
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
94110
log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion)
95111
return &azureProvider{
96-
config: config,
97-
serviceUrl: serviceUrl,
98-
apiVersion: apiVersion,
99-
defaultModel: defaultModel,
100-
contextCache: createContextCache(&config),
112+
config: config,
113+
serviceUrl: serviceUrl,
114+
serviceUrlType: serviceUrlType,
115+
serviceUrlFullPath: serviceUrl.Path + "?" + serviceUrl.RawQuery,
116+
apiVersion: apiVersion,
117+
defaultModel: defaultModel,
118+
contextCache: createContextCache(&config),
101119
}, nil
102120
}
103121

104122
type azureProvider struct {
105123
config ProviderConfig
106124

107-
contextCache *contextCache
108-
serviceUrl *url.URL
109-
apiVersion string
110-
defaultModel string
125+
contextCache *contextCache
126+
serviceUrl *url.URL
127+
serviceUrlFullPath string
128+
serviceUrlType azureServiceUrlType
129+
apiVersion string
130+
defaultModel string
111131
}
112132

113133
func (m *azureProvider) GetProviderType() string {
@@ -152,21 +172,31 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
152172
return originalPath
153173
}
154174

175+
if m.serviceUrlType == azureServiceUrlTypeFull {
176+
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
177+
return m.serviceUrlFullPath
178+
}
179+
155180
log.Debugf("azureProvider: original request path: %s", originalPath)
156181
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
157182
log.Debugf("azureProvider: path: %s", path)
158183
if strings.Contains(path, pathAzureModelPlaceholder) {
159184
log.Debugf("azureProvider: path contains placeholder: %s", path)
160-
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
161-
log.Debugf("azureProvider: model from context: %s", model)
162-
if model == "" {
185+
var model string
186+
if m.serviceUrlType == azureServiceUrlTypeWithDeployment {
163187
model = m.defaultModel
164-
log.Debugf("azureProvider: use default model: %s", model)
188+
} else {
189+
model = ctx.GetStringContext(ctxKeyFinalRequestModel, "")
190+
log.Debugf("azureProvider: model from context: %s", model)
191+
if model == "" {
192+
model = m.defaultModel
193+
log.Debugf("azureProvider: use default model: %s", model)
194+
}
165195
}
166196
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
167197
log.Debugf("azureProvider: model replaced path: %s", path)
168198
}
169-
path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion)
199+
path = path + "?" + m.serviceUrl.RawQuery
170200
log.Debugf("azureProvider: final path: %s", path)
171201

172202
return path

0 commit comments

Comments
 (0)