Skip to content

Commit 9b93ef7

Browse files
committed
lint
Signed-off-by: bitliu <[email protected]>
1 parent f0fdec5 commit 9b93ef7

File tree

1 file changed

+73
-42
lines changed

1 file changed

+73
-42
lines changed

e2e/scripts/generate_test_data.py

Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,23 @@ def generate_synthetic_pii_data(num_samples: int = 20) -> List[Dict]:
107107
}
108108

109109
pii_values = {
110-
110+
"EMAIL_ADDRESS": [
111+
112+
113+
114+
],
111115
"PHONE_NUMBER": ["555-123-4567", "(555) 987-6543", "555.246.8135"],
112116
"US_SSN": ["123-45-6789", "987-65-4321", "456-78-9012"],
113-
"CREDIT_CARD": ["4532-1234-5678-9010", "5425-2334-3010-9903", "3782-822463-10005"],
117+
"CREDIT_CARD": [
118+
"4532-1234-5678-9010",
119+
"5425-2334-3010-9903",
120+
"3782-822463-10005",
121+
],
114122
"PERSON": ["Jane Smith", "John Doe", "Michael Johnson"],
115-
"LOCATION": ["123 Main Street, Springfield, IL 62701", "456 Oak Avenue, Portland, OR 97201"],
123+
"LOCATION": [
124+
"123 Main Street, Springfield, IL 62701",
125+
"456 Oak Avenue, Portland, OR 97201",
126+
],
116127
"DATE_TIME": ["January 15, 1985", "March 22, 1990", "July 4, 1988"],
117128
"US_DRIVER_LICENSE": ["D1234567", "DL98765432", "A9876543"],
118129
"US_PASSPORT": ["AB1234567", "CD9876543", "EF5432109"],
@@ -131,19 +142,23 @@ def generate_synthetic_pii_data(num_samples: int = 20) -> List[Dict]:
131142
value = random.choice(pii_values[entity_type])
132143
text = template.format(value=value)
133144

134-
test_cases.append({
135-
"description": f"{entity_type} in text",
136-
"pii_type": entity_type,
137-
"question": text,
138-
"expected_blocked": True
139-
})
145+
test_cases.append(
146+
{
147+
"description": f"{entity_type} in text",
148+
"pii_type": entity_type,
149+
"question": text,
150+
"expected_blocked": True,
151+
}
152+
)
140153

141154
# Generate multi-PII samples (40%)
142155
num_multi = num_samples - num_single
143156
for i in range(num_multi):
144157
# Select 2-3 entity types
145158
num_entities = random.randint(2, 3)
146-
selected_types = random.sample(entity_types, min(num_entities, len(entity_types)))
159+
selected_types = random.sample(
160+
entity_types, min(num_entities, len(entity_types))
161+
)
147162

148163
# Build combined text
149164
parts = []
@@ -155,12 +170,14 @@ def generate_synthetic_pii_data(num_samples: int = 20) -> List[Dict]:
155170
text = " ".join(parts)
156171
primary_type = selected_types[0]
157172

158-
test_cases.append({
159-
"description": f"Multiple PII types: {', '.join(selected_types)}",
160-
"pii_type": primary_type,
161-
"question": text,
162-
"expected_blocked": True
163-
})
173+
test_cases.append(
174+
{
175+
"description": f"Multiple PII types: {', '.join(selected_types)}",
176+
"pii_type": primary_type,
177+
"question": text,
178+
"expected_blocked": True,
179+
}
180+
)
164181

165182
random.shuffle(test_cases)
166183
return test_cases
@@ -246,7 +263,9 @@ def generate_pii_test_data(num_samples: int = 20) -> List[Dict]:
246263
if has_required_type:
247264
english_samples.append(sample)
248265

249-
print(f"✅ Loaded {len(english_samples)} English samples with required PII types")
266+
print(
267+
f"✅ Loaded {len(english_samples)} English samples with required PII types"
268+
)
250269

251270
# Separate by number of PII types (for diversity)
252271
single_pii_samples = []
@@ -274,10 +293,9 @@ def generate_pii_test_data(num_samples: int = 20) -> List[Dict]:
274293
num_single = int(num_samples * 0.6)
275294
num_multi = num_samples - num_single
276295

277-
selected_samples = (
278-
random.sample(single_pii_samples, min(num_single, len(single_pii_samples))) +
279-
random.sample(multi_pii_samples, min(num_multi, len(multi_pii_samples)))
280-
)
296+
selected_samples = random.sample(
297+
single_pii_samples, min(num_single, len(single_pii_samples))
298+
) + random.sample(multi_pii_samples, min(num_multi, len(multi_pii_samples)))
281299

282300
# Shuffle to mix single and multi PII samples
283301
random.shuffle(selected_samples)
@@ -296,7 +314,9 @@ def generate_pii_test_data(num_samples: int = 20) -> List[Dict]:
296314
if label in REQUIRED_PII_TYPES:
297315
mapped_type = REQUIRED_PII_TYPES[label]
298316
mapped_types.append(mapped_type)
299-
entity_type_counts[mapped_type] = entity_type_counts.get(mapped_type, 0) + 1
317+
entity_type_counts[mapped_type] = (
318+
entity_type_counts.get(mapped_type, 0) + 1
319+
)
300320

301321
# Get unique mapped types
302322
unique_types = sorted(set(mapped_types))
@@ -310,18 +330,21 @@ def generate_pii_test_data(num_samples: int = 20) -> List[Dict]:
310330
# Use the primary entity type (most frequent)
311331
primary_type = max(entity_type_counts, key=entity_type_counts.get)
312332

313-
test_cases.append({
314-
"description": description,
315-
"pii_type": primary_type,
316-
"question": text,
317-
"expected_blocked": True
318-
})
333+
test_cases.append(
334+
{
335+
"description": description,
336+
"pii_type": primary_type,
337+
"question": text,
338+
"expected_blocked": True,
339+
}
340+
)
319341

320342
print(f"✅ Generated {len(test_cases)} PII test cases from ai4privacy dataset")
321343

322344
# Show distribution of PII types
323345
from collections import Counter
324-
type_counts = Counter(case['pii_type'] for case in test_cases)
346+
347+
type_counts = Counter(case["pii_type"] for case in test_cases)
325348
print(f" PII type distribution:")
326349
for pii_type, count in sorted(type_counts.items()):
327350
print(f" {pii_type}: {count}")
@@ -334,7 +357,9 @@ def generate_pii_test_data(num_samples: int = 20) -> List[Dict]:
334357
return generate_synthetic_pii_data(num_samples)
335358

336359

337-
def generate_domain_classification_test_data(samples_per_category: int = 20) -> List[Dict]:
360+
def generate_domain_classification_test_data(
361+
samples_per_category: int = 20,
362+
) -> List[Dict]:
338363
"""Generate domain classification test data from MMLU-Pro dataset.
339364
340365
Args:
@@ -343,7 +368,9 @@ def generate_domain_classification_test_data(samples_per_category: int = 20) ->
343368
Returns:
344369
List of test cases with balanced distribution across categories
345370
"""
346-
print(f"\n📚 Generating {samples_per_category} samples per category for domain classification...")
371+
print(
372+
f"\n📚 Generating {samples_per_category} samples per category for domain classification..."
373+
)
347374

348375
# Load MMLU-Pro dataset
349376
print("📥 Loading MMLU-Pro dataset...")
@@ -355,6 +382,7 @@ def generate_domain_classification_test_data(samples_per_category: int = 20) ->
355382

356383
# Group samples by category
357384
from collections import defaultdict
385+
358386
category_samples = defaultdict(list)
359387
for question, category in zip(questions, categories):
360388
category_samples[category].append(question)
@@ -373,17 +401,16 @@ def generate_domain_classification_test_data(samples_per_category: int = 20) ->
373401
selected_questions = random.sample(available_questions, num_to_sample)
374402

375403
for question in selected_questions:
376-
test_cases.append({
377-
"category": category,
378-
"question": question
379-
})
404+
test_cases.append({"category": category, "question": question})
380405

381406
print(f" ✓ {category}: {num_to_sample} samples")
382407

383408
# Shuffle to mix categories
384409
random.shuffle(test_cases)
385410

386-
print(f"✅ Generated {len(test_cases)} domain classification test cases ({samples_per_category} per category)")
411+
print(
412+
f"✅ Generated {len(test_cases)} domain classification test cases ({samples_per_category} per category)"
413+
)
387414
return test_cases
388415

389416

@@ -409,7 +436,9 @@ def generate_jailbreak_test_data(num_samples: int = 20) -> List[Dict]:
409436
# 2. Load salad-data dataset
410437
print("📥 Loading salad-data dataset...")
411438
try:
412-
salad_dataset = load_dataset("OpenSafetyLab/Salad-Data", "attack_enhanced_set", split="train")
439+
salad_dataset = load_dataset(
440+
"OpenSafetyLab/Salad-Data", "attack_enhanced_set", split="train"
441+
)
413442
salad_count = 0
414443
for sample in salad_dataset:
415444
text = sample.get("attack", "")
@@ -456,11 +485,13 @@ def generate_jailbreak_test_data(num_samples: int = 20) -> List[Dict]:
456485

457486
test_cases = []
458487
for idx, text in enumerate(selected_samples, 1):
459-
test_cases.append({
460-
"description": f"Jailbreak attempt {idx}",
461-
"question": text,
462-
"expected_blocked": True
463-
})
488+
test_cases.append(
489+
{
490+
"description": f"Jailbreak attempt {idx}",
491+
"question": text,
492+
"expected_blocked": True,
493+
}
494+
)
464495

465496
print(f"✅ Generated {len(test_cases)} jailbreak test cases")
466497
return test_cases

0 commit comments

Comments
 (0)