diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index 2d79d6414f..07cf0149a2 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -30,6 +30,7 @@ def __init__( protocol: str = "http", host: str = "localhost", port: int = 8080, + api_key: str = None, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, endpoint: Optional[str] = "/", @@ -49,6 +50,7 @@ def __init__( self.protocol = protocol self.host = host self.port = port + self.api_key = api_key self.endpoint = endpoint self.input_datatype = input_datatype self.output_datatype = output_datatype @@ -137,7 +139,14 @@ def _validate_env(self): @property def endpoint_path(self): - return f"{self.protocol}://{self.host}:{self.port}{self.endpoint}" + if self.api_key: + return f"{self.host}{self.endpoint}" + else: + return f"{self.protocol}://{self.host}:{self.port}{self.endpoint}" + + @property + def api_key_value(self): + return self.api_key def register_microservice( diff --git a/comps/cores/mega/orchestrator.py b/comps/cores/mega/orchestrator.py index e7d181dfab..3f1df22bb3 100644 --- a/comps/cores/mega/orchestrator.py +++ b/comps/cores/mega/orchestrator.py @@ -243,6 +243,7 @@ async def execute( ): # send the cur_node request/reply endpoint = self.services[cur_node].endpoint_path + access_token = self.services[cur_node].api_key_value llm_parameters_dict = llm_parameters.dict() is_llm_vlm = self.services[cur_node].service_type in (ServiceType.LLM, ServiceType.LVM) @@ -263,14 +264,27 @@ async def execute( if ENABLE_OPEA_TELEMETRY else contextlib.nullcontext() ): - response = requests.post( - url=endpoint, - data=json.dumps(inputs), - headers={"Content-type": "application/json"}, - proxies={"http": None}, - stream=True, - timeout=1000, - ) + if access_token: + response = requests.post( + url=endpoint, + data=json.dumps(inputs), + headers={"Content-type": "application/json", "Authorization": f"Bearer {access_token}"}, + proxies={"http": None}, + stream=True, + timeout=1000, + ) + else: + response = requests.post( + url=endpoint, + data=json.dumps(inputs), + headers={ + "Content-type": "application/json", + }, + proxies={"http": None}, + stream=True, + timeout=1000, + ) + downstream = runtime_graph.downstream(cur_node) if downstream: assert len(downstream) == 1, "Not supported multiple stream downstreams yet!" @@ -291,11 +305,25 @@ def generate(): buffered_chunk_str += self.extract_chunk_str(chunk) is_last = chunk.endswith("[DONE]\n\n") if (buffered_chunk_str and buffered_chunk_str[-1] in hitted_ends) or is_last: - res = requests.post( - url=downstream_endpoint, - data=json.dumps({"text": buffered_chunk_str}), - proxies={"http": None}, - ) + if access_token: + res = requests.post( + url=downstream_endpoint, + data=json.dumps({"text": buffered_chunk_str}), + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {access_token}", + }, + proxies={"http": None}, + ) + else: + res = requests.post( + url=downstream_endpoint, + data=json.dumps({"text": buffered_chunk_str}), + headers={ + "Content-type": "application/json", + }, + proxies={"http": None}, + ) res_json = res.json() if "text" in res_json: res_txt = res_json["text"]