Skip to content

Commit 2a761ec

Browse files
author
Amit Raj
committed
Modification of Pipeline-2
Signed-off-by: Amit Raj <[email protected]>
1 parent 5f6d8bd commit 2a761ec

File tree

8 files changed

+271
-201
lines changed

8 files changed

+271
-201
lines changed

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
QEffFluxAttnProcessor,
3636
QEffFluxSingleTransformerBlock,
3737
QEffFluxTransformer2DModel,
38+
QEffFluxTransformer2DModelOF,
3839
QEffFluxTransformerBlock,
3940
)
4041

@@ -83,7 +84,7 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
8384

8485

8586
class OnnxFunctionTransform(ModuleMappingTransform):
86-
_module_mapping = {FluxTransformer2DModel: QEffFluxTransformer2DModel}
87+
_module_mapping = {QEffFluxTransformer2DModel, QEffFluxTransformer2DModelOF}
8788

8889
@classmethod
8990
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323

2424
from QEfficient.diffusers.models.normalization import (
25-
QEffAdaLayerNormContinuous,
2625
QEffAdaLayerNormZero,
2726
QEffAdaLayerNormZeroSingle,
2827
)
@@ -253,58 +252,6 @@ def forward(
253252

254253

255254
class QEffFluxTransformer2DModel(FluxTransformer2DModel):
256-
def __init__(
257-
self,
258-
patch_size: int = 1,
259-
in_channels: int = 64,
260-
out_channels: Optional[int] = None,
261-
num_layers: int = 19,
262-
num_single_layers: int = 38,
263-
attention_head_dim: int = 128,
264-
num_attention_heads: int = 24,
265-
joint_attention_dim: int = 4096,
266-
pooled_projection_dim: int = 768,
267-
guidance_embeds: bool = False,
268-
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
269-
):
270-
super().__init__(
271-
patch_size=patch_size,
272-
in_channels=in_channels,
273-
out_channels=out_channels,
274-
num_layers=num_layers,
275-
num_single_layers=num_single_layers,
276-
attention_head_dim=attention_head_dim,
277-
num_attention_heads=num_attention_heads,
278-
joint_attention_dim=joint_attention_dim,
279-
pooled_projection_dim=pooled_projection_dim,
280-
guidance_embeds=guidance_embeds,
281-
axes_dims_rope=axes_dims_rope,
282-
)
283-
284-
self.transformer_blocks = nn.ModuleList(
285-
[
286-
QEffFluxTransformerBlock(
287-
dim=self.inner_dim,
288-
num_attention_heads=num_attention_heads,
289-
attention_head_dim=attention_head_dim,
290-
)
291-
for _ in range(num_layers)
292-
]
293-
)
294-
295-
self.single_transformer_blocks = nn.ModuleList(
296-
[
297-
QEffFluxSingleTransformerBlock(
298-
dim=self.inner_dim,
299-
num_attention_heads=num_attention_heads,
300-
attention_head_dim=attention_head_dim,
301-
)
302-
for _ in range(num_single_layers)
303-
]
304-
)
305-
306-
self.norm_out = QEffAdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
307-
308255
def forward(
309256
self,
310257
hidden_states: torch.Tensor,
@@ -448,3 +395,31 @@ def forward(
448395
return (output,)
449396

450397
return Transformer2DModelOutput(sample=output)
398+
399+
400+
class QEffFluxTransformer2DModelOF(QEffFluxTransformer2DModel):
401+
def __qeff_init__(self):
402+
self.transformer_blocks = nn.ModuleList()
403+
self._block_classes = set()
404+
405+
for _ in range(self.config.num_layers):
406+
BlockClass = QEffFluxTransformerBlock
407+
block = BlockClass(
408+
dim=self.inner_dim,
409+
num_attention_heads=self.config.num_attention_heads,
410+
attention_head_dim=self.config.attention_head_dim,
411+
)
412+
self.transformer_blocks.append(block)
413+
self._block_classes.add(BlockClass)
414+
415+
self.single_transformer_blocks = nn.ModuleList()
416+
417+
for _ in range(self.config.num_single_layers):
418+
SingleBlockClass = QEffFluxSingleTransformerBlock
419+
single_block = SingleBlockClass(
420+
dim=self.inner_dim,
421+
num_attention_heads=self.config.num_attention_heads,
422+
attention_head_dim=self.config.attention_head_dim,
423+
)
424+
self.single_transformer_blocks.append(single_block)
425+
self._block_classes.add(SingleBlockClass)
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"description": "Example compilation configuration for Flux pipeline",
2+
"description": "Default configuration for Flux pipeline",
33
"model_type": "flux",
44

55
"modules":
@@ -52,9 +52,7 @@
5252
{
5353
"batch_size": 1,
5454
"seq_len": 256,
55-
"steps": 1,
56-
"num_layers": 1,
57-
"num_single_layers": 1
55+
"steps": 1
5856
},
5957
"compilation":
6058
{

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from QEfficient.diffusers.pipelines.config_manager import config_manager, set_module_device_ids
2121
from QEfficient.diffusers.pipelines.pipeline_utils import (
22-
QEffClipTextEncoder,
2322
QEffFluxTransformerModel,
2423
QEffTextEncoder,
2524
QEffVAE,
@@ -38,11 +37,13 @@ class QEFFFluxPipeline(FluxPipeline):
3837
"""
3938

4039
def __init__(self, model, use_onnx_function, *args, **kwargs):
41-
self.text_encoder = QEffClipTextEncoder(model.text_encoder)
40+
self.text_encoder = QEffTextEncoder(model.text_encoder)
4241
self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
4342
self.transformer = QEffFluxTransformerModel(model.transformer, use_onnx_function=use_onnx_function)
4443
self.vae_decode = QEffVAE(model, "decoder")
4544
self.use_onnx_function = use_onnx_function
45+
46+
# Add all modules of FluxPipeline
4647
self.has_module = [
4748
("text_encoder", self.text_encoder),
4849
("text_encoder_2", self.text_encoder_2),
@@ -78,6 +79,10 @@ def __init__(self, model, use_onnx_function, *args, **kwargs):
7879
self.latent_width = self.width // self.vae_scale_factor
7980
self.cl = (self.latent_height * self.latent_width) // 4
8081

82+
self.text_encoder_2.model.config.max_position_embeddings = (
83+
self.text_encoder.model.config.max_position_embeddings
84+
)
85+
8186
@classmethod
8287
def from_pretrained(
8388
cls,
@@ -140,6 +145,15 @@ def export(self, export_dir: Optional[str] = None) -> str:
140145
export_kwargs=export_kwargs,
141146
)
142147

148+
def get_default_config_path():
149+
"""
150+
Returns the default configuration file path for Flux pipeline.
151+
152+
Returns:
153+
str: Path to the default flux_config.json file.
154+
"""
155+
return os.path.join(os.path.dirname(__file__), "flux_config.json")
156+
143157
def compile(
144158
self,
145159
compile_config: Optional[str] = None,
@@ -193,7 +207,6 @@ def _get_t5_prompt_embeds(
193207
num_images_per_prompt: int = 1,
194208
max_sequence_length: int = 512,
195209
device_ids: Optional[List[int]] = None,
196-
dtype: Optional[torch.dtype] = None,
197210
):
198211
"""
199212
Get T5 prompt embeddings for the given prompt(s).
@@ -203,7 +216,6 @@ def _get_t5_prompt_embeds(
203216
num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt.
204217
max_sequence_length (int, defaults to 256): Maximum sequence length for tokenization.
205218
device ids (Optional[torch.device], optional): The device to place tensors on QAIC device ids.
206-
dtype (Optional[torch.dtype], optional): The data type for tensors.
207219
208220
Returns:
209221
torch.Tensor: The T5 prompt embeddings with shape (batch_size * num_images_per_prompt, seq_len, hidden_size).
@@ -245,12 +257,12 @@ def _get_t5_prompt_embeds(
245257
self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output)
246258

247259
aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
248-
prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
260+
import time
249261

250-
# # # AIC Testing
251-
# prompt_embeds_pytorch = self.text_encoder_2.model(text_input_ids, output_hidden_states=False)
252-
# mad = torch.abs(prompt_embeds_pytorch["last_hidden_state"] - prompt_embeds).mean()
253-
# print(">>>>>>>>>>>> MAD for text-encoder-2 - T5 => Pytorch vs AI 100:", mad)
262+
start_time = time.time()
263+
prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
264+
end_time = time.time()
265+
print(f"T5 Text encoder inference time: {end_time - start_time:.4f} seconds")
254266

255267
_, seq_len, _ = prompt_embeds.shape
256268
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
@@ -303,26 +315,21 @@ def _get_clip_prompt_embeds(
303315
self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder.qpc_path), device_ids=device_ids)
304316

305317
text_encoder_output = {
306-
"pooler_output": np.random.rand(batch_size, embed_dim).astype(np.int32),
307318
"last_hidden_state": np.random.rand(batch_size, self.tokenizer_max_length, embed_dim).astype(np.int32),
319+
"pooler_output": np.random.rand(batch_size, embed_dim).astype(np.int32),
308320
}
309321

310322
self.text_encoder.qpc_session.set_buffers(text_encoder_output)
311323

312324
aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
313-
aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
314-
# aic_text_encoder_emb = aic_embeddings["pooler_output"]
315325

316-
# # # # [TEMP] CHECK ACC # #
317-
# prompt_embeds_pytorch = self.text_encoder.model(text_input_ids, output_hidden_states=False)
318-
# pt_pooled_embed = prompt_embeds_pytorch["pooler_output"].detach().numpy()
319-
# mad = np.max(np.abs(pt_pooled_embed - aic_text_encoder_emb))
320-
# print(f">>>>>>>>>>>> CLIP text encoder pooled embed MAD: ", mad) ## 0.0043082903 ##TODO : Clean up
321-
### END CHECK ACC ###
326+
import time
322327

323-
# Use pooled output of CLIPTextModel
328+
start_time = time.time()
329+
aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
330+
end_time = time.time()
331+
print(f"CLIP Text encoder inference time: {end_time - start_time:.4f} seconds")
324332
prompt_embeds = torch.tensor(aic_embeddings["pooler_output"])
325-
# prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
326333

327334
# duplicate text embeddings for each generation per prompt, using mps friendly method
328335
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
@@ -491,23 +498,6 @@ def __call__(
491498
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
492499
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
493500
images.
494-
495-
Examples:
496-
```python
497-
# Basic text-to-image generation
498-
from QEfficient import QEFFFluxPipeline
499-
pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
500-
pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1)
501-
502-
generator = torch.manual_seed(42)
503-
# NOTE: guidance_scale <=1 is not supported
504-
image = pipeline("A cat holding a sign that says hello world",
505-
guidance_scale=0.0,
506-
num_inference_steps=4,
507-
max_sequence_length=256,
508-
generator=generator).images[0]
509-
image.save("flux-schnell_aic.png")
510-
```
511501
"""
512502
device = "cpu"
513503

@@ -663,7 +653,7 @@ def __call__(
663653
start_time = time.time()
664654
outputs = self.transformer.qpc_session.run(inputs_aic)
665655
end_time = time.time()
666-
print(f"Time : {end_time - start_time:.2f} seconds")
656+
print(f"Transformers inference time : {end_time - start_time:.2f} seconds")
667657

668658
noise_pred = torch.from_numpy(outputs["output"])
669659

@@ -711,8 +701,10 @@ def __call__(
711701
self.vae_decode.qpc_session.set_buffers(output_buffer)
712702

713703
inputs = {"latent_sample": latents.numpy()}
704+
start_time = time.time()
714705
image = self.vae_decode.qpc_session.run(inputs)
715-
706+
end_time = time.time()
707+
print(f"Decoder Text encoder inference time: {end_time - start_time:.4f} seconds")
716708
image_tensor = torch.from_numpy(image["sample"])
717709
image = self.image_processor.postprocess(image_tensor, output_type=output_type)
718710

0 commit comments

Comments
 (0)