-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugment_questions_vllm.py
More file actions
393 lines (330 loc) · 15.7 KB
/
augment_questions_vllm.py
File metadata and controls
393 lines (330 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
#!/usr/bin/env python3
"""Augment questions in a metadata-enriched JSONL dataset with paraphrased variants.
This script expects the input JSONL to match the structure produced by ``add_metadata.py``:
each line is a JSON object containing (at least) ``question``, ``query``, and ``schema`` fields,
plus optional metadata such as ``token_count``. For every record, the script:
1. Writes the original record to the output JSONL (unless ``--skip-original`` is supplied).
2. Builds an LLM prompt with the original question, its SQL query, and instructions to produce
alternative phrasings that preserve intent.
3. Calls a vLLM model in configurable batches to generate paraphrases, enforcing a plaintext
code-block response for easy parsing.
4. Appends a copy of the original record for each paraphrase with updated question text and
token counts, along with augmentation metadata (``augmentation_type``, ``augmentation_of``
etc.).
Example usage:
python augment_questions_vllm.py \
--input data_with_metadata.jsonl \
--output data_with_paraphrases.jsonl \
--model your-vllm-model \
--num-alternatives 3
The output file will contain all original records (unless skipped) followed by the generated
paraphrased variants, sharing the same schema and SQL query fields.
"""
import argparse
import copy
import json
import os
import random
import sys
import time
from typing import Any, Dict, List, Optional, Tuple
from vllm import LLM, SamplingParams
from datasets import load_dataset
try:
from tqdm import tqdm
except Exception: # pragma: no cover - fallback when tqdm is unavailable
tqdm = None # type: ignore
from ReFoRCE.utils import extract_code_blocks # re-use shared helper
from add_metadata import count_tokens # ensure consistent token accounting
BATCH_CAP = 32
DEFAULT_BATCH_SIZE = 12
DEFAULT_NUM_ALTERNATIVES = 2
SYSTEM_PREFIX = (
"You are an expert at paraphrasing natural-language questions for SQL tasks. "
"Produce faithful, diverse rewrites that keep the exact analytical intent of the provided SQL query."
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Generate paraphrased variants of questions in an add_metadata-enriched JSONL using vLLM. "
"Original records are kept and paraphrases are appended with updated metadata."
)
)
parser.add_argument("--input", required=True, help="Path to the JSONL produced by add_metadata.py")
parser.add_argument("--output", required=True, help="Destination JSONL for originals + paraphrases")
parser.add_argument("--model", required=True, help="Model name or path loadable by vLLM")
parser.add_argument("--num-alternatives", type=int, default=DEFAULT_NUM_ALTERNATIVES,
help="Paraphrases to request per record (default: %(default)s)")
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE,
help="Records per vLLM batch (capped at 32; default: %(default)s)")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--max-tokens", type=int, default=512, help="Maximum generation tokens")
parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling (default disables top-k)")
parser.add_argument("--tensor-parallel-size", type=int, default=1,
help="Tensor parallel size for vLLM engine")
parser.add_argument("--max-model-len", type=int, default=9182, help="Maximum model context length")
parser.add_argument("--retries", type=int, default=3,
help="Retries for responses that fail formatting checks (default: %(default)s)")
parser.add_argument("--seed", type=int, default=None, help="Seed for deterministic prompt sampling")
parser.add_argument("--start", type=int, default=0,
help="Start index (inclusive) in the dataset/JSONL to begin processing (default: 0)")
parser.add_argument("--end", type=int, default=None,
help="End index (exclusive) in the dataset/JSONL to stop processing (default: None, meaning iterate to the end)")
parser.add_argument("--skip-original", action="store_true",
help="Do not copy original records to the output (only write paraphrases)")
parser.add_argument("--no-progress", action="store_true",
help="Disable progress tracking even if tqdm is available")
return parser.parse_args()
def normalize(text: str) -> str:
return " ".join(text.lower().split())
def parse_plaintext_block(raw: str) -> List[str]:
blocks = extract_code_blocks(raw, "plaintext")
text = blocks[-1] if blocks else raw
paraphrases: List[str] = []
for line in text.splitlines():
candidate = line.strip().strip("-•").strip()
# Strip numbering or bullet prefixes that occasionally slip through.
while candidate and (candidate[0].isdigit() or candidate[0] in {'.', ')'}):
candidate = candidate[1:].lstrip()
candidate = candidate.strip()
if candidate:
paraphrases.append(candidate)
return paraphrases
def build_prompt(question: str, query: str, num_requested: int) -> str:
prompt = SYSTEM_PREFIX + "\n\n"
prompt += "Original question:\n"
prompt += question.strip() + "\n\n"
prompt += "SQL query that answers it:\n"
prompt += "```sql\n" + query.strip() + "\n```\n\n"
prompt += (
f"Task: Provide {num_requested} alternative phrasings of the original question. "
"Each paraphrase must stay fully faithful to the SQL query's intent, maintain the same requested outputs, "
"and avoid adding new requirements. Use varied wording and sentence structure.\n\n"
"Return format:\n"
"```plaintext\nParaphrase 1\nParaphrase 2\n...\n```\n"
"No commentary, no explanations, and no references to the paraphrasing process."
)
return prompt
def build_retry_prompt(base_prompt: str, attempt: int) -> str:
return base_prompt + (
"\n\n# FORMAT REMINDER (retry {n})\n"
"Respond with exactly one fenced block labeled ```plaintext``` containing only the paraphrases, "
"one per line. Remove bullets, numbering, or extra commentary."
).format(n=attempt)
def make_augmented_record(record: Dict[str, Any], paraphrase: str, model_name: str,
timestamp: str) -> Dict[str, Any]:
augmented = copy.deepcopy(record)
original_question = record.get("question", "")
augmented["question"] = paraphrase
token_count = augmented.get("token_count") or {}
# Reuse existing query/schema token counts if present; otherwise recompute.
query_tokens = token_count.get("query") if isinstance(token_count, dict) else None
schema_tokens = token_count.get("schema") if isinstance(token_count, dict) else None
if not isinstance(token_count, dict):
token_count = {}
token_count["question"] = count_tokens(paraphrase)
token_count["query"] = query_tokens if query_tokens is not None else count_tokens(record.get("query", ""))
token_count["schema"] = schema_tokens if schema_tokens is not None else count_tokens(record.get("schema", ""))
token_count["total"] = token_count["question"] + token_count["query"] + token_count["schema"]
augmented["token_count"] = token_count
augmented["augmentation_type"] = "question_paraphrase"
augmented["augmentation_of"] = original_question
augmented["paraphrase_model"] = model_name
augmented["paraphrased_at"] = timestamp
return augmented
def process_batch(
llm: LLM,
sampling: SamplingParams,
batch_entries: List[Tuple[Dict[str, Any], str]],
num_alternatives: int,
retries: int,
model_name: str,
out_stream,
skip_duplicate: bool = True,
) -> int:
if not batch_entries or num_alternatives <= 0:
return 0
base_prompts = [entry[1] for entry in batch_entries]
pending = list(range(len(batch_entries)))
paraphrase_map: Dict[int, List[str]] = {idx: [] for idx in range(len(batch_entries))}
last_raw: Dict[int, str] = {}
for attempt in range(retries + 1):
if not pending:
break
prompts_for_attempt = []
for idx in pending:
prompt = base_prompts[idx]
if attempt > 0:
prompt = build_retry_prompt(prompt, attempt)
prompts_for_attempt.append(prompt)
try:
outputs = llm.generate(prompts_for_attempt, sampling_params=sampling, use_tqdm=False)
except Exception as exc: # pragma: no cover - external engine failure
print(f"[WARN] vLLM generation failed on attempt {attempt}: {exc}", file=sys.stderr)
continue
next_pending: List[int] = []
for local_idx, global_idx in enumerate(pending):
out = outputs[local_idx]
raw = out.outputs[0].text if out.outputs else ""
last_raw[global_idx] = raw
paraphrases = parse_plaintext_block(raw)
# Remove items that contain 'Paraphrase'
paraphrases = [p for p in paraphrases if 'paraphrase' not in p.lower()]
paraphrase_map[global_idx] = paraphrases
if not paraphrases:
next_pending.append(global_idx)
pending = next_pending
for idx in pending:
entry = batch_entries[idx][0]
original_question = entry.get("question", "")
print(
f"[WARN] Failed to parse paraphrases after retries for question: {original_question}",
file=sys.stderr,
)
written = 0
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
for idx, (record, _) in enumerate(batch_entries):
paraphrases = paraphrase_map.get(idx, [])
if not paraphrases:
continue
original_question = record.get("question", "")
seen = {normalize(original_question)} if skip_duplicate else set()
kept: List[str] = []
for candidate in paraphrases:
cleaned = candidate.strip()
if not cleaned:
continue
norm = normalize(cleaned)
if skip_duplicate and norm in seen:
continue
kept.append(cleaned)
seen.add(norm)
if len(kept) >= num_alternatives:
break
for paraphrase in kept:
augmented = make_augmented_record(record, paraphrase, model_name, timestamp)
out_stream.write(json.dumps(augmented, ensure_ascii=False) + "\n")
written += 1
return written
def stream_records(input_spec: str, start: int = 0, end: Optional[int] = None):
"""Stream records either from a local JSONL file or a Hugging Face dataset.
Usage:
- Local JSONL: pass a filesystem path as before.
- Hugging Face: pass a dataset id like 'cwolff/small-text-to-sql' or prefix with 'hf://'.
"""
# Heuristics: explicit hf:// or hf: prefix or a slash with no local path -> treat as HF dataset id.
is_hf = False
dataset_id = None
if isinstance(input_spec, str) and (input_spec.startswith("hf://") or input_spec.startswith("hf:")):
dataset_id = input_spec.split("//", 1)[-1] if input_spec.startswith("hf://") else input_spec.split(":", 1)[1]
is_hf = True
elif isinstance(input_spec, str) and "/" in input_spec and not os.path.exists(input_spec):
# looks like an HF repo id and not a local path
dataset_id = input_spec
is_hf = True
if is_hf:
try:
ds = load_dataset(dataset_id)
except Exception as exc:
print(f"[ERROR] Failed to load Hugging Face dataset '{dataset_id}': {exc}", file=sys.stderr)
raise
# Choose a split: prefer 'train' if present, otherwise the first available split.
if isinstance(ds, dict):
split_name = "train" if "train" in ds else list(ds.keys())[0]
dataset = ds[split_name]
else:
dataset = ds
for idx, item in enumerate(dataset):
if idx < start:
continue
if end is not None and idx >= end:
break
# Items from datasets are typically already dict-like
yield dict(item)
return
# Fallback to local JSONL reading for existing behavior
with open(input_spec, "r", encoding="utf-8") as handle:
for idx, line in enumerate(handle):
if idx < start:
continue
if end is not None and idx >= end:
break
line = line.strip()
if not line:
continue
try:
yield json.loads(line)
except json.JSONDecodeError as exc:
print(f"[WARN] Skipping malformed JSON line {idx + 1}: {exc}", file=sys.stderr)
continue
def main() -> None:
args = parse_args()
if args.seed is not None:
random.seed(args.seed)
if os.path.abspath(args.input) == os.path.abspath(args.output):
print("[ERROR] Input and output paths must differ to avoid clobbering the source file.", file=sys.stderr)
sys.exit(1)
batch_size = min(max(1, args.batch_size), BATCH_CAP)
sampling = SamplingParams(
temperature=float(args.temperature),
max_tokens=int(args.max_tokens),
top_k=(args.top_k if args.top_k is not None else -1),
)
llm = LLM(
model=args.model,
max_model_len=int(args.max_model_len),
max_num_seqs=32,
tensor_parallel_size=int(args.tensor_parallel_size),
gpu_memory_utilization=0.85,
)
progress_iter = stream_records(args.input, start=args.start, end=args.end)
if tqdm and not args.no_progress:
# tqdm needs a known total; best-effort (may be None if limit enforced).
total = args.end - args.start if args.end is not None else None
progress_iter = tqdm(progress_iter, total=total, desc="Augmenting questions", unit="rec") # type: ignore
records_processed = 0
paraphrases_written = 0
with open(args.output, "w", encoding="utf-8") as fout:
batch: List[Tuple[Dict[str, Any], str]] = []
for record in progress_iter:
records_processed += 1
if not args.skip_original:
fout.write(json.dumps(record, ensure_ascii=False) + "\n")
question = record.get("question", "")
query = record.get("query", "")
if not question or not query:
print(
f"[WARN] Record {records_processed} missing question or query; skipping paraphrase.",
file=sys.stderr,
)
continue
prompt = build_prompt(question, query, args.num_alternatives)
batch.append((record, prompt))
if len(batch) >= batch_size:
paraphrases_written += process_batch(
llm=llm,
sampling=sampling,
batch_entries=batch,
num_alternatives=args.num_alternatives,
retries=args.retries,
model_name=args.model,
out_stream=fout,
)
batch = []
if batch:
paraphrases_written += process_batch(
llm=llm,
sampling=sampling,
batch_entries=batch,
num_alternatives=args.num_alternatives,
retries=args.retries,
model_name=args.model,
out_stream=fout,
)
if tqdm and not args.no_progress:
progress_iter.close() # type: ignore[attr-defined]
print(
f"Processed {records_processed} records and wrote {paraphrases_written} paraphrased variants to {args.output}"
)
if __name__ == "__main__":
main()