diff --git a/examples/qwen3_14b_dapo17k_dapo.sh b/examples/qwen3_14b_dapo17k_dapo.sh index 45e5ee3d..18cc7b34 100644 --- a/examples/qwen3_14b_dapo17k_dapo.sh +++ b/examples/qwen3_14b_dapo17k_dapo.sh @@ -10,7 +10,7 @@ python3 -m verl.trainer.main \ config=examples/config.yaml \ data.train_files=Saigyouji-Yuyuko1000/dapo17k@train \ data.val_files=Saigyouji-Yuyuko1000/dapo17k@test \ - data.format_prompt=./examples/format_prompt/math.jinja \ + data.format_prompt=./examples/format_prompt/dapo.jinja \ data.max_prompt_length=2048 \ data.max_response_length=20480 \ data.rollout_batch_size=512 \ diff --git a/examples/reward_function/dapo.py b/examples/reward_function/dapo.py index 35c79166..63a3997e 100644 --- a/examples/reward_function/dapo.py +++ b/examples/reward_function/dapo.py @@ -12,14 +12,134 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Any, Dict, List -from mathruler.grader import extract_boxed_content, grade_answer + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva( + solution_str: str, gt: str, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" +) -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + answer_pattern: Regex pattern to extract the answer + + Returns: + is_correct: Whether the answer is correct + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + gt = normalize_final_answer(gt) + + return pred == gt def accuracy_reward(response: str, ground_truth: str) -> float: - answer = extract_boxed_content(response) - return 1.0 if grade_answer(answer, ground_truth) else -1.0 + return 1.0 if is_correct_minerva(response, ground_truth) else -1.0 def soft_overlong_punishment(response_length: int, max_response_length: int, overlong_buffer_length: int):