@@ -195,8 +195,14 @@ def _move_inputs_to_device(batch_input, device):
195195 return {k : v .to (device ) for k , v in batch_input .items ()}
196196
197197
198- def train_epoch (model , dataloader , optimizer , device , epoch , writer = None , global_step = 0 , log_interval = 100 ):
199- """1 에폭 학습"""
198+ def train_epoch (model , dataloader , optimizer , device , epoch ,
199+ writer = None , global_step = 0 , log_interval = 100 ,
200+ val_loader = None , eval_fn = None , eval_interval = 0 ):
201+ """1 에폭 학습
202+
203+ Args:
204+ eval_interval: N step마다 중간 validation 실행 (0=비활성)
205+ """
200206 model .train ()
201207 total_loss = 0
202208 num_batches = 0
@@ -222,21 +228,40 @@ def train_epoch(model, dataloader, optimizer, device, epoch, writer=None, global
222228
223229 loss .backward ()
224230
225- # Gradient clipping
226- torch .nn .utils .clip_grad_norm_ (model .parameters (), 1.0 )
231+ # Gradient norm ( clipping 전)
232+ grad_norm = torch .nn .utils .clip_grad_norm_ (model .parameters (), 1.0 )
227233 optimizer .step ()
228234
229235 total_loss += loss .item ()
230236 num_batches += 1
231237 global_step += 1
232238
233- # TensorBoard: step별 loss
239+ # TensorBoard: step별 로깅
234240 if writer is not None :
235241 writer .add_scalar ("train/step_loss" , loss .item (), global_step )
242+ writer .add_scalar ("train/grad_norm" , grad_norm .item (), global_step )
243+ writer .add_scalar ("train/lr" , optimizer .param_groups [0 ]["lr" ], global_step )
236244
237245 if (batch_idx + 1 ) % log_interval == 0 :
238246 avg_loss = total_loss / num_batches
239- logger .info (f" Epoch { epoch } [{ batch_idx + 1 } /{ len (dataloader )} ] loss={ avg_loss :.6f} " )
247+ logger .info (f" Epoch { epoch } [{ batch_idx + 1 } /{ len (dataloader )} ] "
248+ f"loss={ avg_loss :.6f} grad_norm={ grad_norm :.4f} " )
249+
250+ # 에폭 중간 validation
251+ if eval_interval > 0 and val_loader is not None and eval_fn is not None :
252+ if (batch_idx + 1 ) % eval_interval == 0 :
253+ logger .info (f" [Mid-epoch validation at step { global_step } ]" )
254+ mid_metrics = eval_fn (model , val_loader , device )
255+ logger .info (f" Pearson={ mid_metrics ['pearson' ]:.4f} "
256+ f"Spearman={ mid_metrics ['spearman' ]:.4f} "
257+ f"Kendall={ mid_metrics ['kendall' ]:.4f} "
258+ f"MSE={ mid_metrics ['mse' ]:.6f} " )
259+ if writer is not None :
260+ writer .add_scalar ("val_mid/pearson" , mid_metrics ["pearson" ], global_step )
261+ writer .add_scalar ("val_mid/spearman" , mid_metrics ["spearman" ], global_step )
262+ writer .add_scalar ("val_mid/kendall" , mid_metrics ["kendall" ], global_step )
263+ writer .add_scalar ("val_mid/mse" , mid_metrics ["mse" ], global_step )
264+ model .train () # evaluate에서 eval()로 전환되므로 복구
240265
241266 return total_loss / max (num_batches , 1 ), global_step
242267
@@ -276,6 +301,8 @@ def evaluate(model, dataloader, device):
276301 "spearman" : spearman_r ,
277302 "kendall" : kendall_tau ,
278303 "mse" : mse ,
304+ "preds" : preds ,
305+ "targets" : targets ,
279306 }
280307
281308
@@ -298,6 +325,8 @@ def main():
298325 parser .add_argument ("--epochs" , type = int , default = 3 , help = "Number of epochs" )
299326 parser .add_argument ("--max_train_rows" , type = int , default = 0 ,
300327 help = "Max training rows (0=all)" )
328+ parser .add_argument ("--eval_interval" , type = int , default = 0 ,
329+ help = "에폭 중간 validation 간격 (step 단위, 0=에폭 끝에만 평가)" )
301330 parser .add_argument ("--seed" , type = int , default = 42 , help = "Random seed" )
302331
303332 args = parser .parse_args ()
@@ -378,6 +407,8 @@ def collate_fn_val(batch):
378407 logger .info (f"Training samples: { len (train_dataset )} " )
379408 logger .info (f"Validation samples: { len (val_dataset )} " )
380409 logger .info (f"Epochs: { args .epochs } , LR: { args .learning_rate } " )
410+ if args .eval_interval > 0 :
411+ logger .info (f"Mid-epoch validation every { args .eval_interval } steps" )
381412
382413 # ========================================
383414 # 5. TensorBoard 초기화
@@ -411,6 +442,8 @@ def collate_fn_val(batch):
411442 train_loss , global_step = train_epoch (
412443 model , train_loader , optimizer , device , epoch + 1 ,
413444 writer = writer , global_step = global_step ,
445+ val_loader = val_loader , eval_fn = evaluate ,
446+ eval_interval = args .eval_interval ,
414447 )
415448 logger .info (f" Train loss: { train_loss :.6f} " )
416449
@@ -427,6 +460,12 @@ def collate_fn_val(batch):
427460 writer .add_scalar ("val/kendall" , metrics ["kendall" ], epoch + 1 )
428461 writer .add_scalar ("val/mse" , metrics ["mse" ], epoch + 1 )
429462
463+ # TensorBoard: 예측값 분포 히스토그램 (score collapse 감지용)
464+ writer .add_histogram ("val/pred_distribution" , metrics ["preds" ], epoch + 1 )
465+ writer .add_histogram ("val/target_distribution" , metrics ["targets" ], epoch + 1 )
466+ writer .add_scalar ("val/pred_std" , float (np .std (metrics ["preds" ])), epoch + 1 )
467+ writer .add_scalar ("val/pred_mean" , float (np .mean (metrics ["preds" ])), epoch + 1 )
468+
430469 # 체크포인트 저장
431470 if metrics ["kendall" ] > best_kendall :
432471 best_kendall = metrics ["kendall" ]
0 commit comments