support OnlineEmbeddingModuleBase batch input#725
Conversation
根据input类型输出 fix flake8 fix onlineEmbed支持batch输入 fix _parse_response fix
|
完善一下测例 |
ChenJiahaoST
left a comment
There was a problem hiding this comment.
commented
另外:
- 是否需要一起修改一下本地部署的embed?这块如果需要,辛苦整体修改一下
- 正确填写pr description模板内容
- 考虑一下如果需要list的场景,额外增加一个入参batch(默认64),以免请求体过大导致一些场景oom。
- 测例补充一下batch的场景
| if isinstance(input, str): | ||
| return response['data'][0]['embedding'] | ||
| else: | ||
| return [res['embedding'] for res in response['data']] |
There was a problem hiding this comment.
建议使用.get、预先判断列表长度等,增加_parse_response的健壮性。更多的, 比如验证列表向量长度是否与input一致
|
|
||
| def _parse_response(self, response: Dict[str, Any]) -> List[float]: | ||
| return response['data'][0]['embedding'] | ||
| def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]: |
There was a problem hiding this comment.
-> List[float]
按照实际类型改对
|
|
||
| def _parse_response(self, response: Dict[str, Any]) -> List[float]: | ||
| return response['output']['embeddings'][0]['embedding'] | ||
| def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]: |
There was a problem hiding this comment.
same, 顺便优化一下取值逻辑,尽量用get,增加鲁棒性
|
|
||
| def _parse_response(self, response: Dict[str, Any]) -> List[float]: | ||
| return response['embeddings'][0]['embedding'] | ||
| def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]: |
| return json_data | ||
|
|
||
| def _parse_response(self, response: Dict[str, Any]) -> List[float]: | ||
| def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]: |
There was a problem hiding this comment.
豆包输入限制最多为1段文本+1张图片
| return json_data | ||
|
|
||
| def _parse_response(self, response: Dict[str, Any]) -> List[float]: | ||
| def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]: |
| return [] | ||
| if isinstance(input, str): | ||
| return response['data'][0]['embedding'] | ||
| return data[0].get("embedding", []) |
| def _parse_response(self, response: Dict, input: Union[List, str]) -> Union[List[List[float]], List[float]]: | ||
| data = response.get("data", []) | ||
| if not data: | ||
| return [] |
There was a problem hiding this comment.
data没有元素是否是异常?raise?
There was a problem hiding this comment.
batchsize、num_workers在基类中添加
| else: | ||
| raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) | ||
|
|
||
| def run_embed_batch(self, input: Union[List, str], data: List, proxies): |
There was a problem hiding this comment.
如沟通,run_embed_batch增加一下超过上限后的动态降批次机制,可参考flagembed的实现机制
ChenJiahaoST
left a comment
There was a problem hiding this comment.
- 测例跑通
- 测例中添加一下考虑降批次、并行/非并行的情况
| } | ||
|
|
||
| def forward(self, input: Union[List, str], **kwargs) -> List[float]: | ||
| def forward(self, input: Union[List, str], **kwargs): |
| with requests.post(self._embed_url, json=data, headers=self._headers, proxies=proxies) as r: | ||
| if r.status_code == 200: | ||
| return self._parse_response(r.json()) | ||
| if isinstance(data, List): |
There was a problem hiding this comment.
判断类型尽量不使用typing, if isinstance(data, list)
|
|
||
| def _parse_response(self, response: Dict[str, Any]) -> List[float]: | ||
| return response['data'][0]['embedding'] | ||
| def run_embed_batch_parallel(self, input: Union[List, str], data: List, proxies, **kwargs): |
There was a problem hiding this comment.
我想了一下,看是否还是把并行和非并行提取一下公共逻辑,比如:
if workers == 1:
with requests.Session() as sess:
for start, chunk in chunks:
embeds = self._send_chunk_with_autoshrink(sess, chunk, proxies, **kwargs)
results[start:start+len(embeds)] = embeds
else:
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = [ex.submit(self._send_chunk_with_autoshrink, None, chunk, proxies, **kwargs)
for _, chunk in chunks]
# 用 as_completed 获取完成结果,并按 start 索引回填
for (start, _), fut in zip(chunks, as_completed(futs)):
embeds = fut.result()
results[start:start+len(embeds)] = embeds
_send_chunk_with_autoshrink中包含当前批次内部降批次的逻辑:
def _send_chunk_with_autoshrink(self, sess: requests.Session, chunk: List[str], proxies, **kwargs) -> List[List[float]]:
# 对单个 micro-batch 发送;若报“批太大/体积过大”,减半重试
cur = list(chunk)
local_sess = sess or requests.Session()
while True:
try:
resp = self._post_once(local_sess, self._build_payload(cur, **kwargs), proxies)
out = self._parse_response(resp, cur) # 此处传“当前批”,不是整包
return out # List[List[float]]
except requests.HTTPError as e:
code = getattr(e.response, "status_code", None)
text = (getattr(e.response, "text", "") or "").lower()
too_big = code in (413, 422) or ("too many" in text) or ("max" in text and "input" in text)
if too_big and len(cur) > 1:
cur = cur[:max(1, len(cur)//2)]
continue
# 401/其他错误:直接抛
raise
从而不直接修改self._batch_size
| if isinstance(input, str): | ||
| return data[0].get('embedding', []) | ||
| else: | ||
| return [res.get('embedding', []) for res in data] No newline at end of file |
844fd7b to
1dcc928
Compare
ChenJiahaoST
left a comment
There was a problem hiding this comment.
- 修复跳出的bug,完善重试机制
- 完善请求最好加一下timeout
| self._batch_size = max(self._batch_size // 2, 1) | ||
| data = self._encapsulated_data(input, **kwargs) | ||
| break | ||
| break |
There was a problem hiding this comment.
降批次后,break跳出接一个break,是否逻辑出错?直接跳出来了
| vec = self._parse_response(r.json(), input) | ||
| start = i * self._batch_size | ||
| end = start + len(vec) | ||
| ret[start:end] = vec |
There was a problem hiding this comment.
ret[start: start + len(vec)] = vec
完善OnlineEmbeddingModuleBase解析结果时固定取0造成不支持batch输入的现状