|
1 | 1 | from lazyllm import LOG |
2 | | - |
| 2 | +import re |
| 3 | +from typing import Optional, Dict |
3 | 4 | MODEL_MAPPING = { |
4 | 5 | # ===== OpenAI (LLM) ===== |
5 | 6 | 'gpt-5': 'vlm', |
|
435 | 436 | 'deepseek-chat': 'llm', |
436 | 437 | 'deepseek-reasoner': 'llm' |
437 | 438 | } |
| 439 | +_TOKEN_MAP = { |
| 440 | + 'embed': ('embedding', 'embed'), |
| 441 | + 'stt': ('whisper', 'paraformer', 'asr', 'stt', 'transcribe'), |
| 442 | + 'tts': ('tts', 'cosyvoice', 'nova-tts'), |
| 443 | + 'vlm': ('qwen-vl', 'vl', 'vision', 'caption', 'omni', 'vlm', 'seed'), |
| 444 | + 'ocr': ('ocr',), |
| 445 | + 'rerank': ('rerank',), |
| 446 | + 'cross_modal_embed': ('cross_modal', 'multimodal-embedding', 'embedding-vision'), |
| 447 | + 'sd': ('dall', 'wan', 'sora', 'image', 'video', 't2i', 't2v'), |
| 448 | +} |
| 449 | +_SUFFIX_RE = re.compile( |
| 450 | + r'(?:' |
| 451 | + r'|[-_.]\d{4}(?:-\d{2}-\d{2})?' # date-like suffix |
| 452 | + r'|[-_.]v?\d+[a-z0-9\-]*' # version-like suffix |
| 453 | + r')+$', |
| 454 | + flags=re.I |
| 455 | +) |
438 | 456 |
|
| 457 | +def _normalize_key(name: str) -> str: |
| 458 | + '''Normalize model name for consistent comparison.''' |
| 459 | + if not name: |
| 460 | + return '' |
| 461 | + s = name.strip().lower() |
| 462 | + s = re.sub(r'[^a-z0-9]+', '-', s) |
| 463 | + s = s.strip('-') |
| 464 | + s = re.sub(_SUFFIX_RE, '', s) |
| 465 | + return s |
439 | 466 |
|
440 | | -def special_model_rule(model_name: str): |
441 | | - '''Keyword-based matching''' |
442 | | - return MODEL_MAPPING.get(model_name) |
| 467 | +def _contains_token(name: str, token: str) -> bool: |
| 468 | + '''Check if token exists as a separate unit in the model name.''' |
| 469 | + if not token: |
| 470 | + return False |
| 471 | + patterns = [ |
| 472 | + rf'(^|[-_.]){re.escape(token)}($|[-_.])', |
| 473 | + rf'{re.escape(token)}$', |
| 474 | + rf'^{re.escape(token)}' |
| 475 | + ] |
| 476 | + return any(re.search(p, name) for p in patterns) |
443 | 477 |
|
| 478 | +NORMALIZED_MODEL_MAPPING: Dict[str, str] = { |
| 479 | + _normalize_key(k): v for k, v in MODEL_MAPPING.items() |
| 480 | +} |
444 | 481 |
|
445 | | -def feature_keyword_rule(model_name: str): |
446 | | - '''Exact match''' |
447 | | - lower_name = model_name.lower() |
448 | | - if 'embedding' in lower_name: |
449 | | - return 'embed' |
450 | | - if 'whisper' in lower_name or 'paraformer' in lower_name or 'asr' in lower_name or 'stt' in lower_name: |
451 | | - return 'stt' |
452 | | - if 'tts' in lower_name or 'cosyvoice' in lower_name: |
453 | | - return 'tts' |
454 | | - if 'vl' in lower_name or 'vision' in lower_name or 'caption' in lower_name: |
455 | | - return 'vlm' |
456 | | - if 'ocr' in lower_name: |
457 | | - return 'ocr' |
458 | | - if 'rerank' in lower_name: |
459 | | - return 'rerank' |
460 | | - if 'cross_modal' in lower_name or 'multimodal-embedding' in lower_name: |
461 | | - return 'cross_modal_embed' |
462 | | - if any(kw in lower_name for kw in ('dall', 'cogview', 'wan', 'sd', 'image', 'video')): |
463 | | - return 'sd' |
464 | | - return None |
| 482 | +def feature_keyword_rule(model_name: str) -> Optional[str]: |
| 483 | + '''Identify the model type by normalized name or keyword match.''' |
| 484 | + stripped_input = _normalize_key(model_name) |
| 485 | + if stripped_input in NORMALIZED_MODEL_MAPPING: |
| 486 | + return NORMALIZED_MODEL_MAPPING[stripped_input] |
465 | 487 |
|
466 | | -def get_model_type(model_name) -> str: |
| 488 | + for model_type, tokens in _TOKEN_MAP.items(): |
| 489 | + for tok in tokens: |
| 490 | + if _contains_token(stripped_input, tok): |
| 491 | + return model_type |
| 492 | + return None |
| 493 | +def special_model_rule(model_name: str) -> Optional[str]: |
467 | 494 | '''Determine the model category''' |
468 | | - for rule in [special_model_rule, feature_keyword_rule]: |
| 495 | + return MODEL_MAPPING.get(model_name) |
| 496 | + |
| 497 | +def get_model_type(model_name: str) -> str: |
| 498 | + model_name = model_name.lower() |
| 499 | + if not model_name: |
| 500 | + return 'llm' |
| 501 | + for rule in (special_model_rule, feature_keyword_rule): |
469 | 502 | try: |
470 | 503 | result = rule(model_name) |
471 | | - if result is not None: |
| 504 | + if result: |
472 | 505 | LOG.info(f'Model: {model_name} classified as type: {result} by rule: {rule.__name__}') |
473 | 506 | return result |
474 | 507 | except Exception as e: |
475 | | - LOG.warning(f'Rule {rule.__name__} execution error: {e}') |
476 | | - continue |
477 | | - |
| 508 | + LOG.warning(f'Rule {rule.__name__} failed: {e}') |
478 | 509 | LOG.warning(f'Cannot classify model type for: {model_name}. Defaulting to "llm" instead.') |
479 | 510 | return 'llm' |
0 commit comments