@@ -14,13 +14,21 @@ import (
1414	"github.com/higress-group/wasm-go/pkg/wrapper" 
1515)
1616
17+ type  azureServiceUrlType  int 
18+ 
1719const  (
1820	pathAzurePrefix            =  "/openai" 
1921	pathAzureModelPlaceholder  =  "{model}" 
2022	pathAzureWithModelPrefix   =  "/openai/deployments/"  +  pathAzureModelPlaceholder 
2123	queryAzureApiVersion       =  "api-version" 
2224)
2325
26+ const  (
27+ 	azureServiceUrlTypeFull  azureServiceUrlType  =  iota 
28+ 	azureServiceUrlTypeWithDeployment 
29+ 	azureServiceUrlTypeDomainOnly 
30+ )
31+ 
2432var  (
2533	azureModelIrrelevantApis  =  map [ApiName ]bool {
2634		ApiNameModels :              true ,
3139		ApiNameRetrieveFile :        true ,
3240		ApiNameRetrieveFileContent : true ,
3341	}
34- 	regexAzureModelWithPath  =  regexp .MustCompile ("/openai/deployments/(.+?)(/.* |$)" )
42+ 	regexAzureModelWithPath  =  regexp .MustCompile ("/openai/deployments/(.+?)(?:/(.*) |$)" )
3543)
3644
3745// azureProvider is the provider for Azure OpenAI service. 
@@ -82,32 +90,44 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
8290
8391	modelSubMatch  :=  regexAzureModelWithPath .FindStringSubmatch (serviceUrl .Path )
8492	defaultModel  :=  "placeholder" 
93+ 	var  serviceUrlType  azureServiceUrlType 
8594	if  modelSubMatch  !=  nil  {
8695		defaultModel  =  modelSubMatch [1 ]
96+ 		if  modelSubMatch [2 ] !=  ""  {
97+ 			serviceUrlType  =  azureServiceUrlTypeFull 
98+ 		} else  {
99+ 			serviceUrlType  =  azureServiceUrlTypeWithDeployment 
100+ 		}
87101		log .Debugf ("azureProvider: found default model from serviceUrl: %s" , defaultModel )
88102	} else  {
103+ 		serviceUrlType  =  azureServiceUrlTypeDomainOnly 
89104		log .Debugf ("azureProvider: no default model found in serviceUrl" )
90105	}
106+ 	log .Debugf ("azureProvider: serviceUrlType=%d" , serviceUrlType )
91107
92108	config .setDefaultCapabilities (m .DefaultCapabilities ())
93109	apiVersion  :=  serviceUrl .Query ().Get (queryAzureApiVersion )
94110	log .Debugf ("azureProvider: using %s: %s" , queryAzureApiVersion , apiVersion )
95111	return  & azureProvider {
96- 		config :       config ,
97- 		serviceUrl :   serviceUrl ,
98- 		apiVersion :   apiVersion ,
99- 		defaultModel : defaultModel ,
100- 		contextCache : createContextCache (& config ),
112+ 		config :             config ,
113+ 		serviceUrl :         serviceUrl ,
114+ 		serviceUrlType :     serviceUrlType ,
115+ 		serviceUrlFullPath : serviceUrl .Path  +  "?"  +  serviceUrl .RawQuery ,
116+ 		apiVersion :         apiVersion ,
117+ 		defaultModel :       defaultModel ,
118+ 		contextCache :       createContextCache (& config ),
101119	}, nil 
102120}
103121
104122type  azureProvider  struct  {
105123	config  ProviderConfig 
106124
107- 	contextCache  * contextCache 
108- 	serviceUrl    * url.URL 
109- 	apiVersion    string 
110- 	defaultModel  string 
125+ 	contextCache        * contextCache 
126+ 	serviceUrl          * url.URL 
127+ 	serviceUrlFullPath  string 
128+ 	serviceUrlType      azureServiceUrlType 
129+ 	apiVersion          string 
130+ 	defaultModel        string 
111131}
112132
113133func  (m  * azureProvider ) GetProviderType () string  {
@@ -152,21 +172,31 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
152172		return  originalPath 
153173	}
154174
175+ 	if  m .serviceUrlType  ==  azureServiceUrlTypeFull  {
176+ 		log .Debugf ("azureProvider: use configured path %s" , m .serviceUrlFullPath )
177+ 		return  m .serviceUrlFullPath 
178+ 	}
179+ 
155180	log .Debugf ("azureProvider: original request path: %s" , originalPath )
156181	path  :=  util .MapRequestPathByCapability (string (apiName ), originalPath , m .config .capabilities )
157182	log .Debugf ("azureProvider: path: %s" , path )
158183	if  strings .Contains (path , pathAzureModelPlaceholder ) {
159184		log .Debugf ("azureProvider: path contains placeholder: %s" , path )
160- 		model  :=  ctx .GetStringContext (ctxKeyFinalRequestModel , "" )
161- 		log .Debugf ("azureProvider: model from context: %s" , model )
162- 		if  model  ==  ""  {
185+ 		var  model  string 
186+ 		if  m .serviceUrlType  ==  azureServiceUrlTypeWithDeployment  {
163187			model  =  m .defaultModel 
164- 			log .Debugf ("azureProvider: use default model: %s" , model )
188+ 		} else  {
189+ 			model  =  ctx .GetStringContext (ctxKeyFinalRequestModel , "" )
190+ 			log .Debugf ("azureProvider: model from context: %s" , model )
191+ 			if  model  ==  ""  {
192+ 				model  =  m .defaultModel 
193+ 				log .Debugf ("azureProvider: use default model: %s" , model )
194+ 			}
165195		}
166196		path  =  strings .ReplaceAll (path , pathAzureModelPlaceholder , model )
167197		log .Debugf ("azureProvider: model replaced path: %s" , path )
168198	}
169- 	path  =  fmt . Sprintf ( "%s?%s=%s" ,  path ,  queryAzureApiVersion ,  m . apiVersion ) 
199+ 	path  =  path   +   "?"   +   m . serviceUrl . RawQuery 
170200	log .Debugf ("azureProvider: final path: %s" , path )
171201
172202	return  path 
0 commit comments