1515import importlib .util
1616import os
1717import sys
18+ from abc import ABC , abstractmethod
1819from collections import defaultdict
1920from functools import partial
2021from typing import Callable , Dict , List , Optional , Tuple , TypedDict
@@ -32,10 +33,12 @@ class RewardScore(TypedDict):
3233 accuracy : Optional [float ]
3334
3435
35- RewardFunction = Callable [[str , str ], RewardScore ]
36+ SequentialRewardFunction = Callable [[str , str ], RewardScore ]
3637
38+ BatchRewardFunction = Callable [[List [str ], List [str ]], List [RewardScore ]]
3739
38- class FunctionRewardManager :
40+
41+ class FunctionRewardManager (ABC ):
3942 """Reward manager for rule-based reward."""
4043
4144 def __init__ (self , config : RewardConfig , tokenizer : PreTrainedTokenizer ):
@@ -56,29 +59,60 @@ def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
5659 if not hasattr (module , config .reward_function_name ):
5760 raise AttributeError (f"Module { module } does not have function { config .reward_function_name } ." )
5861
59- reward_fn : RewardFunction = getattr (module , config .reward_function_name )
62+ reward_fn = getattr (module , config .reward_function_name )
6063 print (f"Using reward function `{ config .reward_function_name } ` from `{ config .reward_function } `." )
6164 self .reward_fn = partial (reward_fn , ** config .reward_function_kwargs )
6265 self .config = config
6366 self .tokenizer = tokenizer
6467
68+ @abstractmethod
69+ def compute_reward (self , data : DataProto ) -> Tuple [torch .Tensor , Dict [str , List [float ]]]:
70+ """Compute reward for a batch of data."""
71+ ...
72+
73+
74+ class SequentialFunctionRewardManager (FunctionRewardManager ):
75+ reward_fn : SequentialRewardFunction
76+
6577 def compute_reward (self , data : DataProto ) -> Tuple [torch .Tensor , Dict [str , List [float ]]]:
6678 reward_tensor = torch .zeros_like (data .batch ["responses" ], dtype = torch .float32 )
6779 reward_metrics = defaultdict (list )
80+ response_ids = data .batch ["responses" ]
81+ response_length = data .batch ["response_mask" ].sum (dim = - 1 )
6882 for i in range (len (data )):
69- data_item = data [i ] # DataProtoItem
70- response_ids = data_item .batch ["responses" ]
71- response_mask = data_item .batch ["response_mask" ]
72- valid_response_length = response_mask .sum ()
73- valid_response_ids = response_ids [:valid_response_length ]
74-
83+ valid_response_ids = response_ids [i ][: response_length [i ]]
7584 response_str = self .tokenizer .decode (
7685 valid_response_ids , skip_special_tokens = self .config .skip_special_tokens
7786 )
78- ground_truth = data_item .non_tensor_batch ["ground_truth" ]
87+ ground_truth = data .non_tensor_batch ["ground_truth" ][ i ]
7988
8089 score = self .reward_fn (response_str , ground_truth )
81- reward_tensor [i , valid_response_length - 1 ] = score ["overall" ]
90+ reward_tensor [i , response_length [i ] - 1 ] = score ["overall" ]
91+ for key , value in score .items ():
92+ reward_metrics [key ].append (value )
93+
94+ return reward_tensor , reward_metrics
95+
96+
97+ class BatchFunctionRewardManager (FunctionRewardManager ):
98+ reward_fn : BatchRewardFunction
99+
100+ def compute_reward (self , data : DataProto ) -> Tuple [torch .Tensor , Dict [str , List [float ]]]:
101+ response_str , ground_truth = [], []
102+ response_ids = data .batch ["responses" ]
103+ response_length = data .batch ["response_mask" ].sum (dim = - 1 )
104+ for i in range (len (data )):
105+ valid_response_ids = response_ids [i ][: response_length [i ]]
106+ response_str .append (
107+ self .tokenizer .decode (valid_response_ids , skip_special_tokens = self .config .skip_special_tokens )
108+ )
109+ ground_truth .append (data .non_tensor_batch ["ground_truth" ][i ])
110+
111+ scores = self .reward_fn (response_str , ground_truth )
112+ reward_tensor = torch .zeros_like (data .batch ["responses" ], dtype = torch .float32 )
113+ reward_metrics = defaultdict (list )
114+ for i , score in enumerate (scores ):
115+ reward_tensor [i , response_length [i ] - 1 ] = score ["overall" ]
82116 for key , value in score .items ():
83117 reward_metrics [key ].append (value )
84118
0 commit comments