Skip to content

Commit 7df7eb7

Browse files
committed
revert changes done by linting.
1 parent 01c6982 commit 7df7eb7

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

docs/tutorials/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ These tutorials will guide you through the process of using |pruna| to optimize
7373
.. grid-item-card:: Smashing at Finer Granularity with Target Modules
7474
:text-align: center
7575
:link: ./target_modules_quanto.ipynb
76-
76+
7777
Learn how to use the ``target_modules`` parameter to target specific modules in your model.
7878

7979
.. toctree::
@@ -82,4 +82,4 @@ These tutorials will guide you through the process of using |pruna| to optimize
8282
:caption: Pruna
8383
:glob:
8484

85-
./*
85+
./*

docs/utils/gen_docs.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def generate_algorithm_desc(obj: PrunaAlgorithmBase, name_suffix: str = "") -> s
4545
f"| **Can be applied on**: {compatible_devices_str}.",
4646
f"| **Required**: {required_inputs_str}.",
4747
f"| **Compatible with**: {compatible_algorithms_str}.",
48-
f"| **Required install**: {required_install_str}." if required_install_str else "",
48+
f"| **Required install**: {required_install_str}."
49+
if required_install_str
50+
else "",
4951
]
5052
)
5153

@@ -81,12 +83,20 @@ def format_grid_table(rows: list[list[str]]) -> str:
8183
total_widths = [w + 2 for w in col_widths]
8284

8385
horizontal_border = "+" + "+".join("-" * width for width in total_widths) + "+"
84-
header_line = "|" + "|".join(" " + rows[0][i].ljust(col_widths[i]) + " " for i in range(num_cols)) + "|"
86+
header_line = (
87+
"|"
88+
+ "|".join(" " + rows[0][i].ljust(col_widths[i]) + " " for i in range(num_cols))
89+
+ "|"
90+
)
8591
header_separator = "+" + "+".join("=" * width for width in total_widths) + "+"
8692

8793
data_lines = []
8894
for row in rows[1:]:
89-
row_line = "|" + "|".join(" " + row[i].ljust(col_widths[i]) + " " for i in range(num_cols)) + "|"
95+
row_line = (
96+
"|"
97+
+ "|".join(" " + row[i].ljust(col_widths[i]) + " " for i in range(num_cols))
98+
+ "|"
99+
)
90100
data_lines.append(row_line)
91101
data_lines.append(horizontal_border)
92102

@@ -212,4 +222,4 @@ def get_table_rows(obj: PrunaAlgorithmBase) -> tuple[list[list[str]], int]:
212222
for algorithm_group in PRUNA_ALGORITHMS.values():
213223
for algorithm in algorithm_group.values():
214224
f.write(generate_algorithm_desc(algorithm))
215-
f.write("\n\n")
225+
f.write("\n\n")

src/pruna/evaluation/metrics/metric_elapsed_time.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,9 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> Dict[str, Any] |
198198
# Measurement
199199
list_elapsed_times = []
200200
with tqdm(total=self.n_iterations, desc="Measuring inference time", unit="iter") as pbar:
201-
202201
def measure_with_progress(m, x):
203202
list_elapsed_times.append(self._time_inference(m, x))
204203
pbar.update(1)
205-
206204
self._measure(model, dataloader, self.n_iterations, measure_with_progress)
207205

208206
total_elapsed_time = sum(list_elapsed_times)
@@ -349,4 +347,4 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> MetricResult:
349347
# Use EvaluationAgent to share computation across time metrics.
350348
raw_results = super().compute(model, dataloader)
351349
result = cast(Dict[str, Any], raw_results)[self.metric_name]
352-
return MetricResult(self.metric_name, self.__dict__.copy(), result)
350+
return MetricResult(self.metric_name, self.__dict__.copy(), result)

src/pruna/evaluation/metrics/metric_torch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None:
124124

125125

126126
def ssim_update(
127-
metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any
127+
metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure,
128+
preds: Any,
129+
target: Any
128130
) -> None:
129131
"""
130132
Update handler for SSIM or MS-SSIM metric.
@@ -400,4 +402,4 @@ def get_call_type(call_type: str, metric_name: str) -> str:
400402
return get_single_pairing(TorchMetrics[metric_name].call_type)
401403
else:
402404
pruna_logger.error(f"Invalid call type: {call_type}. Must be one of {CALL_TYPES}.")
403-
raise ValueError(f"Invalid call type: {call_type}. Must be one of {CALL_TYPES}.")
405+
raise ValueError(f"Invalid call type: {call_type}. Must be one of {CALL_TYPES}.")

0 commit comments

Comments
 (0)