Skip to content

Commit 7377e41

Browse files
committed
format code
1 parent db811d7 commit 7377e41

12 files changed

Lines changed: 51 additions & 31 deletions

File tree

src/accelerate/optimizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ def zero_grad(self, set_to_none=None):
122122
self.optimizer.zero_grad()
123123

124124
def step(self, closure=None):
125-
if not self.gradient_state.is_xla_gradients_synced and self.accelerator_state.distributed_type == DistributedType.XLA:
125+
if (
126+
not self.gradient_state.is_xla_gradients_synced
127+
and self.accelerator_state.distributed_type == DistributedType.XLA
128+
):
126129
gradients = xm._fetch_gradients(self.optimizer)
127130
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
128131
self.gradient_state.is_xla_gradients_synced = True

src/accelerate/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import warnings
1818

19-
from .state import AcceleratorState, DistributedType, GradientState
19+
from .state import AcceleratorState, GradientState
2020

2121

2222
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")

src/accelerate/state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -990,10 +990,10 @@ class GradientState:
990990
accumulation
991991
- **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader
992992
iteration and the number of total steps reset
993-
- **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized as
994-
False. Once gradients have been reduced before the optimizer step, this flag is set to True.
995-
Subsequently, after each step, the flag is reset to False. FSDP will always synchronize the gradients,
996-
hence is_xla_gradients_synced is always true.
993+
- **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized
994+
as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently,
995+
after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence
996+
is_xla_gradients_synced is always true.
997997
"""
998998

999999
_shared_state = SharedDict()

src/accelerate/test_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
require_multi_gpu,
1313
require_multi_xpu,
1414
require_non_cpu,
15-
require_no_torch_xla,
15+
require_non_torch_xla,
1616
require_single_device,
1717
require_single_gpu,
1818
require_single_xpu,

src/accelerate/test_utils/scripts/external_deps/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from accelerate import Accelerator, DistributedType
2929
from accelerate.data_loader import DataLoaderDispatcher
3030
from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device
31-
from accelerate.utils import is_torch_xla_available, set_seed
31+
from accelerate.utils import set_seed
3232

3333

3434
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"

src/accelerate/test_utils/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def require_tpu(test_case):
175175
return unittest.skipUnless(is_torch_xla_available(check_is_tpu=True), "test requires TPU")(test_case)
176176

177177

178-
def require_no_torch_xla(test_case):
178+
def require_non_torch_xla(test_case):
179179
"""
180180
Decorator marking a test as requiring an environment without TorchXLA. These tests are skipped when TorchXLA is
181181
available.

tests/fsdp/test_fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
require_cuda,
3232
require_fsdp,
3333
require_multi_gpu,
34-
require_no_torch_xla,
34+
require_non_torch_xla,
3535
slow,
3636
)
3737
from accelerate.utils.constants import (
@@ -171,7 +171,7 @@ def test_cpu_offload(self):
171171

172172

173173
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
174-
@require_no_torch_xla
174+
@require_non_torch_xla
175175
@require_fsdp
176176
@require_multi_gpu
177177
@slow

tests/test_accelerator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from accelerate.accelerator import Accelerator
1313
from accelerate.state import GradientState, PartialState
1414
from accelerate.test_utils import require_bnb, require_multi_gpu, slow
15-
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_no_torch_xla
15+
from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
1616
from accelerate.utils import patch_environment
1717
from accelerate.utils.modeling import load_checkpoint_in_model
1818

@@ -63,7 +63,6 @@ def test_accelerator_can_be_reinstantiated(self):
6363
with self.assertRaises(ValueError):
6464
_ = Accelerator(cpu=True)
6565

66-
6766
def test_mutable_states(self):
6867
accelerator = Accelerator()
6968
state = GradientState()
@@ -105,7 +104,7 @@ def test_free_memory_dereferences_prepared_components(self):
105104
self.assertTrue(len(accelerator._schedulers) == 0)
106105
self.assertTrue(len(accelerator._dataloaders) == 0)
107106

108-
@require_no_torch_xla
107+
@require_non_torch_xla
109108
def test_env_var_device(self):
110109
"""Tests that setting the torch device with ACCELERATE_TORCH_DEVICE overrides default device."""
111110
PartialState._reset_state()
@@ -285,7 +284,7 @@ def test_is_accelerator_prepared(self):
285284
"Valid Dataloader is missing `_is_accelerator_prepared` or is set to `False`",
286285
)
287286

288-
@require_no_torch_xla
287+
@require_non_torch_xla
289288
@slow
290289
@require_bnb
291290
def test_accelerator_bnb(self):
@@ -302,7 +301,7 @@ def test_accelerator_bnb(self):
302301
# This should work
303302
model = accelerator.prepare(model)
304303

305-
@require_no_torch_xla
304+
@require_non_torch_xla
306305
@slow
307306
@require_bnb
308307
def test_accelerator_bnb_cpu_error(self):
@@ -328,7 +327,7 @@ def test_accelerator_bnb_cpu_error(self):
328327
with self.assertRaises(ValueError):
329328
model = accelerator.prepare(model)
330329

331-
@require_no_torch_xla
330+
@require_non_torch_xla
332331
@slow
333332
@require_bnb
334333
@require_multi_gpu
@@ -359,7 +358,7 @@ def test_accelerator_bnb_multi_gpu(self):
359358

360359
PartialState._reset_state()
361360

362-
@require_no_torch_xla
361+
@require_non_torch_xla
363362
@slow
364363
@require_bnb
365364
@require_multi_gpu

tests/test_big_modeling.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
load_checkpoint_and_dispatch,
3131
)
3232
from accelerate.hooks import remove_hook_from_submodules
33-
from accelerate.test_utils import require_bnb, require_cuda, require_mps, require_multi_gpu, require_no_torch_xla, slow
33+
from accelerate.test_utils import (
34+
require_bnb,
35+
require_cuda,
36+
require_mps,
37+
require_multi_gpu,
38+
require_non_torch_xla,
39+
slow,
40+
)
3441
from accelerate.utils import is_torch_version, offload_state_dict
3542

3643

@@ -708,7 +715,7 @@ def test_cpu_offload_with_hook(self):
708715
hook2.offload()
709716
self.assertEqual(model2.weight.device, torch.device("cpu"))
710717

711-
@require_no_torch_xla
718+
@require_non_torch_xla
712719
@slow
713720
@require_bnb
714721
@require_multi_gpu
@@ -740,7 +747,7 @@ def test_dispatch_model_bnb(self):
740747
self.assertTrue(model.h[-1].self_attention.query_key_value.weight.dtype == torch.int8)
741748
self.assertTrue(model.h[-1].self_attention.query_key_value.weight.device.index == 1)
742749

743-
@require_no_torch_xla
750+
@require_non_torch_xla
744751
@slow
745752
@require_bnb
746753
def test_dispatch_model_int8_simple(self):
@@ -803,7 +810,7 @@ def test_dispatch_model_int8_simple(self):
803810
self.assertTrue(model.h[0].self_attention.query_key_value.weight.dtype == torch.int8)
804811
self.assertTrue(model.h[0].self_attention.query_key_value.weight.device.index == 0)
805812

806-
@require_no_torch_xla
813+
@require_non_torch_xla
807814
@slow
808815
@require_bnb
809816
def test_dipatch_model_fp4_simple(self):

tests/test_multigpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import accelerate
2222
from accelerate import Accelerator
2323
from accelerate.big_modeling import dispatch_model
24-
from accelerate.test_utils import assert_exception, execute_subprocess_async, require_multi_gpu, require_no_torch_xla
24+
from accelerate.test_utils import assert_exception, execute_subprocess_async, require_multi_gpu, require_non_torch_xla
2525
from accelerate.utils import patch_environment
2626

2727

@@ -55,7 +55,7 @@ def test_pad_across_processes(self):
5555
with patch_environment(omp_num_threads=1):
5656
execute_subprocess_async(cmd, env=os.environ.copy())
5757

58-
@require_no_torch_xla
58+
@require_non_torch_xla
5959
@require_multi_gpu
6060
def test_distributed_data_loop(self):
6161
"""

0 commit comments

Comments
 (0)