Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 125 additions & 17 deletions unsloth/dataprep/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
"SyntheticDataKit",
]
import subprocess
import threading
from collections import deque
import time
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import requests
import torch
import gc
import time
import re
from unsloth_zoo.vllm_utils import (
load_vllm,
patch_vllm,
Expand All @@ -35,6 +38,100 @@
synthetic_qa_config,
)

def terminate_tree(proc: subprocess.Popen, timeout=15):
if proc is None or proc.poll() is not None:
return

try:
import psutil
parent = psutil.Process(proc.pid)
for child in parent.children(recursive=True):
child.terminate()
parent.terminate()
parent.wait(timeout=timeout/2)
return
except:
pass

if os.name == 'nt':
try:
subprocess.run(
['taskkill', '/T', '/F', '/PID', str(proc.pid)],
capture_output=True,
timeout=5
)
proc.wait(timeout=1)
return
except:
pass

proc.kill()
try:
proc.wait(timeout=5)
except:
pass

class PipeCapture:
"""Non blocking pipe capture"""
def __init__(self, pipe, keep_lines=2000, echo=False, name="", text=True, encoding='utf-8', errors='replace', ready_regex=None):
self.pipe = pipe
self.buf = deque(maxlen=keep_lines)
self.lock = threading.Lock()
self.echo = echo
self.name = name
self.text = text
self.encoding = encoding
self.errors = errors

self.ready_event = threading.Event()
self.closed_event = threading.Event()

self.ready_regex = None
if ready_regex is not None:
if not hasattr(ready_regex, "search"):
ready_regex = re.compile(ready_regex)
self.ready_regex = ready_regex

self.t = threading.Thread(target=self._reader, daemon=True)
self.t.start()

def _reader(self):
try:
sentinel = '' if self.text else b''
for raw_line in iter(self.pipe.readline, sentinel):
if not self.text:
line = raw_line.decode(self.encoding, self.errors)
else:
line = raw_line
line = line.rstrip('\r\n')
if self.echo:
if "platform is" not in line:
print(f"{self.name}: {line}")

with self.lock:
self.buf.append(line)

if self.ready_regex is not None and self.ready_regex.search(line):
self.ready_event.set()

finally:
try: self.pipe.close()
except Exception: pass
self.closed_event.set()

def wait_for_ready(self, timeout=None):
return self.ready_event.wait(timeout)

def has_closed(self):
return self.closed_event.is_set()

def wait_until_closed(self, timeout=None):
return self.closed_event.wait(timeout)

def tail(self, n=200):
with self.lock:
return '\n'.join(list(self.buf)[-n:])

class SyntheticDataKit:
def __init__(
self,
Expand All @@ -44,6 +141,7 @@ def __init__(
float8_kv_cache = False,
conservativeness = 1.0,
token = None,
timeout = 1200, # maybe this is not enough for large models if we need to download
**kwargs,
):
assert(type(model_name) is str)
Expand Down Expand Up @@ -128,30 +226,40 @@ def __init__(
stderr = subprocess.PIPE,
start_new_session = True,
)
ready_re = re.compile(r"Starting vLLM API server(?:\s+\d+)?\s+on\b")
self.vllm_process = vllm_process
self.stdout_capture = PipeCapture(vllm_process.stdout, keep_lines = 1000,
echo = True, name = "vLLM STDOUT",
ready_regex = ready_re, text = False)
self.stderr_capture = PipeCapture(vllm_process.stderr, keep_lines = 2000,
echo = False, name = "vLLM STDERR",
ready_regex = None, text = False)
# we don't print stderr to console but self.stderr_capture.tail(200) will print the last 200 lines

ready_message_part = b"Starting vLLM API server on"
ready = False
while vllm_process.poll() is None:
output = vllm_process.stdout.readline()
if not output:
ready = self.stdout_capture.wait_for_ready(timeout = timeout)
if not ready:
if self.stdout_capture.has_closed() or self.vllm_process.poll() is not None:
print("Stdout stream ended before readiness message detected.")
break
output_str = output.decode('utf-8', errors='ignore').strip()
if "platform is" not in output_str:
print(f"vLLM STDOUT: {output_str}")
if ready_message_part in output:
print(f"\n--- vLLM Server Ready (Detected: '{ready_message_part.decode()}') ---")
ready = True
break
pass
print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
else:
print(f"Unsloth: vllm_process failed to load! (timeout={timeout})")
print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
terminate_tree(self.vllm_process)
return
else:
print("vLLM Server Ready Detected")
pass
if vllm_process is None:
raise RuntimeError("Unsloth: vllm_process failed to load!")

trial = 0
while not self.check_vllm_status():
if trial >= 100:
raise RuntimeError("Unsloth: vllm_process failed to load!")
print("Unsloth: vllm_process failed to load!")
print("\n--- stdout tail ---\n", self.stdout_capture.tail(50))
print("\n--- stderr tail ---\n", self.stderr_capture.tail(50))
terminate_tree(self.vllm_process)
return
trial += 1
time.sleep(1)
return
Expand Down