Skip to content

Commit 6cbc1b7

Browse files
authored
Add gpt-oss gen data script (#190)
* add gendata script * fix * fix * fix * fix
1 parent 14060c0 commit 6cbc1b7

File tree

1 file changed

+366
-0
lines changed

1 file changed

+366
-0
lines changed

scripts/gen_oss_dataset.py

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple script to generate responses using local SGLang API from JSONL file.
4+
5+
Data: https://huggingface.co/datasets/mlabonne/open-perfectblend
6+
Environment variables:
7+
# optional, default: http://localhost:30000
8+
Usage:
9+
step 1: data splitting
10+
```
11+
#!/bin/bash
12+
input="your_file.txt"
13+
lines_per_file=20000
14+
prefix="shard"
15+
ext=".json"
16+
total=$(($(wc -l < "$input" + lines_per_file - 1) / lines_per_file))
17+
split -l $lines_per_file -d -a 4 "$input" tmp_shard_
18+
i=0
19+
for f in tmp_shard_*; do
20+
shard_num=$((i+1))
21+
mv "$f" "${prefix}_${shard_num}_of_${total}${ext}"
22+
i=$((i+1))
23+
done
24+
```
25+
step 2: python3 -m sglang.launch_server --model-path openai/gpt-oss-20b --tp 8
26+
step 3: python gen_data.py <shared>
27+
Example: python gen_data.py 9
28+
"""
29+
import argparse
30+
import json
31+
import os
32+
import random
33+
import sys
34+
from concurrent.futures import ThreadPoolExecutor, as_completed
35+
from typing import Any, Dict, List, Optional
36+
37+
import requests
38+
from openai_harmony import (
39+
Author,
40+
Conversation,
41+
DeveloperContent,
42+
HarmonyEncodingName,
43+
Message,
44+
ReasoningEffort,
45+
Role,
46+
SystemContent,
47+
ToolDescription,
48+
load_harmony_encoding,
49+
)
50+
from tqdm.auto import tqdm
51+
52+
# Configuration
53+
BASE_URL = os.getenv("SGLANG_BASE_URL", "http://localhost:30000/v1/completions")
54+
HEADERS = {"Content-Type": "application/json"}
55+
56+
MODEL = "openai/gpt-oss-20b"
57+
MAX_TOKENS = 2048
58+
BATCH_SIZE = 128
59+
TEMPERATURE = 0.7
60+
61+
# Load harmony encoding once at module level to avoid repeated loading
62+
_harmony_encoding = None
63+
64+
65+
def get_random_reasoning_effort() -> ReasoningEffort:
66+
"""Get a random reasoning effort level for the model with weighted probabilities."""
67+
# Reasoning effort levels with weights: LOW(7), MEDIUM(2), HIGH(1)
68+
reasoning_efforts = [
69+
ReasoningEffort.LOW,
70+
ReasoningEffort.MEDIUM,
71+
ReasoningEffort.HIGH,
72+
]
73+
weights = [7, 2, 1] # 7:2:1 probability ratio
74+
return random.choices(reasoning_efforts, weights=weights, k=1)[0]
75+
76+
77+
def get_harmony_encoding():
78+
"""Get the harmony encoding, loading it only once."""
79+
global _harmony_encoding
80+
if _harmony_encoding is None:
81+
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
82+
return _harmony_encoding
83+
84+
85+
def build_prompt(user_msg: str, reasoning_effort) -> str:
86+
"""Embed user message into the required prompt template."""
87+
system_message = (
88+
SystemContent.new()
89+
.with_model_identity(
90+
"You are ChatGPT, a large language model trained by OpenAI."
91+
)
92+
.with_reasoning_effort(reasoning_effort)
93+
.with_conversation_start_date("2025-06-28")
94+
.with_knowledge_cutoff("2024-06")
95+
.with_required_channels(["analysis", "commentary", "final"])
96+
)
97+
convo = []
98+
convo.append(Message.from_role_and_content(Role.SYSTEM, system_message))
99+
convo.append(Message.from_role_and_content(Role.USER, user_msg))
100+
convo = Conversation.from_messages(convo)
101+
enc = get_harmony_encoding() # Use cached encoding
102+
tokens = enc.render_conversation_for_completion(convo, Role.ASSISTANT)
103+
prompt_text = enc.decode_utf8(tokens)
104+
return prompt_text
105+
106+
107+
def build_prompt_batch_parallel(
108+
batch_data: List[tuple], max_workers: int = 8
109+
) -> List[tuple]:
110+
"""
111+
Build prompts in parallel for a batch of data.
112+
113+
Args:
114+
batch_data: List of (item, human_msg) tuples
115+
max_workers: Maximum number of worker threads
116+
117+
Returns:
118+
List of (item, human_msg, reasoning_effort, prompt) tuples for successful builds
119+
"""
120+
121+
def build_single_prompt(item_data):
122+
item, human_msg = item_data
123+
try:
124+
reasoning_effort = get_random_reasoning_effort()
125+
prompt = build_prompt(human_msg, reasoning_effort)
126+
return (item, human_msg, reasoning_effort, prompt, None)
127+
except Exception as e:
128+
return (item, human_msg, None, None, str(e))
129+
130+
results = []
131+
132+
# Use ThreadPoolExecutor for parallel processing
133+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
134+
# Submit all tasks
135+
future_to_data = {
136+
executor.submit(build_single_prompt, item_data): item_data
137+
for item_data in batch_data
138+
}
139+
140+
# Collect results as they complete
141+
for future in as_completed(future_to_data):
142+
item, human_msg, reasoning_effort, prompt, error = future.result()
143+
if error:
144+
print(f"Error building prompt: {error}")
145+
else:
146+
results.append((item, human_msg, reasoning_effort, prompt))
147+
148+
return results
149+
150+
151+
def call_sglang_batch(prompts: List[str]) -> List[str]:
152+
"""Send a batch of prompts to sglang /v1/completions."""
153+
payload = {
154+
"model": MODEL,
155+
"prompt": prompts,
156+
"max_tokens": MAX_TOKENS,
157+
"temperature": TEMPERATURE,
158+
"skip_special_tokens": False,
159+
}
160+
161+
resp = requests.post(BASE_URL, headers=HEADERS, json=payload, timeout=600)
162+
resp.raise_for_status()
163+
data = resp.json()
164+
return [choice["text"].strip() for choice in data["choices"]]
165+
166+
167+
def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
168+
"""Load data from JSONL file."""
169+
data = []
170+
with open(file_path, "r", encoding="utf-8") as f:
171+
for line in f:
172+
line = line.strip()
173+
if line:
174+
data.append(json.loads(line))
175+
return data
176+
177+
178+
def extract_human_message(item: Dict[str, Any]) -> str:
179+
"""Extract human message from data item."""
180+
# Try common formats
181+
if "conversations" in item:
182+
conv = item["conversations"]
183+
if isinstance(conv, list) and len(conv) > 0:
184+
return conv[0].get("value", conv[0].get("content", ""))
185+
186+
# Try other common fields
187+
for field in ["message", "instruction", "question", "input", "text"]:
188+
if field in item:
189+
return item[field]
190+
191+
return str(item)
192+
193+
194+
def parse_channel_output(output: str) -> Dict[str, Optional[str]]:
195+
"""Parse the channel-based output format into analysis and final parts."""
196+
result = {"analysis": None, "final": None}
197+
198+
# Find analysis channel
199+
analysis_start = output.find("<|channel|>analysis<|message|>")
200+
if analysis_start != -1:
201+
analysis_start += len("<|channel|>analysis<|message|>")
202+
analysis_end = output.find("<|end|>", analysis_start)
203+
if analysis_end != -1:
204+
result["analysis"] = output[analysis_start:analysis_end].strip()
205+
206+
# Find final channel
207+
final_start = output.find("<|channel|>final<|message|>")
208+
if final_start != -1:
209+
final_start += len("<|channel|>final<|message|>")
210+
# Final content goes to the end of the string
211+
result["final"] = output[final_start:].strip()
212+
213+
return result
214+
215+
216+
def main():
217+
# Parse command line arguments
218+
parser = argparse.ArgumentParser(
219+
description="Generate GPT-OSS data from JSONL files"
220+
)
221+
parser.add_argument("shared", type=int, help="Starting shard number")
222+
parser.add_argument(
223+
"--input-dir", default="/data/", help="Input directory path (default: /data/)"
224+
)
225+
parser.add_argument(
226+
"--output-dir", default="/data/", help="Output directory path (default: /data/)"
227+
)
228+
parser.add_argument(
229+
"--shard-step",
230+
type=int,
231+
default=5,
232+
help="Process every Nth shard; step size (default: 5)",
233+
)
234+
235+
args = parser.parse_args()
236+
237+
start_shared = args.shared
238+
max_shared = 72 # Based on the filename pattern shard_X_of_72
239+
shard_step = max(1, args.shard_step)
240+
241+
for shared in range(start_shared, max_shared + 1, shard_step):
242+
input_file = os.path.join(args.input_dir, f"shard_{shared}_of_72.json")
243+
output_file = os.path.join(args.output_dir, f"shard_{shared}_of_72.json")
244+
245+
# Ensure output directory exists
246+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
247+
248+
# Check if input file exists
249+
if not os.path.exists(input_file):
250+
print(f"Input file not found: {input_file}")
251+
print(f"Stopping at shard {shared}")
252+
break
253+
try:
254+
data = load_jsonl(input_file)
255+
print(f"Loaded {len(data)} items")
256+
257+
if not data:
258+
print("No data found in input file, skipping.")
259+
continue
260+
261+
# Process data in batches
262+
total_saved = 0
263+
264+
# Prepare all valid data first
265+
valid_items = []
266+
for item in data:
267+
human_msg = extract_human_message(item)
268+
if human_msg.strip():
269+
valid_items.append((item, human_msg))
270+
271+
# Open output file once and write each batch result immediately
272+
with open(output_file, "w", encoding="utf-8") as f:
273+
# Process in batches
274+
for i in tqdm(
275+
range(0, len(valid_items), BATCH_SIZE),
276+
desc=f"Processing shard {shared}",
277+
):
278+
batch = valid_items[i : i + BATCH_SIZE]
279+
280+
# Build prompts in parallel for the entire batch
281+
try:
282+
batch_results = build_prompt_batch_parallel(
283+
batch, max_workers=8
284+
)
285+
286+
if not batch_results:
287+
continue
288+
289+
batch_prompts = []
290+
batch_items = []
291+
292+
for item, human_msg, reasoning_effort, prompt in batch_results:
293+
batch_prompts.append(prompt)
294+
batch_items.append((item, human_msg, reasoning_effort))
295+
296+
except Exception as e:
297+
print(f"Error in parallel prompt building: {e}")
298+
continue
299+
300+
if not batch_prompts:
301+
continue
302+
303+
try:
304+
# Process entire batch at once
305+
outputs = call_sglang_batch(batch_prompts)
306+
307+
# Process each response in the batch
308+
for j, output in enumerate(outputs):
309+
if (
310+
j < len(batch_items) and output
311+
): # Check bounds and valid response
312+
item, human_msg, reasoning_effort = batch_items[j]
313+
314+
# Parse the channel-based output
315+
parsed_output = parse_channel_output(output)
316+
317+
row = {
318+
"conversations": [
319+
{"from": "human", "value": human_msg},
320+
{"from": "assistant", "value": output},
321+
{
322+
"from": "assistant_analysis",
323+
"value": parsed_output["analysis"],
324+
},
325+
{
326+
"from": "assistant_final",
327+
"value": parsed_output["final"],
328+
},
329+
{
330+
"from": "assistant_reasoning_effort",
331+
"value": reasoning_effort.value,
332+
},
333+
],
334+
}
335+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
336+
total_saved += 1
337+
else:
338+
print(f"Warning: Empty response for batch item {j}")
339+
340+
f.flush() # Ensure data is written to disk after each batch
341+
342+
except Exception as e:
343+
print(f"Error processing batch starting at index {i}: {e}")
344+
continue
345+
# Show results for this shard
346+
if total_saved > 0:
347+
print(f"✅ Saved {total_saved} responses to {output_file}")
348+
print(
349+
f"Success rate: {total_saved}/{len(data)} ({total_saved/len(data)*100:.1f}%)"
350+
)
351+
else:
352+
print("No responses were generated for this shard.")
353+
except Exception as e:
354+
print(f"Error processing shard {shared}: {e}")
355+
print("Continuing to next shard...")
356+
continue
357+
358+
print(f"\n{'='*60}")
359+
print(
360+
f"Completed processing shards starting from {start_shared} (every {shard_step}th shard)"
361+
)
362+
print(f"{'='*60}")
363+
364+
365+
if __name__ == "__main__":
366+
main()

0 commit comments

Comments
 (0)