Skip to content

Commit 1c9ac6a

Browse files
authored
Merge pull request #3 from taeungshin/claude/integrate-comet-evaluation-2iFB6
Add enhanced TensorBoard logging to finetune_lora.py
2 parents 4f0b08f + ea7d228 commit 1c9ac6a

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed

scripts/finetune_lora.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,14 @@ def _move_inputs_to_device(batch_input, device):
195195
return {k: v.to(device) for k, v in batch_input.items()}
196196

197197

198-
def train_epoch(model, dataloader, optimizer, device, epoch, writer=None, global_step=0, log_interval=100):
199-
"""1 에폭 학습"""
198+
def train_epoch(model, dataloader, optimizer, device, epoch,
199+
writer=None, global_step=0, log_interval=100,
200+
val_loader=None, eval_fn=None, eval_interval=0):
201+
"""1 에폭 학습
202+
203+
Args:
204+
eval_interval: N step마다 중간 validation 실행 (0=비활성)
205+
"""
200206
model.train()
201207
total_loss = 0
202208
num_batches = 0
@@ -222,21 +228,40 @@ def train_epoch(model, dataloader, optimizer, device, epoch, writer=None, global
222228

223229
loss.backward()
224230

225-
# Gradient clipping
226-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
231+
# Gradient norm (clipping 전)
232+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
227233
optimizer.step()
228234

229235
total_loss += loss.item()
230236
num_batches += 1
231237
global_step += 1
232238

233-
# TensorBoard: step별 loss
239+
# TensorBoard: step별 로깅
234240
if writer is not None:
235241
writer.add_scalar("train/step_loss", loss.item(), global_step)
242+
writer.add_scalar("train/grad_norm", grad_norm.item(), global_step)
243+
writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_step)
236244

237245
if (batch_idx + 1) % log_interval == 0:
238246
avg_loss = total_loss / num_batches
239-
logger.info(f" Epoch {epoch} [{batch_idx+1}/{len(dataloader)}] loss={avg_loss:.6f}")
247+
logger.info(f" Epoch {epoch} [{batch_idx+1}/{len(dataloader)}] "
248+
f"loss={avg_loss:.6f} grad_norm={grad_norm:.4f}")
249+
250+
# 에폭 중간 validation
251+
if eval_interval > 0 and val_loader is not None and eval_fn is not None:
252+
if (batch_idx + 1) % eval_interval == 0:
253+
logger.info(f" [Mid-epoch validation at step {global_step}]")
254+
mid_metrics = eval_fn(model, val_loader, device)
255+
logger.info(f" Pearson={mid_metrics['pearson']:.4f} "
256+
f"Spearman={mid_metrics['spearman']:.4f} "
257+
f"Kendall={mid_metrics['kendall']:.4f} "
258+
f"MSE={mid_metrics['mse']:.6f}")
259+
if writer is not None:
260+
writer.add_scalar("val_mid/pearson", mid_metrics["pearson"], global_step)
261+
writer.add_scalar("val_mid/spearman", mid_metrics["spearman"], global_step)
262+
writer.add_scalar("val_mid/kendall", mid_metrics["kendall"], global_step)
263+
writer.add_scalar("val_mid/mse", mid_metrics["mse"], global_step)
264+
model.train() # evaluate에서 eval()로 전환되므로 복구
240265

241266
return total_loss / max(num_batches, 1), global_step
242267

@@ -276,6 +301,8 @@ def evaluate(model, dataloader, device):
276301
"spearman": spearman_r,
277302
"kendall": kendall_tau,
278303
"mse": mse,
304+
"preds": preds,
305+
"targets": targets,
279306
}
280307

281308

@@ -298,6 +325,8 @@ def main():
298325
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
299326
parser.add_argument("--max_train_rows", type=int, default=0,
300327
help="Max training rows (0=all)")
328+
parser.add_argument("--eval_interval", type=int, default=0,
329+
help="에폭 중간 validation 간격 (step 단위, 0=에폭 끝에만 평가)")
301330
parser.add_argument("--seed", type=int, default=42, help="Random seed")
302331

303332
args = parser.parse_args()
@@ -378,6 +407,8 @@ def collate_fn_val(batch):
378407
logger.info(f"Training samples: {len(train_dataset)}")
379408
logger.info(f"Validation samples: {len(val_dataset)}")
380409
logger.info(f"Epochs: {args.epochs}, LR: {args.learning_rate}")
410+
if args.eval_interval > 0:
411+
logger.info(f"Mid-epoch validation every {args.eval_interval} steps")
381412

382413
# ========================================
383414
# 5. TensorBoard 초기화
@@ -411,6 +442,8 @@ def collate_fn_val(batch):
411442
train_loss, global_step = train_epoch(
412443
model, train_loader, optimizer, device, epoch + 1,
413444
writer=writer, global_step=global_step,
445+
val_loader=val_loader, eval_fn=evaluate,
446+
eval_interval=args.eval_interval,
414447
)
415448
logger.info(f" Train loss: {train_loss:.6f}")
416449

@@ -427,6 +460,12 @@ def collate_fn_val(batch):
427460
writer.add_scalar("val/kendall", metrics["kendall"], epoch + 1)
428461
writer.add_scalar("val/mse", metrics["mse"], epoch + 1)
429462

463+
# TensorBoard: 예측값 분포 히스토그램 (score collapse 감지용)
464+
writer.add_histogram("val/pred_distribution", metrics["preds"], epoch + 1)
465+
writer.add_histogram("val/target_distribution", metrics["targets"], epoch + 1)
466+
writer.add_scalar("val/pred_std", float(np.std(metrics["preds"])), epoch + 1)
467+
writer.add_scalar("val/pred_mean", float(np.mean(metrics["preds"])), epoch + 1)
468+
430469
# 체크포인트 저장
431470
if metrics["kendall"] > best_kendall:
432471
best_kendall = metrics["kendall"]

0 commit comments

Comments
 (0)