Skip to content

Commit 90d6b91

Browse files
authored
OnlineChatModule type auto-detection logic enhancement [skip ci] (#844)
1 parent 85c921c commit 90d6b91

2 files changed

Lines changed: 64 additions & 30 deletions

File tree

lazyllm/module/llms/onlinemodule/map_model_type.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from lazyllm import LOG
2-
2+
import re
3+
from typing import Optional, Dict
34
MODEL_MAPPING = {
45
# ===== OpenAI (LLM) =====
56
'gpt-5': 'vlm',
@@ -435,45 +436,75 @@
435436
'deepseek-chat': 'llm',
436437
'deepseek-reasoner': 'llm'
437438
}
439+
_TOKEN_MAP = {
440+
'embed': ('embedding', 'embed'),
441+
'stt': ('whisper', 'paraformer', 'asr', 'stt', 'transcribe'),
442+
'tts': ('tts', 'cosyvoice', 'nova-tts'),
443+
'vlm': ('qwen-vl', 'vl', 'vision', 'caption', 'omni', 'vlm', 'seed'),
444+
'ocr': ('ocr',),
445+
'rerank': ('rerank',),
446+
'cross_modal_embed': ('cross_modal', 'multimodal-embedding', 'embedding-vision'),
447+
'sd': ('dall', 'wan', 'sora', 'image', 'video', 't2i', 't2v'),
448+
}
449+
_SUFFIX_RE = re.compile(
450+
r'(?:'
451+
r'|[-_.]\d{4}(?:-\d{2}-\d{2})?' # date-like suffix
452+
r'|[-_.]v?\d+[a-z0-9\-]*' # version-like suffix
453+
r')+$',
454+
flags=re.I
455+
)
438456

457+
def _normalize_key(name: str) -> str:
458+
'''Normalize model name for consistent comparison.'''
459+
if not name:
460+
return ''
461+
s = name.strip().lower()
462+
s = re.sub(r'[^a-z0-9]+', '-', s)
463+
s = s.strip('-')
464+
s = re.sub(_SUFFIX_RE, '', s)
465+
return s
439466

440-
def special_model_rule(model_name: str):
441-
'''Keyword-based matching'''
442-
return MODEL_MAPPING.get(model_name)
467+
def _contains_token(name: str, token: str) -> bool:
468+
'''Check if token exists as a separate unit in the model name.'''
469+
if not token:
470+
return False
471+
patterns = [
472+
rf'(^|[-_.]){re.escape(token)}($|[-_.])',
473+
rf'{re.escape(token)}$',
474+
rf'^{re.escape(token)}'
475+
]
476+
return any(re.search(p, name) for p in patterns)
443477

478+
NORMALIZED_MODEL_MAPPING: Dict[str, str] = {
479+
_normalize_key(k): v for k, v in MODEL_MAPPING.items()
480+
}
444481

445-
def feature_keyword_rule(model_name: str):
446-
'''Exact match'''
447-
lower_name = model_name.lower()
448-
if 'embedding' in lower_name:
449-
return 'embed'
450-
if 'whisper' in lower_name or 'paraformer' in lower_name or 'asr' in lower_name or 'stt' in lower_name:
451-
return 'stt'
452-
if 'tts' in lower_name or 'cosyvoice' in lower_name:
453-
return 'tts'
454-
if 'vl' in lower_name or 'vision' in lower_name or 'caption' in lower_name:
455-
return 'vlm'
456-
if 'ocr' in lower_name:
457-
return 'ocr'
458-
if 'rerank' in lower_name:
459-
return 'rerank'
460-
if 'cross_modal' in lower_name or 'multimodal-embedding' in lower_name:
461-
return 'cross_modal_embed'
462-
if any(kw in lower_name for kw in ('dall', 'cogview', 'wan', 'sd', 'image', 'video')):
463-
return 'sd'
464-
return None
482+
def feature_keyword_rule(model_name: str) -> Optional[str]:
483+
'''Identify the model type by normalized name or keyword match.'''
484+
stripped_input = _normalize_key(model_name)
485+
if stripped_input in NORMALIZED_MODEL_MAPPING:
486+
return NORMALIZED_MODEL_MAPPING[stripped_input]
465487

466-
def get_model_type(model_name) -> str:
488+
for model_type, tokens in _TOKEN_MAP.items():
489+
for tok in tokens:
490+
if _contains_token(stripped_input, tok):
491+
return model_type
492+
return None
493+
def special_model_rule(model_name: str) -> Optional[str]:
467494
'''Determine the model category'''
468-
for rule in [special_model_rule, feature_keyword_rule]:
495+
return MODEL_MAPPING.get(model_name)
496+
497+
def get_model_type(model_name: str) -> str:
498+
model_name = model_name.lower()
499+
if not model_name:
500+
return 'llm'
501+
for rule in (special_model_rule, feature_keyword_rule):
469502
try:
470503
result = rule(model_name)
471-
if result is not None:
504+
if result:
472505
LOG.info(f'Model: {model_name} classified as type: {result} by rule: {rule.__name__}')
473506
return result
474507
except Exception as e:
475-
LOG.warning(f'Rule {rule.__name__} execution error: {e}')
476-
continue
477-
508+
LOG.warning(f'Rule {rule.__name__} failed: {e}')
478509
LOG.warning(f'Cannot classify model type for: {model_name}. Defaulting to "llm" instead.')
479510
return 'llm'

tests/basic_tests/Modules/test_map_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
test_models = {
55
'llm': [
6+
'GPT5'
7+
'Qwen3-32B',
8+
'qwen3-coder-plus',
69
'sensechat-128k',
710
'glm-4-5-airx',
811
'qwen3-coder-plus-2025-09-23',

0 commit comments

Comments
 (0)