@@ -61,130 +61,124 @@ def __init__(self, od_config: OmniDiffusionConfig):
6161 raise e
6262
6363 def step (self , requests : list [OmniDiffusionRequest ]):
64- try :
65- # Apply pre-processing if available
66- if self .pre_process_func is not None :
67- preprocess_start_time = time .time ()
68- requests = self .pre_process_func (requests )
69- preprocess_time = time .time () - preprocess_start_time
70- logger .info (f"Pre-processing completed in { preprocess_time :.4f} seconds" )
71-
72- output = self .add_req_and_wait_for_response (requests )
73- if output .error :
74- raise Exception (f"{ output .error } " )
75- logger .info ("Generation completed successfully." )
76-
77- if output .output is None :
78- logger .warning ("Output is None, returning empty OmniRequestOutput" )
79- # Return empty output for the first request
80- if len (requests ) > 0 :
81- request = requests [0 ]
82- request_id = request .request_id or ""
83- prompt = request .prompt
84- if isinstance (prompt , list ):
85- prompt = prompt [0 ] if prompt else None
86- return OmniRequestOutput .from_diffusion (
87- request_id = request_id ,
88- images = [],
89- prompt = prompt ,
90- metrics = {},
91- latents = None ,
92- )
93- return None
94-
95- postprocess_start_time = time .time ()
96- outputs = self .post_process_func (output .output ) if self .post_process_func is not None else output .output
97- postprocess_time = time .time () - postprocess_start_time
98- logger .info (f"Post-processing completed in { postprocess_time :.4f} seconds" )
64+ # Apply pre-processing if available
65+ if self .pre_process_func is not None :
66+ preprocess_start_time = time .time ()
67+ requests = self .pre_process_func (requests )
68+ preprocess_time = time .time () - preprocess_start_time
69+ logger .info (f"Pre-processing completed in { preprocess_time :.4f} seconds" )
70+
71+ output = self .add_req_and_wait_for_response (requests )
72+ if output .error :
73+ raise Exception (f"{ output .error } " )
74+ logger .info ("Generation completed successfully." )
75+
76+ if output .output is None :
77+ logger .warning ("Output is None, returning empty OmniRequestOutput" )
78+ # Return empty output for the first request
79+ if len (requests ) > 0 :
80+ request = requests [0 ]
81+ request_id = request .request_id or ""
82+ prompt = request .prompt
83+ if isinstance (prompt , list ):
84+ prompt = prompt [0 ] if prompt else None
85+ return OmniRequestOutput .from_diffusion (
86+ request_id = request_id ,
87+ images = [],
88+ prompt = prompt ,
89+ metrics = {},
90+ latents = None ,
91+ )
92+ return None
9993
100- # Convert to OmniRequestOutput format
101- # Ensure outputs is a list
102- if not isinstance (outputs , list ):
103- outputs = [outputs ] if outputs is not None else []
94+ postprocess_start_time = time .time ()
95+ outputs = self .post_process_func (output .output ) if self .post_process_func is not None else output .output
96+ postprocess_time = time .time () - postprocess_start_time
97+ logger .info (f"Post-processing completed in { postprocess_time :.4f} seconds" )
98+
99+ # Convert to OmniRequestOutput format
100+ # Ensure outputs is a list
101+ if not isinstance (outputs , list ):
102+ outputs = [outputs ] if outputs is not None else []
103+
104+ # Handle single request or multiple requests
105+ if len (requests ) == 1 :
106+ # Single request: return single OmniRequestOutput
107+ request = requests [0 ]
108+ request_id = request .request_id or ""
109+ prompt = request .prompt
110+ if isinstance (prompt , list ):
111+ prompt = prompt [0 ] if prompt else None
112+
113+ metrics = {}
114+ if output .trajectory_timesteps is not None :
115+ metrics ["trajectory_timesteps" ] = output .trajectory_timesteps
116+
117+ if supports_audio_output (self .od_config .model_class_name ):
118+ audio_payload = outputs [0 ] if len (outputs ) == 1 else outputs
119+ return OmniRequestOutput .from_diffusion (
120+ request_id = request_id ,
121+ images = [],
122+ prompt = prompt ,
123+ metrics = metrics ,
124+ latents = output .trajectory_latents ,
125+ multimodal_output = {"audio" : audio_payload },
126+ final_output_type = "audio" ,
127+ )
128+ else :
129+ return OmniRequestOutput .from_diffusion (
130+ request_id = request_id ,
131+ images = outputs ,
132+ prompt = prompt ,
133+ metrics = metrics ,
134+ latents = output .trajectory_latents ,
135+ )
136+ else :
137+ # Multiple requests: return list of OmniRequestOutput
138+ # Split images based on num_outputs_per_prompt for each request
139+ results = []
140+ output_idx = 0
104141
105- # Handle single request or multiple requests
106- if len (requests ) == 1 :
107- # Single request: return single OmniRequestOutput
108- request = requests [0 ]
142+ for request in requests :
109143 request_id = request .request_id or ""
110144 prompt = request .prompt
111145 if isinstance (prompt , list ):
112146 prompt = prompt [0 ] if prompt else None
113147
148+ # Get images for this request
149+ num_outputs = request .num_outputs_per_prompt
150+ request_outputs = outputs [output_idx : output_idx + num_outputs ] if output_idx < len (outputs ) else []
151+ output_idx += num_outputs
152+
114153 metrics = {}
115154 if output .trajectory_timesteps is not None :
116155 metrics ["trajectory_timesteps" ] = output .trajectory_timesteps
117156
118157 if supports_audio_output (self .od_config .model_class_name ):
119- audio_payload = outputs [0 ] if len (outputs ) == 1 else outputs
120- return OmniRequestOutput .from_diffusion (
121- request_id = request_id ,
122- images = [],
123- prompt = prompt ,
124- metrics = metrics ,
125- latents = output .trajectory_latents ,
126- multimodal_output = {"audio" : audio_payload },
127- final_output_type = "audio" ,
158+ audio_payload = request_outputs [0 ] if len (request_outputs ) == 1 else request_outputs
159+ results .append (
160+ OmniRequestOutput .from_diffusion (
161+ request_id = request_id ,
162+ images = [],
163+ prompt = prompt ,
164+ metrics = metrics ,
165+ latents = output .trajectory_latents ,
166+ multimodal_output = {"audio" : audio_payload },
167+ final_output_type = "audio" ,
168+ )
128169 )
129170 else :
130- return OmniRequestOutput .from_diffusion (
131- request_id = request_id ,
132- images = outputs ,
133- prompt = prompt ,
134- metrics = metrics ,
135- latents = output .trajectory_latents ,
136- )
137- else :
138- # Multiple requests: return list of OmniRequestOutput
139- # Split images based on num_outputs_per_prompt for each request
140- results = []
141- output_idx = 0
142-
143- for request in requests :
144- request_id = request .request_id or ""
145- prompt = request .prompt
146- if isinstance (prompt , list ):
147- prompt = prompt [0 ] if prompt else None
148-
149- # Get images for this request
150- num_outputs = request .num_outputs_per_prompt
151- request_outputs = (
152- outputs [output_idx : output_idx + num_outputs ] if output_idx < len (outputs ) else []
153- )
154- output_idx += num_outputs
155-
156- metrics = {}
157- if output .trajectory_timesteps is not None :
158- metrics ["trajectory_timesteps" ] = output .trajectory_timesteps
159-
160- if supports_audio_output (self .od_config .model_class_name ):
161- audio_payload = request_outputs [0 ] if len (request_outputs ) == 1 else request_outputs
162- results .append (
163- OmniRequestOutput .from_diffusion (
164- request_id = request_id ,
165- images = [],
166- prompt = prompt ,
167- metrics = metrics ,
168- latents = output .trajectory_latents ,
169- multimodal_output = {"audio" : audio_payload },
170- final_output_type = "audio" ,
171- )
172- )
173- else :
174- results .append (
175- OmniRequestOutput .from_diffusion (
176- request_id = request_id ,
177- images = request_outputs ,
178- prompt = prompt ,
179- metrics = metrics ,
180- latents = output .trajectory_latents ,
181- )
171+ results .append (
172+ OmniRequestOutput .from_diffusion (
173+ request_id = request_id ,
174+ images = request_outputs ,
175+ prompt = prompt ,
176+ metrics = metrics ,
177+ latents = output .trajectory_latents ,
182178 )
179+ )
183180
184- return results
185- except Exception as e :
186- logger .error (f"Generation failed: { e } " )
187- return None
181+ return results
188182
189183 @staticmethod
190184 def make_engine (config : OmniDiffusionConfig ) -> "DiffusionEngine" :
0 commit comments