Skip to content

Conversation

@begumcig
Copy link
Member

@begumcig begumcig commented Oct 6, 2025

Description

This PR introduces several updates to the inference handlers:

DiffusersHandler Enhancements:

  • Added support for video models.
  • Added support for return_dict=False, which skips wrapping outputs in a pipeline output object and instead returns a tuple directly.

Seeding Support Across Handlers:

  • All inference handlers now support optional seeding.

For DiffusersHandler, seeding is crucial since it directly affects the generation process.
For other models, inference is usually deterministic, but we added seeding for reproducibility in cases where randomness is involved.

Users can explicitly set or unset the seed. By default, no seed is applied, leaving seeding control to the user.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Unit tests were added in tests/test_handler.py to validate the new functionality:

  • Verified that seeding and unseeding work as expected for both standard inference handlers and the DiffusersHandler.

  • Added tests for image and video model outputs using lightweight/tiny models to ensure correctness and efficiency.

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Now you can configure the seed like:

input_text = "photo of a cute prune."
pipe = FluxPipeline.from_pretrained("katuni4ka/tiny-random-flux")
pruna_model = PrunaModel(pipe)
pruna_model.inference_handler.configure_seed("per_sample",global_seed=42)
output_image = pruna_model.run_inference(input_text)

@begumcig begumcig changed the title Feat/video models and seeding for inference Feat: video models and seeding for inference Oct 6, 2025
@begumcig begumcig changed the title Feat: video models and seeding for inference feat: video models and seeding for inference Oct 6, 2025
cursor[bot]

This comment was marked as outdated.

cursor[bot]

This comment was marked as outdated.

@begumcig begumcig force-pushed the feat/video-models-and-seeding-for-inference branch from bb515c8 to 914b790 Compare October 14, 2025 12:12
cursor[bot]

This comment was marked as outdated.

@begumcig begumcig force-pushed the feat/video-models-and-seeding-for-inference branch from 914b790 to e73912e Compare October 14, 2025 12:17
cursor[bot]

This comment was marked as outdated.

cursor[bot]

This comment was marked as outdated.

@begumcig begumcig force-pushed the feat/video-models-and-seeding-for-inference branch from 1a7caeb to 7a79bd6 Compare October 14, 2025 13:33
Copy link
Member

@davidberenstein1957 davidberenstein1957 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some small comments, but feel free to mege after resolving,

self.model_args = default_args
self.model_args = model_args if model_args else {}
# We want the default output type to be pytorch tensors.
self.model_args["output_type"] = "pt"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended to overwrite? Perhaps checking if it is present and overwriting only if it is not is better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! In the previous implementation we already automatically transformed everything to Tensors using torchvision.transforms.toTensor (line 87), so this is me just refactoring it using the more "official" way

)
inference_function = getattr(self, inference_function_name)

self.inference_handler.model_args = filter_load_kwargs(self.model.__call__, self.inference_handler.model_args)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sassy 😂

Comment on lines +73 to +77
if output_attr != "none":
pipe_output = getattr(pipe_output, output_attr)
pipe_output = pipe_output[0]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why were we doin this again? this is only a problem when using them in tests like this, but this is not the behavior we expect users to deal with, correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, we are doing all this here because the pipe is not a PrunaModel. Indeed when we use PrunaModel and Evaluation Agent we don't have to deal with this at all. This is the exact feature I am testing here :D

Copy link
Member

@simlang simlang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only have minor comments. super clean 💅

"""
inputs = metric_data_processor(x, gt, outputs, self.call_type)
if inputs[1].dtype == torch.bfloat16:
inputs[1] = inputs[1].to(torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would just cast it to f32, tbh.
This could actually change the value of the input, because the precision is different between bf16 and f16

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually moving this entire type casting part to inference handler! And yes will use f32 tysm

@begumcig begumcig force-pushed the feat/video-models-and-seeding-for-inference branch from deb8d45 to 3e56fb9 Compare October 28, 2025 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants