Skip to content

support OnlineEmbeddingModuleBase batch input#725

Merged
wzh1994 merged 16 commits intoLazyAGI:mainfrom
wangtianxiong-sensetime:fix_onlineEmbeddingModule_parse_response
Sep 5, 2025
Merged

support OnlineEmbeddingModuleBase batch input#725
wzh1994 merged 16 commits intoLazyAGI:mainfrom
wangtianxiong-sensetime:fix_onlineEmbeddingModule_parse_response

Conversation

@wangtianxiong-sensetime
Copy link
Contributor

完善OnlineEmbeddingModuleBase解析结果时固定取0造成不支持batch输入的现状

wangtianxiong added 2 commits August 26, 2025 10:14
根据input类型输出

fix flake8

fix

onlineEmbed支持batch输入

fix _parse_response

fix
@wzh1994 wzh1994 requested a review from ChenJiahaoST August 26, 2025 03:55
@wzh1994
Copy link
Contributor

wzh1994 commented Aug 26, 2025

完善一下测例

Copy link
Collaborator

@ChenJiahaoST ChenJiahaoST left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented
另外:

  1. 是否需要一起修改一下本地部署的embed?这块如果需要,辛苦整体修改一下
  2. 正确填写pr description模板内容
  3. 考虑一下如果需要list的场景,额外增加一个入参batch(默认64),以免请求体过大导致一些场景oom。
  4. 测例补充一下batch的场景

if isinstance(input, str):
return response['data'][0]['embedding']
else:
return [res['embedding'] for res in response['data']]
Copy link
Collaborator

@ChenJiahaoST ChenJiahaoST Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用.get、预先判断列表长度等,增加_parse_response的健壮性。更多的, 比如验证列表向量长度是否与input一致


def _parse_response(self, response: Dict[str, Any]) -> List[float]:
return response['data'][0]['embedding']
def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> List[float]
按照实际类型改对


def _parse_response(self, response: Dict[str, Any]) -> List[float]:
return response['output']['embeddings'][0]['embedding']
def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, 顺便优化一下取值逻辑,尽量用get,增加鲁棒性


def _parse_response(self, response: Dict[str, Any]) -> List[float]:
return response['embeddings'][0]['embedding']
def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

return json_data

def _parse_response(self, response: Dict[str, Any]) -> List[float]:
def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否也需要支持多模态的列表输入?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

豆包输入限制最多为1段文本+1张图片

return json_data

def _parse_response(self, response: Dict[str, Any]) -> List[float]:
def _parse_response(self, response: Dict[str, Any], input: Union[List, str]) -> List[float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same,rerank是否需要修改

@wangtianxiong-sensetime wangtianxiong-sensetime changed the title 支持OnlineEmbeddingModuleBase batch输入 support OnlineEmbeddingModuleBase batch input Aug 27, 2025
Copy link
Collaborator

@ChenJiahaoST ChenJiahaoST left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

批处理和并行放在base实现

return []
if isinstance(input, str):
return response['data'][0]['embedding']
return data[0].get("embedding", [])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单引号优先

def _parse_response(self, response: Dict, input: Union[List, str]) -> Union[List[List[float]], List[float]]:
data = response.get("data", [])
if not data:
return []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data没有元素是否是异常?raise?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batchsize、num_workers在基类中添加

else:
raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)]))

def run_embed_batch(self, input: Union[List, str], data: List, proxies):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如沟通,run_embed_batch增加一下超过上限后的动态降批次机制,可参考flagembed的实现机制

Copy link
Collaborator

@ChenJiahaoST ChenJiahaoST left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 测例跑通
  2. 测例中添加一下考虑降批次、并行/非并行的情况

}

def forward(self, input: Union[List, str], **kwargs) -> List[float]:
def forward(self, input: Union[List, str], **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward还是加一下返回类型吧

with requests.post(self._embed_url, json=data, headers=self._headers, proxies=proxies) as r:
if r.status_code == 200:
return self._parse_response(r.json())
if isinstance(data, List):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断类型尽量不使用typing, if isinstance(data, list)


def _parse_response(self, response: Dict[str, Any]) -> List[float]:
return response['data'][0]['embedding']
def run_embed_batch_parallel(self, input: Union[List, str], data: List, proxies, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想了一下,看是否还是把并行和非并行提取一下公共逻辑,比如:

if workers == 1:
        with requests.Session() as sess:
            for start, chunk in chunks:
                embeds = self._send_chunk_with_autoshrink(sess, chunk, proxies, **kwargs)
                results[start:start+len(embeds)] = embeds
    else:
        with ThreadPoolExecutor(max_workers=workers) as ex:
            futs = [ex.submit(self._send_chunk_with_autoshrink, None, chunk, proxies, **kwargs)
                    for _, chunk in chunks]
            # 用 as_completed 获取完成结果,并按 start 索引回填
            for (start, _), fut in zip(chunks, as_completed(futs)):
                embeds = fut.result()
                results[start:start+len(embeds)] = embeds

_send_chunk_with_autoshrink中包含当前批次内部降批次的逻辑:

def _send_chunk_with_autoshrink(self, sess: requests.Session, chunk: List[str], proxies, **kwargs) -> List[List[float]]:
    # 对单个 micro-batch 发送;若报“批太大/体积过大”,减半重试
    cur = list(chunk)
    local_sess = sess or requests.Session()
    while True:
        try:
            resp = self._post_once(local_sess, self._build_payload(cur, **kwargs), proxies)
            out = self._parse_response(resp, cur)  # 此处传“当前批”,不是整包
            return out  # List[List[float]]
        except requests.HTTPError as e:
            code = getattr(e.response, "status_code", None)
            text = (getattr(e.response, "text", "") or "").lower()
            too_big = code in (413, 422) or ("too many" in text) or ("max" in text and "input" in text)
            if too_big and len(cur) > 1:
                cur = cur[:max(1, len(cur)//2)]
                continue
            # 401/其他错误:直接抛
            raise

从而不直接修改self._batch_size

if isinstance(input, str):
return data[0].get('embedding', [])
else:
return [res.get('embedding', []) for res in data] No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最后空一行

@wangtianxiong-sensetime wangtianxiong-sensetime force-pushed the fix_onlineEmbeddingModule_parse_response branch from 844fd7b to 1dcc928 Compare September 2, 2025 07:25
Copy link
Collaborator

@ChenJiahaoST ChenJiahaoST left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 修复跳出的bug,完善重试机制
  2. 完善请求最好加一下timeout

self._batch_size = max(self._batch_size // 2, 1)
data = self._encapsulated_data(input, **kwargs)
break
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

降批次后,break跳出接一个break,是否逻辑出错?直接跳出来了

vec = self._parse_response(r.json(), input)
start = i * self._batch_size
end = start + len(vec)
ret[start:end] = vec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ret[start: start + len(vec)] = vec

@mergify mergify bot added the lint_pass label Sep 5, 2025
@wzh1994 wzh1994 merged commit a3db624 into LazyAGI:main Sep 5, 2025
19 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants