Skip to content

Commit 3e56fb9

Browse files
committed
refactor: configure seeding and tests
1 parent 6e41d48 commit 3e56fb9

File tree

4 files changed

+52
-59
lines changed

4 files changed

+52
-59
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ possibly-missing-attribute = "ignore" # mypy is more permissive with attribute a
3535
possibly-missing-import = "ignore" # mypy is more permissive with imports
3636
no-matching-overload = "ignore" # mypy is more permissive with overloads
3737
unresolved-reference = "ignore" # mypy is more permissive with references
38+
missing-argument = "ignore"
3839

3940
[tool.coverage.run]
4041
source = ["src/pruna"]

src/pruna/engine/handler/handler_diffuser.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121

22-
from pruna.engine.handler.handler_inference import InferenceHandler, validate_seed_strategy
22+
from pruna.engine.handler.handler_inference import InferenceHandler
2323
from pruna.logging.logger import pruna_logger
2424

2525

@@ -93,35 +93,27 @@ def process_output(self, output: Any) -> torch.Tensor:
9393
# Maybe the user is calling the pipeline with return_dict = False,
9494
# which then returns the generated image / video in a tuple
9595
generated = output[0]
96-
return generated
96+
return generated.float()
9797

9898
def log_model_info(self) -> None:
9999
"""Log information about the inference handler."""
100100
pruna_logger.info(
101101
"Detected diffusers model. Using DiffuserHandler.\n- The first element of the batch is passed as input.\n"
102+
"Inference outputs are expected to have either have an `images` attribute or a `frames` attribute."
103+
"Or be a tuple with the generated image / video as the first element."
102104
)
103105

104-
def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None:
106+
def set_seed(self, seed: int) -> None:
105107
"""
106-
Set the random seed according to the chosen strategy.
107-
108-
- If `seed_strategy="per_sample"`,the `global_seed` is used as a base to derive a different seed for each
109-
sample. This ensures reproducibility while still producing variation across samples,
110-
making it the preferred option for benchmarking.
111-
- If `seed_strategy="no_seed"`, no seed is set internally.
112-
The user is responsible for managing seeds if reproducibility is required.
108+
Set the random seed for the current process.
113109
114110
Parameters
115111
----------
116-
seed_strategy : Literal["per_sample", "no_seed"]
117-
The seeding strategy to apply.
118-
global_seed : int | None
119-
The base seed value to use (if applicable).
112+
seed : int
113+
The seed to set.
120114
"""
121-
self.seed_strategy = seed_strategy
122-
validate_seed_strategy(seed_strategy, global_seed)
123-
if global_seed is not None:
124-
self.global_seed = global_seed
125-
self.model_args["generator"] = torch.Generator("cpu").manual_seed(global_seed)
126-
else:
127-
self.model_args["generator"] = None # Remove the seed.
115+
self.model_args["generator"] = torch.Generator("cpu").manual_seed(seed)
116+
117+
def remove_seed(self) -> None:
118+
"""Remove the seed from the current process."""
119+
self.model_args["generator"] = None

src/pruna/engine/handler/handler_inference.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -122,38 +122,36 @@ def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global
122122
validate_seed_strategy(seed_strategy, global_seed)
123123
if global_seed is not None:
124124
self.global_seed = global_seed
125-
set_seed(global_seed)
125+
self.set_seed(global_seed)
126126
else:
127-
remove_seed()
127+
self.remove_seed()
128128

129+
def set_seed(self, seed: int) -> None:
130+
"""
131+
Set the random seed for the current process.
129132
130-
def set_seed(seed: int) -> None:
131-
"""
132-
Set the random seed for the current process.
133-
134-
Parameters
135-
----------
136-
seed : int
137-
The seed to set.
138-
"""
139-
# With the default handler, we can't assume anything about the model,
140-
# so we are setting the seed for all RNGs available.
141-
random.seed(seed)
142-
np.random.seed(seed)
143-
torch.manual_seed(seed)
144-
if torch.cuda.is_available():
145-
torch.cuda.manual_seed_all(seed)
146-
147-
148-
def remove_seed() -> None:
149-
"""Remove the seed from the current process."""
150-
random.seed(None)
151-
np.random.seed(None)
152-
# We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed().
153-
# torch.seed() creates a non-deterministic random number.
154-
torch.manual_seed(torch.seed())
155-
if torch.cuda.is_available():
156-
torch.cuda.manual_seed_all(torch.seed())
133+
Parameters
134+
----------
135+
seed : int
136+
The seed to set.
137+
"""
138+
# With the default handler, we can't assume anything about the model,
139+
# so we are setting the seed for all RNGs available.
140+
random.seed(seed)
141+
np.random.seed(seed)
142+
torch.manual_seed(seed)
143+
if torch.cuda.is_available():
144+
torch.cuda.manual_seed_all(seed)
145+
146+
def remove_seed(self) -> None:
147+
"""Remove the seed from the current process."""
148+
random.seed(None)
149+
np.random.seed(None)
150+
# We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed().
151+
# torch.seed() creates a non-deterministic random number.
152+
torch.manual_seed(torch.seed())
153+
if torch.cuda.is_available():
154+
torch.cuda.manual_seed_all(torch.seed())
157155

158156

159157
def validate_seed_strategy(seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None:

tests/engine/test_handler.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import types
1+
import numpy as np
22
import pytest
33
import torch
44
from pruna.engine.handler.handler_inference import (
5-
set_seed,
65
validate_seed_strategy,
76
)
87
from pruna.engine.handler.handler_diffuser import DiffuserHandler
8+
from pruna.engine.handler.handler_standard import StandardHandler
99
from pruna.engine.pruna_model import PrunaModel
1010
from pruna.engine.utils import move_to_device
1111
# Default handler tests, mainly for checking seeding.
@@ -25,14 +25,16 @@ def test_validate_seed_strategy_invalid(strategy, seed):
2525
with pytest.raises(ValueError):
2626
validate_seed_strategy(strategy, seed)
2727

28-
2928
def test_set_seed_reproducibility():
30-
'''Test to see set_seed is reproducible'''
31-
set_seed(42)
32-
a = torch.randn(3)
33-
set_seed(42)
34-
b = torch.randn(3)
35-
assert torch.equal(a, b)
29+
inference_handler = StandardHandler()
30+
inference_handler.set_seed(42)
31+
torch_random_tensor = torch.randn(3)
32+
numpy_random_tensor = np.random.randn(3)
33+
inference_handler.set_seed(42)
34+
torch_expected = torch.randn(3)
35+
numpy_expected = np.random.randn(3)
36+
assert torch.equal(torch_random_tensor, torch_expected)
37+
assert np.array_equal(numpy_random_tensor, numpy_expected)
3638

3739

3840
# Diffuser handler tests, checking output processing and seeding.

0 commit comments

Comments
 (0)