|
| 1 | +""" |
| 2 | +OpenMathReasoning Dataset Implementation |
| 3 | +
|
| 4 | +NVIDIA's OpenMathReasoning dataset - high-quality math problems with detailed |
| 5 | +chain-of-thought solutions. Contains 5.68M rows across multiple splits. |
| 6 | +
|
| 7 | +This implementation uses the 'cot' split which has 3.2M examples with detailed reasoning. |
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +import random |
| 12 | +import re |
| 13 | +import sys |
| 14 | +from typing import List, Optional, Tuple |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import pandas as pd |
| 18 | +from datasets import load_dataset |
| 19 | + |
| 20 | +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 21 | + |
| 22 | +from ..dataset_interface import DatasetInfo, DatasetInterface, Question |
| 23 | + |
| 24 | + |
| 25 | +class OpenMathReasoningDataset(DatasetInterface): |
| 26 | + """OpenMathReasoning dataset implementation for advanced mathematical reasoning.""" |
| 27 | + |
| 28 | + def __init__(self): |
| 29 | + """Initialize OpenMathReasoning dataset.""" |
| 30 | + self._dataset_cache = None |
| 31 | + self._categories_cache = None |
| 32 | + |
| 33 | + @property |
| 34 | + def dataset_name(self) -> str: |
| 35 | + return "OpenMathReasoning" |
| 36 | + |
| 37 | + @property |
| 38 | + def supports_cot(self) -> bool: |
| 39 | + return True # Has detailed chain-of-thought solutions |
| 40 | + |
| 41 | + def _load_raw_dataset(self, max_examples: int = 10000): |
| 42 | + """ |
| 43 | + Load raw OpenMathReasoning dataset from Hugging Face. |
| 44 | +
|
| 45 | + Args: |
| 46 | + max_examples: Maximum number of examples to load (default: 10000) |
| 47 | + This prevents loading all 3.2M rows unnecessarily. |
| 48 | + """ |
| 49 | + if self._dataset_cache is not None: |
| 50 | + return self._dataset_cache |
| 51 | + |
| 52 | + # Use STREAMING mode to avoid downloading the full 3.2M dataset |
| 53 | + # This way we only fetch the examples we actually need |
| 54 | + print(f"Loading OpenMathReasoning: {max_examples} examples (out of 3.2M total)") |
| 55 | + print(f" Using streaming mode to avoid downloading full dataset...") |
| 56 | + |
| 57 | + dataset_stream = load_dataset( |
| 58 | + "nvidia/OpenMathReasoning", split="cot", streaming=True |
| 59 | + ) |
| 60 | + |
| 61 | + # Take only the first max_examples from the stream |
| 62 | + examples = [] |
| 63 | + for i, example in enumerate(dataset_stream): |
| 64 | + if i >= max_examples: |
| 65 | + break |
| 66 | + examples.append(example) |
| 67 | + if (i + 1) % 1000 == 0: |
| 68 | + print(f" Loaded {i + 1}/{max_examples} examples...", end="\r") |
| 69 | + |
| 70 | + print(f"\n ✓ Loaded {len(examples)} examples (streamed, not cached)") |
| 71 | + self._dataset_cache = pd.DataFrame(examples) |
| 72 | + return self._dataset_cache |
| 73 | + |
| 74 | + def _get_categories(self, max_examples: int = 10000) -> List[str]: |
| 75 | + """Get available categories in OpenMathReasoning dataset.""" |
| 76 | + if self._categories_cache is not None: |
| 77 | + return self._categories_cache |
| 78 | + |
| 79 | + # OpenMathReasoning has problem_type and problem_source fields |
| 80 | + # We'll use problem_type as categories |
| 81 | + # Load a subset to discover categories |
| 82 | + df = self._load_raw_dataset(max_examples=max_examples) |
| 83 | + self._categories_cache = df["problem_type"].unique().tolist() |
| 84 | + return self._categories_cache |
| 85 | + |
| 86 | + def get_available_categories(self) -> List[str]: |
| 87 | + """Get list of all available categories in the dataset.""" |
| 88 | + return self._get_categories() |
| 89 | + |
| 90 | + def load_dataset( |
| 91 | + self, |
| 92 | + categories: Optional[List[str]] = None, |
| 93 | + samples_per_category: Optional[int] = None, |
| 94 | + seed: int = 42, |
| 95 | + max_cot_length: Optional[int] = None, |
| 96 | + ) -> Tuple[List[Question], DatasetInfo]: |
| 97 | + """ |
| 98 | + Load OpenMathReasoning dataset with optional filtering and sampling. |
| 99 | +
|
| 100 | + Args: |
| 101 | + categories: Filter by problem types |
| 102 | + samples_per_category: Number of samples per category |
| 103 | + seed: Random seed for sampling |
| 104 | + max_cot_length: Maximum character length for CoT solutions (for memory efficiency) |
| 105 | + """ |
| 106 | + # Calculate how many examples we need to load |
| 107 | + # If samples_per_category is specified, we can limit loading |
| 108 | + # Use a buffer factor based on whether we're filtering by length |
| 109 | + if samples_per_category: |
| 110 | + # If filtering by length, load more samples to compensate |
| 111 | + buffer_factor = 15 if max_cot_length else 3 |
| 112 | + estimated_needed = samples_per_category * 3 * buffer_factor |
| 113 | + max_to_load = min( |
| 114 | + estimated_needed, 100000 |
| 115 | + ) # Cap at 100k for length filtering |
| 116 | + else: |
| 117 | + # Load more if no limit specified |
| 118 | + max_to_load = 50000 # Still cap to avoid loading all 3.2M |
| 119 | + |
| 120 | + df = self._load_raw_dataset(max_examples=max_to_load) |
| 121 | + available_categories = self._get_categories(max_examples=max_to_load) |
| 122 | + |
| 123 | + # Filter by CoT length if specified (for memory-efficient training) |
| 124 | + if max_cot_length: |
| 125 | + print( |
| 126 | + f"\n 📏 Filtering samples by CoT length (max: {max_cot_length} chars)" |
| 127 | + ) |
| 128 | + original_count = len(df) |
| 129 | + df["cot_length"] = df["generated_solution"].str.len() |
| 130 | + df = df[df["cot_length"] <= max_cot_length] |
| 131 | + print( |
| 132 | + f" ✓ Kept {len(df)}/{original_count} samples ({len(df)/original_count*100:.1f}%) after length filtering" |
| 133 | + ) |
| 134 | + |
| 135 | + # Print distribution stats |
| 136 | + if len(df) > 0: |
| 137 | + print(f" 📊 CoT Length Stats (filtered):") |
| 138 | + print(f" Min: {df['cot_length'].min()} chars") |
| 139 | + print(f" Max: {df['cot_length'].max()} chars") |
| 140 | + print(f" Mean: {df['cot_length'].mean():.0f} chars") |
| 141 | + print(f" Median: {df['cot_length'].median():.0f} chars") |
| 142 | + |
| 143 | + # Filter categories if specified |
| 144 | + if categories: |
| 145 | + missing_categories = set(categories) - set(available_categories) |
| 146 | + if missing_categories: |
| 147 | + raise ValueError( |
| 148 | + f"Categories not found: {missing_categories}. " |
| 149 | + f"Available: {available_categories}" |
| 150 | + ) |
| 151 | + df = df[df["problem_type"].isin(categories)] |
| 152 | + selected_categories = categories |
| 153 | + else: |
| 154 | + selected_categories = available_categories |
| 155 | + |
| 156 | + # Sample questions if specified (per category) |
| 157 | + if samples_per_category: |
| 158 | + np.random.seed(seed) |
| 159 | + random.seed(seed) |
| 160 | + |
| 161 | + sampled_dfs = [] |
| 162 | + for category in selected_categories: |
| 163 | + category_df = df[df["problem_type"] == category] |
| 164 | + sample_size = min(samples_per_category, len(category_df)) |
| 165 | + if sample_size > 0: |
| 166 | + sampled_df = category_df.sample(n=sample_size, random_state=seed) |
| 167 | + sampled_dfs.append(sampled_df) |
| 168 | + |
| 169 | + if sampled_dfs: |
| 170 | + df = pd.concat(sampled_dfs, ignore_index=True) |
| 171 | + else: |
| 172 | + df = pd.DataFrame() |
| 173 | + |
| 174 | + # Convert to Question objects |
| 175 | + questions = [] |
| 176 | + for _, row in df.iterrows(): |
| 177 | + problem_text = row["problem"] |
| 178 | + solution_text = row["generated_solution"] |
| 179 | + expected_answer = row.get("expected_answer", "") |
| 180 | + problem_type = row.get("problem_type", "default") |
| 181 | + |
| 182 | + # Clean the answer if needed |
| 183 | + correct_answer = str(expected_answer).strip() |
| 184 | + |
| 185 | + question = Question( |
| 186 | + question_id=f"openmr_{len(questions)}", |
| 187 | + question=problem_text, |
| 188 | + options=[], # Free-form, no multiple choice |
| 189 | + correct_answer=correct_answer, |
| 190 | + category=problem_type, |
| 191 | + cot_content=solution_text, # Full solution with detailed reasoning |
| 192 | + metadata={ |
| 193 | + "difficulty": "Advanced", |
| 194 | + "type": "math_problem", |
| 195 | + "problem_source": row.get("problem_source", "unknown"), |
| 196 | + "generation_model": row.get("generation_model", "unknown"), |
| 197 | + "pass_rate_72b_tir": row.get("pass_rate_72b_tir", "unknown"), |
| 198 | + }, |
| 199 | + ) |
| 200 | + questions.append(question) |
| 201 | + |
| 202 | + dataset_info = DatasetInfo( |
| 203 | + name="OpenMathReasoning", |
| 204 | + description="NVIDIA's high-quality math problems with detailed chain-of-thought reasoning", |
| 205 | + categories=selected_categories, |
| 206 | + total_questions=len(questions), |
| 207 | + format_type="free_form", |
| 208 | + difficulty_level="Advanced", |
| 209 | + ) |
| 210 | + |
| 211 | + return questions, dataset_info |
| 212 | + |
| 213 | + def format_prompt(self, question: Question, prompt_style: str = "plain") -> str: |
| 214 | + """Format prompt for OpenMathReasoning questions.""" |
| 215 | + if prompt_style == "plain": |
| 216 | + return f"""Solve this math problem: |
| 217 | +
|
| 218 | +{question.question} |
| 219 | +
|
| 220 | +Please provide your final answer in the following structured format: |
| 221 | +The answer is [your_final_answer] |
| 222 | +
|
| 223 | +For example: The answer is 42""" |
| 224 | + |
| 225 | + elif prompt_style == "explicit_cot": |
| 226 | + return f"""Solve this math problem step by step, showing all your reasoning: |
| 227 | +
|
| 228 | +Problem: {question.question} |
| 229 | +
|
| 230 | +Please work through this step-by-step: |
| 231 | +1. Read the problem carefully and understand what is being asked |
| 232 | +2. Identify the given information and what needs to be found |
| 233 | +3. Choose appropriate methods and formulas |
| 234 | +4. Work through the solution step by step with clear explanations |
| 235 | +5. Verify your answer makes sense |
| 236 | +6. State your final answer clearly |
| 237 | +
|
| 238 | +Please provide your final answer in the following structured format: |
| 239 | +The answer is [your_final_answer] |
| 240 | +
|
| 241 | +For example: The answer is 42""" |
| 242 | + |
| 243 | + else: |
| 244 | + raise ValueError(f"Unknown prompt style: {prompt_style}") |
0 commit comments