1919
2020from QEfficient .diffusers .pipelines .config_manager import config_manager , set_module_device_ids
2121from QEfficient .diffusers .pipelines .pipeline_utils import (
22- QEffClipTextEncoder ,
2322 QEffFluxTransformerModel ,
2423 QEffTextEncoder ,
2524 QEffVAE ,
@@ -38,11 +37,13 @@ class QEFFFluxPipeline(FluxPipeline):
3837 """
3938
4039 def __init__ (self , model , use_onnx_function , * args , ** kwargs ):
41- self .text_encoder = QEffClipTextEncoder (model .text_encoder )
40+ self .text_encoder = QEffTextEncoder (model .text_encoder )
4241 self .text_encoder_2 = QEffTextEncoder (model .text_encoder_2 )
4342 self .transformer = QEffFluxTransformerModel (model .transformer , use_onnx_function = use_onnx_function )
4443 self .vae_decode = QEffVAE (model , "decoder" )
4544 self .use_onnx_function = use_onnx_function
45+
46+ # Add all modules of FluxPipeline
4647 self .has_module = [
4748 ("text_encoder" , self .text_encoder ),
4849 ("text_encoder_2" , self .text_encoder_2 ),
@@ -78,6 +79,10 @@ def __init__(self, model, use_onnx_function, *args, **kwargs):
7879 self .latent_width = self .width // self .vae_scale_factor
7980 self .cl = (self .latent_height * self .latent_width ) // 4
8081
82+ self .text_encoder_2 .model .config .max_position_embeddings = (
83+ self .text_encoder .model .config .max_position_embeddings
84+ )
85+
8186 @classmethod
8287 def from_pretrained (
8388 cls ,
@@ -140,6 +145,15 @@ def export(self, export_dir: Optional[str] = None) -> str:
140145 export_kwargs = export_kwargs ,
141146 )
142147
148+ def get_default_config_path ():
149+ """
150+ Returns the default configuration file path for Flux pipeline.
151+
152+ Returns:
153+ str: Path to the default flux_config.json file.
154+ """
155+ return os .path .join (os .path .dirname (__file__ ), "flux_config.json" )
156+
143157 def compile (
144158 self ,
145159 compile_config : Optional [str ] = None ,
@@ -193,7 +207,6 @@ def _get_t5_prompt_embeds(
193207 num_images_per_prompt : int = 1 ,
194208 max_sequence_length : int = 512 ,
195209 device_ids : Optional [List [int ]] = None ,
196- dtype : Optional [torch .dtype ] = None ,
197210 ):
198211 """
199212 Get T5 prompt embeddings for the given prompt(s).
@@ -203,7 +216,6 @@ def _get_t5_prompt_embeds(
203216 num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt.
204217 max_sequence_length (int, defaults to 256): Maximum sequence length for tokenization.
205218 device ids (Optional[torch.device], optional): The device to place tensors on QAIC device ids.
206- dtype (Optional[torch.dtype], optional): The data type for tensors.
207219
208220 Returns:
209221 torch.Tensor: The T5 prompt embeddings with shape (batch_size * num_images_per_prompt, seq_len, hidden_size).
@@ -245,12 +257,12 @@ def _get_t5_prompt_embeds(
245257 self .text_encoder_2 .qpc_session .set_buffers (text_encoder_2_output )
246258
247259 aic_text_input = {"input_ids" : text_input_ids .numpy ().astype (np .int64 )}
248- prompt_embeds = torch . tensor ( self . text_encoder_2 . qpc_session . run ( aic_text_input )[ "last_hidden_state" ])
260+ import time
249261
250- # # # AIC Testing
251- # prompt_embeds_pytorch = self.text_encoder_2.model(text_input_ids, output_hidden_states=False )
252- # mad = torch.abs(prompt_embeds_pytorch["last_hidden_state"] - prompt_embeds).mean ()
253- # print(">>>>>>>>>>>> MAD for text- encoder-2 - T5 => Pytorch vs AI 100:", mad )
262+ start_time = time . time ()
263+ prompt_embeds = torch . tensor ( self .text_encoder_2 .qpc_session . run ( aic_text_input )[ "last_hidden_state" ] )
264+ end_time = time . time ()
265+ print (f"T5 Text encoder inference time: { end_time - start_time :.4f } seconds" )
254266
255267 _ , seq_len , _ = prompt_embeds .shape
256268 # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
@@ -303,26 +315,21 @@ def _get_clip_prompt_embeds(
303315 self .text_encoder .qpc_session = QAICInferenceSession (str (self .text_encoder .qpc_path ), device_ids = device_ids )
304316
305317 text_encoder_output = {
306- "pooler_output" : np .random .rand (batch_size , embed_dim ).astype (np .int32 ),
307318 "last_hidden_state" : np .random .rand (batch_size , self .tokenizer_max_length , embed_dim ).astype (np .int32 ),
319+ "pooler_output" : np .random .rand (batch_size , embed_dim ).astype (np .int32 ),
308320 }
309321
310322 self .text_encoder .qpc_session .set_buffers (text_encoder_output )
311323
312324 aic_text_input = {"input_ids" : text_input_ids .numpy ().astype (np .int64 )}
313- aic_embeddings = self .text_encoder .qpc_session .run (aic_text_input )
314- # aic_text_encoder_emb = aic_embeddings["pooler_output"]
315325
316- # # # # [TEMP] CHECK ACC # #
317- # prompt_embeds_pytorch = self.text_encoder.model(text_input_ids, output_hidden_states=False)
318- # pt_pooled_embed = prompt_embeds_pytorch["pooler_output"].detach().numpy()
319- # mad = np.max(np.abs(pt_pooled_embed - aic_text_encoder_emb))
320- # print(f">>>>>>>>>>>> CLIP text encoder pooled embed MAD: ", mad) ## 0.0043082903 ##TODO : Clean up
321- ### END CHECK ACC ###
326+ import time
322327
323- # Use pooled output of CLIPTextModel
328+ start_time = time .time ()
329+ aic_embeddings = self .text_encoder .qpc_session .run (aic_text_input )
330+ end_time = time .time ()
331+ print (f"CLIP Text encoder inference time: { end_time - start_time :.4f} seconds" )
324332 prompt_embeds = torch .tensor (aic_embeddings ["pooler_output" ])
325- # prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
326333
327334 # duplicate text embeddings for each generation per prompt, using mps friendly method
328335 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt )
@@ -491,23 +498,6 @@ def __call__(
491498 [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
492499 is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
493500 images.
494-
495- Examples:
496- ```python
497- # Basic text-to-image generation
498- from QEfficient import QEFFFluxPipeline
499- pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
500- pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1)
501-
502- generator = torch.manual_seed(42)
503- # NOTE: guidance_scale <=1 is not supported
504- image = pipeline("A cat holding a sign that says hello world",
505- guidance_scale=0.0,
506- num_inference_steps=4,
507- max_sequence_length=256,
508- generator=generator).images[0]
509- image.save("flux-schnell_aic.png")
510- ```
511501 """
512502 device = "cpu"
513503
@@ -663,7 +653,7 @@ def __call__(
663653 start_time = time .time ()
664654 outputs = self .transformer .qpc_session .run (inputs_aic )
665655 end_time = time .time ()
666- print (f"Time : { end_time - start_time :.2f} seconds" )
656+ print (f"Transformers inference time : { end_time - start_time :.2f} seconds" )
667657
668658 noise_pred = torch .from_numpy (outputs ["output" ])
669659
@@ -711,8 +701,10 @@ def __call__(
711701 self .vae_decode .qpc_session .set_buffers (output_buffer )
712702
713703 inputs = {"latent_sample" : latents .numpy ()}
704+ start_time = time .time ()
714705 image = self .vae_decode .qpc_session .run (inputs )
715-
706+ end_time = time .time ()
707+ print (f"Decoder Text encoder inference time: { end_time - start_time :.4f} seconds" )
716708 image_tensor = torch .from_numpy (image ["sample" ])
717709 image = self .image_processor .postprocess (image_tensor , output_type = output_type )
718710
0 commit comments