Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def process_outputs(self, prev_nodes: List, result_dict: Dict) -> Dict:
all_outputs.update(result_dict[prev_node])
return all_outputs

async def wrap_iterable(self, aiterable, is_first=True):
def wrap_iterable(self, iterable, is_first=True):

with tracer.start_as_current_span("llm_generate_stream") if ENABLE_OPEA_TELEMETRY else contextlib.nullcontext():
while True:
Expand All @@ -217,10 +217,10 @@ async def wrap_iterable(self, aiterable, is_first=True):
else contextlib.nullcontext()
): # else tracer.start_as_current_span(f"llm_generate_stream_next_token")
try:
token = await anext(aiterable)
token = next(iterable)
yield token
is_first = False
except StopAsyncIteration:
except StopIteration:
# Exiting the iterable loop cleanly
break
except Exception as e:
Expand Down Expand Up @@ -274,36 +274,34 @@ async def execute(
hitted_ends = [".", "?", "!", "。", ",", "!"]
downstream_endpoint = self.services[downstream[0]].endpoint_path

async def generate():
def generate():
token_start = req_start
if response:
# response.elapsed = time until first headers received
buffered_chunk_str = ""
is_first = True
async for chunk in self.wrap_iterable(response.content.iter_chunked(None)):
for chunk in self.wrap_iterable(response.iter_content(chunk_size=None)):
if chunk:
if downstream:
chunk = chunk.decode("utf-8")
buffered_chunk_str += self.extract_chunk_str(chunk)
is_last = chunk.endswith("[DONE]\n\n")
if (buffered_chunk_str and buffered_chunk_str[-1] in hitted_ends) or is_last:
async with aiohttp.ClientSession() as downstream_session:
res = await downstream_session.post(
url=downstream_endpoint,
data=json.dumps({"text": buffered_chunk_str}),
proxy=None,
)
res_json = await res.json()
if "text" in res_json:
res_txt = res_json["text"]
else:
raise Exception("Other response types not supported yet!")
buffered_chunk_str = "" # clear
async for item in self.token_generator(
res_txt, token_start, is_first=is_first, is_last=is_last
):
yield item
token_start = time.time()
res = requests.post(
url=downstream_endpoint,
data=json.dumps({"text": buffered_chunk_str}),
proxies={"http": None},
)
res_json = res.json()
if "text" in res_json:
res_txt = res_json["text"]
else:
raise Exception("Other response types not supported yet!")
buffered_chunk_str = "" # clear
yield from self.token_generator(
res_txt, token_start, is_first=is_first, is_last=is_last
)
token_start = time.time()
else:
token_start = self.metrics.token_update(token_start, is_first)
yield chunk
Expand Down
7 changes: 3 additions & 4 deletions comps/cores/mega/orchestrator_with_yaml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import asyncio
import json
import re
from collections import OrderedDict
Expand All @@ -24,10 +23,10 @@ def __init__(self, yaml_file_path: str):
if not is_valid:
raise Exception("Invalid mega graph!")

async def execute(self, cur_node: str, inputs: Dict):
def execute(self, cur_node: str, inputs: Dict):
# send the cur_node request/reply
endpoint = self.docs["opea_micro_services"][cur_node]["endpoint"]
response = await asyncio.to_thread(requests.post, url=endpoint, data=json.dumps(inputs), proxies={"http": None})
response = requests.post(url=endpoint, data=json.dumps(inputs), proxies={"http": None})
print(response)
return response.json()

Expand All @@ -49,7 +48,7 @@ async def schedule(self, initial_inputs: Dict):
inputs = initial_inputs
else:
inputs = self.process_outputs(self.predecessors(node))
response = await self.execute(node, inputs)
response = self.execute(node, inputs)
self.result_dict[node] = response

def _load_from_yaml(self):
Expand Down