Skip to content

Commit 578ed19

Browse files
[Bugfix] raise error in diffusion engine and fix offload test (#933)
Signed-off-by: zjy0516 <[email protected]> Co-authored-by: Hongsheng Liu <[email protected]>
1 parent 7f08258 commit 578ed19

4 files changed

Lines changed: 137 additions & 134 deletions

File tree

tests/e2e/offline_inference/test_diffusion_cpu_offload.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
67

78
from tests.utils import GPUMemoryMonitor
89
from vllm_omni.utils.platform_utils import is_npu, is_rocm
@@ -17,34 +18,39 @@
1718
models = ["riverclouds/qwen_image_random"]
1819

1920

21+
def inference(model_name: str, offload: bool = True):
22+
torch.cuda.empty_cache()
23+
device_index = torch.cuda.current_device()
24+
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
25+
monitor.start()
26+
m = Omni(model=model_name, enable_cpu_offload=offload)
27+
torch.cuda.reset_peak_memory_stats(device=device_index)
28+
height = 256
29+
width = 256
30+
31+
m.generate(
32+
"a photo of a cat sitting on a laptop keyboard",
33+
height=height,
34+
width=width,
35+
num_inference_steps=9,
36+
guidance_scale=0.0,
37+
generator=torch.Generator("cuda").manual_seed(42),
38+
)
39+
peak = monitor.peak_used_mb
40+
monitor.stop()
41+
42+
return peak
43+
44+
2045
@pytest.mark.skipif(is_npu() or is_rocm(), reason="Hardware not supported")
2146
@pytest.mark.parametrize("model_name", models)
2247
def test_cpu_offload_diffusion_model(model_name: str):
23-
def inference(offload: bool = True):
24-
torch.cuda.empty_cache()
25-
device_index = torch.cuda.current_device()
26-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
27-
monitor.start()
28-
m = Omni(model=model_name, enable_cpu_offload=offload)
29-
torch.cuda.reset_peak_memory_stats(device=device_index)
30-
height = 256
31-
width = 256
32-
33-
m.generate(
34-
"a photo of a cat sitting on a laptop keyboard",
35-
height=height,
36-
width=width,
37-
num_inference_steps=9,
38-
guidance_scale=0.0,
39-
generator=torch.Generator("cuda").manual_seed(42),
40-
)
41-
peak = monitor.peak_used_mb
42-
monitor.stop()
43-
44-
return peak
45-
46-
offload_peak_memory = inference(offload=True)
47-
no_offload_peak_memory = inference(offload=False)
48+
try:
49+
no_offload_peak_memory = inference(model_name, offload=False)
50+
cleanup_dist_env_and_memory()
51+
offload_peak_memory = inference(model_name, offload=True)
52+
except Exception:
53+
pytest.fail("Inference failed")
4854
print(f"Offload peak memory: {offload_peak_memory} MB")
4955
print(f"No offload peak memory: {no_offload_peak_memory} MB")
5056
assert offload_peak_memory + 2500 < no_offload_peak_memory, (

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 102 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -61,130 +61,124 @@ def __init__(self, od_config: OmniDiffusionConfig):
6161
raise e
6262

6363
def step(self, requests: list[OmniDiffusionRequest]):
64-
try:
65-
# Apply pre-processing if available
66-
if self.pre_process_func is not None:
67-
preprocess_start_time = time.time()
68-
requests = self.pre_process_func(requests)
69-
preprocess_time = time.time() - preprocess_start_time
70-
logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds")
71-
72-
output = self.add_req_and_wait_for_response(requests)
73-
if output.error:
74-
raise Exception(f"{output.error}")
75-
logger.info("Generation completed successfully.")
76-
77-
if output.output is None:
78-
logger.warning("Output is None, returning empty OmniRequestOutput")
79-
# Return empty output for the first request
80-
if len(requests) > 0:
81-
request = requests[0]
82-
request_id = request.request_id or ""
83-
prompt = request.prompt
84-
if isinstance(prompt, list):
85-
prompt = prompt[0] if prompt else None
86-
return OmniRequestOutput.from_diffusion(
87-
request_id=request_id,
88-
images=[],
89-
prompt=prompt,
90-
metrics={},
91-
latents=None,
92-
)
93-
return None
94-
95-
postprocess_start_time = time.time()
96-
outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output
97-
postprocess_time = time.time() - postprocess_start_time
98-
logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds")
64+
# Apply pre-processing if available
65+
if self.pre_process_func is not None:
66+
preprocess_start_time = time.time()
67+
requests = self.pre_process_func(requests)
68+
preprocess_time = time.time() - preprocess_start_time
69+
logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds")
70+
71+
output = self.add_req_and_wait_for_response(requests)
72+
if output.error:
73+
raise Exception(f"{output.error}")
74+
logger.info("Generation completed successfully.")
75+
76+
if output.output is None:
77+
logger.warning("Output is None, returning empty OmniRequestOutput")
78+
# Return empty output for the first request
79+
if len(requests) > 0:
80+
request = requests[0]
81+
request_id = request.request_id or ""
82+
prompt = request.prompt
83+
if isinstance(prompt, list):
84+
prompt = prompt[0] if prompt else None
85+
return OmniRequestOutput.from_diffusion(
86+
request_id=request_id,
87+
images=[],
88+
prompt=prompt,
89+
metrics={},
90+
latents=None,
91+
)
92+
return None
9993

100-
# Convert to OmniRequestOutput format
101-
# Ensure outputs is a list
102-
if not isinstance(outputs, list):
103-
outputs = [outputs] if outputs is not None else []
94+
postprocess_start_time = time.time()
95+
outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output
96+
postprocess_time = time.time() - postprocess_start_time
97+
logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds")
98+
99+
# Convert to OmniRequestOutput format
100+
# Ensure outputs is a list
101+
if not isinstance(outputs, list):
102+
outputs = [outputs] if outputs is not None else []
103+
104+
# Handle single request or multiple requests
105+
if len(requests) == 1:
106+
# Single request: return single OmniRequestOutput
107+
request = requests[0]
108+
request_id = request.request_id or ""
109+
prompt = request.prompt
110+
if isinstance(prompt, list):
111+
prompt = prompt[0] if prompt else None
112+
113+
metrics = {}
114+
if output.trajectory_timesteps is not None:
115+
metrics["trajectory_timesteps"] = output.trajectory_timesteps
116+
117+
if supports_audio_output(self.od_config.model_class_name):
118+
audio_payload = outputs[0] if len(outputs) == 1 else outputs
119+
return OmniRequestOutput.from_diffusion(
120+
request_id=request_id,
121+
images=[],
122+
prompt=prompt,
123+
metrics=metrics,
124+
latents=output.trajectory_latents,
125+
multimodal_output={"audio": audio_payload},
126+
final_output_type="audio",
127+
)
128+
else:
129+
return OmniRequestOutput.from_diffusion(
130+
request_id=request_id,
131+
images=outputs,
132+
prompt=prompt,
133+
metrics=metrics,
134+
latents=output.trajectory_latents,
135+
)
136+
else:
137+
# Multiple requests: return list of OmniRequestOutput
138+
# Split images based on num_outputs_per_prompt for each request
139+
results = []
140+
output_idx = 0
104141

105-
# Handle single request or multiple requests
106-
if len(requests) == 1:
107-
# Single request: return single OmniRequestOutput
108-
request = requests[0]
142+
for request in requests:
109143
request_id = request.request_id or ""
110144
prompt = request.prompt
111145
if isinstance(prompt, list):
112146
prompt = prompt[0] if prompt else None
113147

148+
# Get images for this request
149+
num_outputs = request.num_outputs_per_prompt
150+
request_outputs = outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else []
151+
output_idx += num_outputs
152+
114153
metrics = {}
115154
if output.trajectory_timesteps is not None:
116155
metrics["trajectory_timesteps"] = output.trajectory_timesteps
117156

118157
if supports_audio_output(self.od_config.model_class_name):
119-
audio_payload = outputs[0] if len(outputs) == 1 else outputs
120-
return OmniRequestOutput.from_diffusion(
121-
request_id=request_id,
122-
images=[],
123-
prompt=prompt,
124-
metrics=metrics,
125-
latents=output.trajectory_latents,
126-
multimodal_output={"audio": audio_payload},
127-
final_output_type="audio",
158+
audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs
159+
results.append(
160+
OmniRequestOutput.from_diffusion(
161+
request_id=request_id,
162+
images=[],
163+
prompt=prompt,
164+
metrics=metrics,
165+
latents=output.trajectory_latents,
166+
multimodal_output={"audio": audio_payload},
167+
final_output_type="audio",
168+
)
128169
)
129170
else:
130-
return OmniRequestOutput.from_diffusion(
131-
request_id=request_id,
132-
images=outputs,
133-
prompt=prompt,
134-
metrics=metrics,
135-
latents=output.trajectory_latents,
136-
)
137-
else:
138-
# Multiple requests: return list of OmniRequestOutput
139-
# Split images based on num_outputs_per_prompt for each request
140-
results = []
141-
output_idx = 0
142-
143-
for request in requests:
144-
request_id = request.request_id or ""
145-
prompt = request.prompt
146-
if isinstance(prompt, list):
147-
prompt = prompt[0] if prompt else None
148-
149-
# Get images for this request
150-
num_outputs = request.num_outputs_per_prompt
151-
request_outputs = (
152-
outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else []
153-
)
154-
output_idx += num_outputs
155-
156-
metrics = {}
157-
if output.trajectory_timesteps is not None:
158-
metrics["trajectory_timesteps"] = output.trajectory_timesteps
159-
160-
if supports_audio_output(self.od_config.model_class_name):
161-
audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs
162-
results.append(
163-
OmniRequestOutput.from_diffusion(
164-
request_id=request_id,
165-
images=[],
166-
prompt=prompt,
167-
metrics=metrics,
168-
latents=output.trajectory_latents,
169-
multimodal_output={"audio": audio_payload},
170-
final_output_type="audio",
171-
)
172-
)
173-
else:
174-
results.append(
175-
OmniRequestOutput.from_diffusion(
176-
request_id=request_id,
177-
images=request_outputs,
178-
prompt=prompt,
179-
metrics=metrics,
180-
latents=output.trajectory_latents,
181-
)
171+
results.append(
172+
OmniRequestOutput.from_diffusion(
173+
request_id=request_id,
174+
images=request_outputs,
175+
prompt=prompt,
176+
metrics=metrics,
177+
latents=output.trajectory_latents,
182178
)
179+
)
183180

184-
return results
185-
except Exception as e:
186-
logger.error(f"Generation failed: {e}")
187-
return None
181+
return results
188182

189183
@staticmethod
190184
def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine":

vllm_omni/diffusion/scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
6262
raise RuntimeError("Result queue not initialized")
6363

6464
output = self.result_mq.dequeue()
65+
# {"status": "error", "error": str(e)}
66+
if isinstance(output, dict) and output.get("status") == "error":
67+
raise RuntimeError("worker error")
6568
return output
6669
except zmq.error.Again:
6770
logger.error("Timeout waiting for response from scheduler.")

vllm_omni/diffusion/worker/gpu_diffusion_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]:
279279
return result, should_reply
280280
except Exception as e:
281281
logger.error(f"Error executing RPC: {e}", exc_info=True)
282-
return {"status": "error", "error": str(e)}, should_reply
282+
raise e
283283

284284
def worker_busy_loop(self) -> None:
285285
"""Main busy loop for Multiprocessing Workers."""

0 commit comments

Comments
 (0)