Skip to content

Commit 6deb96c

Browse files
committed
Add retrain progress exporter and meeting RL config
1 parent 922d45d commit 6deb96c

File tree

8 files changed

+740
-13
lines changed

8 files changed

+740
-13
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Energy RL meeting run:
2+
# - reduced 64-turn horizon
3+
# - rollout fan-out 2
4+
# - 4096-token safety budget for rarer OOD/verbose completions
5+
# Isolated log dir so this run cannot disturb canonical experiments.
6+
7+
[backend]
8+
backend = "tinker"
9+
10+
[model]
11+
model = "Qwen/Qwen3.5-4B"
12+
lora_rank = 128
13+
14+
[algorithm]
15+
advantage_mode = "reinforce_pp"
16+
transform_mode = "none"
17+
18+
[training]
19+
seed = 42
20+
max_steps = 200
21+
sft_warmup_steps = 0
22+
batch_size = 4
23+
group_size = 2
24+
max_tokens = 4096
25+
temperature = 0.7
26+
lr = 1e-5
27+
save_every = 20
28+
batch_advantage_norm = true
29+
adv_clip_max = 5.0
30+
31+
[environment]
32+
provider = "verifiers"
33+
id = "soma_energy"
34+
max_turns = 64
35+
36+
[environment.args]
37+
num_examples = 64
38+
dataset_seed = 7
39+
max_turns = 64
40+
http_url = "http://127.0.0.1:13737"
41+
42+
[environment.args.default_config]
43+
grid_max_import_kw = 10.0
44+
utility_grid_export_limit_kw = 3.0
45+
forecast_demand_base_mape = 0.02
46+
forecast_demand_mape_per_step = 0.01
47+
forecast_solar_base_mape = 0.05
48+
forecast_solar_mape_per_step = 0.025
49+
50+
[environment.args.evaluator]
51+
horizon_ticks = 96
52+
53+
[logging]
54+
log_dir = "logs/energy-rl-turn64-max4096-g2"

retrain/backend_definitions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
from collections.abc import Callable, Mapping
99
from dataclasses import dataclass, field
10+
from pathlib import Path
1011
from typing import TYPE_CHECKING, TypedDict, cast
1112

1213
if TYPE_CHECKING:
@@ -108,6 +109,7 @@ def _create_tinker(config: "TrainConfig") -> "TrainHelper":
108109
clip_eps_high=config.clip_eps_high,
109110
grad_clip_norm=config.grad_clip_norm,
110111
clip_ratio_c=config.clip_ratio_c,
112+
sample_log_dir=str(Path(config.log_dir).resolve()),
111113
)
112114
helper.sft_loss_fn = config.sft_loss_fn
113115
return helper

retrain/config.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import difflib
1010
import json
11+
import os
1112
import re
1213
import sys
1314
import tomllib
@@ -32,6 +33,14 @@
3233
_DEFAULT_ADAPTER_PATH = "/tmp/retrain_adapter"
3334

3435

36+
def _first_non_empty_env(*names: str) -> str:
37+
for name in names:
38+
value = os.getenv(name, "").strip()
39+
if value:
40+
return value
41+
return ""
42+
43+
3544
@dataclass
3645
class SqueezeConfig:
3746
"""Configuration for LoRA-Squeeze rank analysis and compression."""
@@ -225,6 +234,29 @@ class TrainConfig:
225234
plugins_strict: bool = True
226235

227236
def __post_init__(self) -> None:
237+
if not self.wandb_project:
238+
self.wandb_project = _first_non_empty_env(
239+
"SOMA_WANDB_PROJECT",
240+
"RETRAIN_WANDB_PROJECT",
241+
"WANDB_PROJECT",
242+
)
243+
if not self.wandb_entity:
244+
self.wandb_entity = _first_non_empty_env(
245+
"SOMA_WANDB_ENTITY",
246+
"RETRAIN_WANDB_ENTITY",
247+
"WANDB_ENTITY",
248+
)
249+
if not self.wandb_group:
250+
self.wandb_group = _first_non_empty_env(
251+
"SOMA_WANDB_GROUP",
252+
"RETRAIN_WANDB_GROUP",
253+
)
254+
if not self.wandb_tags:
255+
self.wandb_tags = _first_non_empty_env(
256+
"SOMA_WANDB_TAGS",
257+
"RETRAIN_WANDB_TAGS",
258+
)
259+
228260
# --- Hard errors (batched) ---
229261
errors: list[str] = []
230262

0 commit comments

Comments
 (0)