Skip to content

Commit 7543d76

Browse files
committed
precommit
1 parent f97d002 commit 7543d76

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

python/nutpie/compile_pymc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def expand(_x, **shared):
576576
dims=dims,
577577
coords=coords,
578578
raw_logp_fn=orig_logp_fn,
579-
force_single_core=(gradient_backend == "mlx")
579+
force_single_core=(gradient_backend == "mlx"),
580580
)
581581

582582

python/nutpie/compiled_pyfunc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ def with_data(self, **updates):
4343
raise ValueError(f"Unknown data variable: {name}")
4444

4545
updated = self._shared_data.copy()
46-
46+
4747
# Convert to MLX arrays if using MLX backend (indicated by force_single_core)
4848
if self._force_single_core:
4949
import mlx.core as mx
50+
5051
for name, value in updates.items():
5152
updated[name] = mx.array(value)
5253
else:
5354
updated.update(**updates)
54-
55+
5556
return dataclasses.replace(self, _shared_data=updated)
5657

5758
def with_transform_adapt(self, **kwargs):

0 commit comments

Comments
 (0)