diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 9651df23e..e70c6b50a 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -16,6 +16,8 @@ "SyntheticDataKit", ] import subprocess +import threading +from collections import deque import time import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" @@ -23,6 +25,7 @@ import torch import gc import time +import re from unsloth_zoo.vllm_utils import ( load_vllm, patch_vllm, @@ -35,6 +38,100 @@ synthetic_qa_config, ) +def terminate_tree(proc: subprocess.Popen, timeout=15): + if proc is None or proc.poll() is not None: + return + + try: + import psutil + parent = psutil.Process(proc.pid) + for child in parent.children(recursive=True): + child.terminate() + parent.terminate() + parent.wait(timeout=timeout/2) + return + except: + pass + + if os.name == 'nt': + try: + subprocess.run( + ['taskkill', '/T', '/F', '/PID', str(proc.pid)], + capture_output=True, + timeout=5 + ) + proc.wait(timeout=1) + return + except: + pass + + proc.kill() + try: + proc.wait(timeout=5) + except: + pass + +class PipeCapture: + """Non blocking pipe capture""" + def __init__(self, pipe, keep_lines=2000, echo=False, name="", text=True, encoding='utf-8', errors='replace', ready_regex=None): + self.pipe = pipe + self.buf = deque(maxlen=keep_lines) + self.lock = threading.Lock() + self.echo = echo + self.name = name + self.text = text + self.encoding = encoding + self.errors = errors + + self.ready_event = threading.Event() + self.closed_event = threading.Event() + + self.ready_regex = None + if ready_regex is not None: + if not hasattr(ready_regex, "search"): + ready_regex = re.compile(ready_regex) + self.ready_regex = ready_regex + + self.t = threading.Thread(target=self._reader, daemon=True) + self.t.start() + + def _reader(self): + try: + sentinel = '' if self.text else b'' + for raw_line in iter(self.pipe.readline, sentinel): + if not self.text: + line = raw_line.decode(self.encoding, self.errors) + else: + line = raw_line + line = line.rstrip('\r\n') + if self.echo: + if "platform is" not in line: + print(f"{self.name}: {line}") + + with self.lock: + self.buf.append(line) + + if self.ready_regex is not None and self.ready_regex.search(line): + self.ready_event.set() + + finally: + try: self.pipe.close() + except Exception: pass + self.closed_event.set() + + def wait_for_ready(self, timeout=None): + return self.ready_event.wait(timeout) + + def has_closed(self): + return self.closed_event.is_set() + + def wait_until_closed(self, timeout=None): + return self.closed_event.wait(timeout) + + def tail(self, n=200): + with self.lock: + return '\n'.join(list(self.buf)[-n:]) + class SyntheticDataKit: def __init__( self, @@ -44,6 +141,7 @@ def __init__( float8_kv_cache = False, conservativeness = 1.0, token = None, + timeout = 1200, # maybe this is not enough for large models if we need to download **kwargs, ): assert(type(model_name) is str) @@ -128,30 +226,40 @@ def __init__( stderr = subprocess.PIPE, start_new_session = True, ) + ready_re = re.compile(r"Starting vLLM API server(?:\s+\d+)?\s+on\b") self.vllm_process = vllm_process + self.stdout_capture = PipeCapture(vllm_process.stdout, keep_lines = 1000, + echo = True, name = "vLLM STDOUT", + ready_regex = ready_re, text = False) + self.stderr_capture = PipeCapture(vllm_process.stderr, keep_lines = 2000, + echo = False, name = "vLLM STDERR", + ready_regex = None, text = False) + # we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines - ready_message_part = b"Starting vLLM API server on" - ready = False - while vllm_process.poll() is None: - output = vllm_process.stdout.readline() - if not output: + ready = self.stdout_capture.wait_for_ready(timeout = timeout) + if not ready: + if self.stdout_capture.has_closed() or self.vllm_process.poll() is not None: print("Stdout stream ended before readiness message detected.") - break - output_str = output.decode('utf-8', errors='ignore').strip() - if "platform is" not in output_str: - print(f"vLLM STDOUT: {output_str}") - if ready_message_part in output: - print(f"\n--- vLLM Server Ready (Detected: '{ready_message_part.decode()}') ---") - ready = True - break - pass + print("\n--- stdout tail ---\n", self.stdout_capture.tail(50)) + print("\n--- stderr tail ---\n", self.stderr_capture.tail(50)) + else: + print(f"Unsloth: vllm_process failed to load! (timeout={timeout})") + print("\n--- stdout tail ---\n", self.stdout_capture.tail(50)) + print("\n--- stderr tail ---\n", self.stderr_capture.tail(50)) + terminate_tree(self.vllm_process) + return + else: + print("vLLM Server Ready Detected") pass - if vllm_process is None: - raise RuntimeError("Unsloth: vllm_process failed to load!") + trial = 0 while not self.check_vllm_status(): if trial >= 100: - raise RuntimeError("Unsloth: vllm_process failed to load!") + print("Unsloth: vllm_process failed to load!") + print("\n--- stdout tail ---\n", self.stdout_capture.tail(50)) + print("\n--- stderr tail ---\n", self.stderr_capture.tail(50)) + terminate_tree(self.vllm_process) + return trial += 1 time.sleep(1) return