Skip to content

Commit c940428

Browse files
authored
fix: fix the deprecated usage of tuple slicing, wandb's start_method, and transformers's dtype (#462)
1 parent e0d09c3 commit c940428

File tree

5 files changed

+33
-39
lines changed

5 files changed

+33
-39
lines changed

areal/engine/base_hf_engine.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def create_device_model(self):
154154
model = AutoModelForImageTextToText.from_pretrained(
155155
pretrained_model_name_or_path=self.config.path,
156156
trust_remote_code=True,
157-
torch_dtype=dtype,
157+
dtype=dtype,
158158
attn_implementation=self.config.attn_impl,
159159
)
160160
if self.config.disable_dropout:
@@ -179,40 +179,35 @@ def create_device_model(self):
179179

180180
def _create_llm_actor_or_critic(self):
181181
dtype = getattr(torch, self.config.dtype)
182-
if not self.config.is_critic:
183-
if self.config.init_from_scratch:
184-
# initialize model from config
185-
# NOTE: VLM cannot directly load state dict using this
186-
# random initialized model, so otherwise we call
187-
# from_pretrained rather than loading weights into this random model.
188-
model = AutoModelForCausalLM.from_config(
189-
self.model_config,
190-
torch_dtype=dtype,
191-
attn_implementation=self.config.attn_impl,
192-
)
193-
else:
194-
model = AutoModelForCausalLM.from_pretrained(
195-
pretrained_model_name_or_path=self.config.path,
196-
trust_remote_code=True,
197-
torch_dtype=dtype,
198-
attn_implementation=self.config.attn_impl,
199-
)
182+
183+
if self.config.is_critic:
184+
model_class = AutoModelForTokenClassification
185+
model_kwargs = {"num_labels": 1}
200186
else:
201-
if self.config.init_from_scratch:
202-
model = AutoModelForTokenClassification.from_config(
203-
self.model_config,
204-
torch_dtype=dtype,
205-
num_labels=1,
206-
attn_implementation=self.config.attn_impl,
207-
)
208-
else:
209-
model = AutoModelForTokenClassification.from_pretrained(
210-
pretrained_model_name_or_path=self.config.path,
211-
trust_remote_code=True,
212-
torch_dtype=dtype,
213-
num_labels=1,
214-
attn_implementation=self.config.attn_impl,
215-
)
187+
model_class = AutoModelForCausalLM
188+
model_kwargs = {}
189+
190+
common_kwargs = {
191+
"dtype": dtype,
192+
"attn_implementation": self.config.attn_impl,
193+
}
194+
model_kwargs.update(common_kwargs)
195+
196+
if self.config.init_from_scratch:
197+
# initialize model from config
198+
# NOTE: VLM cannot directly load state dict using this
199+
# random initialized model, so otherwise we call
200+
# from_pretrained rather than loading weights into this random model.
201+
model = model_class.from_config(
202+
self.model_config,
203+
**model_kwargs,
204+
)
205+
else:
206+
model = model_class.from_pretrained(
207+
pretrained_model_name_or_path=self.config.path,
208+
trust_remote_code=True,
209+
**model_kwargs,
210+
)
216211
return model
217212

218213
def destroy(self):

areal/utils/stats_logger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def init(self):
6565
force=True,
6666
id=f"{self.config.experiment_name}_{self.config.trial_name}_{suffix}",
6767
resume="allow",
68-
settings=wandb.Settings(start_method="fork"),
6968
)
7069

7170
swanlab_config = self.config.swanlab

areal/utils/ulysses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
100100
return x
101101
slc = [slice(None)] * len(x.shape)
102102
slc[dim] = slice(0, -padding_size)
103-
return x[slc]
103+
return x[tuple(slc)]
104104

105105

106106
def slice_input_tensor(
@@ -118,7 +118,7 @@ def slice_input_tensor(
118118
parts = x.size(dim) // sp_world_size
119119
slc = [slice(None)] * len(x.shape)
120120
slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)
121-
return x[slc].contiguous()
121+
return x[tuple(slc)].contiguous()
122122

123123

124124
def all_to_all_tensor(

evaluation/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def load_hf_lm_and_tokenizer(
228228
# defaul load in float16
229229
model = AutoModelForCausalLM.from_pretrained(
230230
model_name_or_path,
231-
torch_dtype=torch.float16,
231+
dtype=torch.float16,
232232
device_map=device_map,
233233
trust_remote_code=True,
234234
use_safetensors=use_safetensors,

examples/docs/debug/cmp_rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
2727
model_id = "google/gemma-3-4b-it"
2828

2929
model = Gemma3ForConditionalGeneration.from_pretrained(
30-
model_id, torch_dtype=torch.bfloat16
30+
pretrained_model_name_or_path=model_id, dtype=torch.bfloat16
3131
).to("cuda")
3232
model.eval()
3333

0 commit comments

Comments
 (0)