diff --git a/xinference/constants.py b/xinference/constants.py index 8db65ac13e..6b5758a86d 100644 --- a/xinference/constants.py +++ b/xinference/constants.py @@ -44,11 +44,12 @@ def get_xinference_home() -> str: home_path = os.environ.get(XINFERENCE_ENV_HOME_PATH) if home_path is None: home_path = str(Path.home() / ".xinference") - else: - # if user has already set `XINFERENCE_HOME` env, change huggingface and modelscope default download path - os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(home_path, "huggingface") - os.environ["MODELSCOPE_CACHE"] = os.path.join(home_path, "modelscope") - os.environ["XDG_CACHE_HOME"] = os.path.join(home_path, "openmind_hub") + # Always change huggingface, modelscope, and openmind_hub default download path + # to ensure xinference process has write permissions for downloading dependencies + # (e.g., Qwen3-ASR's forced aligner model downloaded from Hugging Face Hub) + os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(home_path, "huggingface") + os.environ["MODELSCOPE_CACHE"] = os.path.join(home_path, "modelscope") + os.environ["XDG_CACHE_HOME"] = os.path.join(home_path, "openmind_hub") # In multi-tenant mode, # gradio's temporary files are stored in their respective home directories, # to prevent insufficient permissions diff --git a/xinference/model/audio/core.py b/xinference/model/audio/core.py index 373500b882..38089121be 100644 --- a/xinference/model/audio/core.py +++ b/xinference/model/audio/core.py @@ -29,6 +29,7 @@ from .kokoro_zh import KokoroZHModel from .megatts import MegaTTSModel from .melotts import MeloTTSModel +from .qwen3_asr import Qwen3ASRModel from .whisper import WhisperModel from .whisper_mlx import WhisperMLXModel @@ -155,6 +156,7 @@ def create_audio_model_instance( KokoroZHModel, MegaTTSModel, Indextts2, + Qwen3ASRModel, ]: from ..cache_manager import CacheManager @@ -178,6 +180,7 @@ def create_audio_model_instance( KokoroZHModel, MegaTTSModel, Indextts2, + Qwen3ASRModel, ] if model_spec.model_family == "whisper": if not model_spec.engine: @@ -208,6 +211,8 @@ def create_audio_model_instance( model = MegaTTSModel(model_uid, model_path, model_spec, **kwargs) elif model_spec.model_family == "IndexTTS2": model = Indextts2(model_uid, model_path, model_spec, **kwargs) + elif model_spec.model_family == "qwen3_asr": + model = Qwen3ASRModel(model_uid, model_path, model_spec, **kwargs) else: raise Exception(f"Unsupported audio model family: {model_spec.model_family}") return model diff --git a/xinference/model/audio/qwen3_asr.py b/xinference/model/audio/qwen3_asr.py new file mode 100644 index 0000000000..215c61d8f5 --- /dev/null +++ b/xinference/model/audio/qwen3_asr.py @@ -0,0 +1,154 @@ +# Copyright 2022-2026 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import tempfile +from typing import TYPE_CHECKING, List, Optional, Tuple + +from ...device_utils import ( + get_available_device, + get_device_preferred_dtype, + is_device_available, +) + +if TYPE_CHECKING: + from .core import AudioModelFamilyV2 + +logger = logging.getLogger(__name__) + + +class Qwen3ASRModel: + def __init__( + self, + model_uid: str, + model_path: str, + model_spec: "AudioModelFamilyV2", + device: Optional[str] = None, + **kwargs, + ): + self.model_family = model_spec + self._model_uid = model_uid + self._model_path = model_path + self._model_spec = model_spec + self._device = device + self._model = None + self._kwargs = kwargs + + @property + def model_ability(self): + return self._model_spec.model_ability + + def load(self): + try: + from qwen_asr import Qwen3ASRModel as QwenASR + except ImportError: + error_message = "Failed to import module 'qwen_asr'" + installation_guide = [ + "Please make sure 'qwen-asr' is installed. ", + "You can install it by `pip install qwen-asr`\n", + ] + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + if self._device is None: + self._device = get_available_device() + else: + if not is_device_available(self._device): + raise ValueError(f"Device {self._device} is not available!") + + init_kwargs = ( + self._model_spec.default_model_config.copy() + if getattr(self._model_spec, "default_model_config", None) + else {} + ) + init_kwargs.update(self._kwargs) + init_kwargs.setdefault("device_map", self._device) + init_kwargs.setdefault("dtype", get_device_preferred_dtype(self._device)) + if "forced_aligner" in init_kwargs: + forced_aligner_kwargs = init_kwargs.get("forced_aligner_kwargs") or {} + forced_aligner_kwargs.setdefault("device_map", self._device) + forced_aligner_kwargs.setdefault( + "dtype", get_device_preferred_dtype(self._device) + ) + init_kwargs["forced_aligner_kwargs"] = forced_aligner_kwargs + logger.debug("Loading Qwen3-ASR model with kwargs: %s", init_kwargs) + self._model = QwenASR.from_pretrained(self._model_path, **init_kwargs) + + def _extract_text_and_language(self, result) -> Tuple[str, Optional[str]]: + if isinstance(result, list): + if not result: + return "", None + result = result[0] + + if hasattr(result, "text"): + text = result.text + language = getattr(result, "language", None) + return text, language + + if isinstance(result, dict): + text = result.get("text") or result.get("transcript") or "" + language = result.get("language") + return text, language + + return str(result), None + + def transcriptions( + self, + audio: bytes, + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: str = "json", + temperature: float = 0, + timestamp_granularities: Optional[List[str]] = None, + **kwargs, + ): + if temperature != 0: + raise RuntimeError("`temperature` is not supported for Qwen3-ASR") + if timestamp_granularities is not None: + raise RuntimeError( + "`timestamp_granularities` is not supported for Qwen3-ASR" + ) + if prompt is not None: + logger.warning( + "Prompt for Qwen3-ASR transcriptions will be ignored: %s", prompt + ) + + kw = dict(getattr(self._model_spec, "default_transcription_config", None) or {}) + kw.update(kwargs) + + with tempfile.NamedTemporaryFile(buffering=0) as f: + f.write(audio) + assert self._model is not None + result = self._model.transcribe(audio=f.name, language=language, **kw) + text, detected_language = self._extract_text_and_language(result) + + if response_format == "json": + return {"text": text} + if response_format == "verbose_json": + return { + "task": "transcribe", + "language": detected_language, + "text": text, + } + raise ValueError(f"Unsupported response format: {response_format}") + + def translations( + self, + audio: bytes, + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: str = "json", + temperature: float = 0, + timestamp_granularities: Optional[List[str]] = None, + ): + raise RuntimeError("Qwen3-ASR does not support translations API")