-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_pipeline.py
More file actions
134 lines (119 loc) · 5.68 KB
/
run_pipeline.py
File metadata and controls
134 lines (119 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import os
import sys
from pathlib import Path
from datetime import datetime
PROJECT_ROOT = Path(__file__).parent
sys.path.insert(0, str(PROJECT_ROOT))
from utils.io_utils import load_config, ensure_dir, list_files, merge_jsonl_files, save_json
class Pipeline:
def __init__(self, config_path: str = "configs/config.yaml"):
self.config = load_config(config_path)
self.config_path = config_path
self.raw_dir = ensure_dir(self.config['paths']['raw_data'])
self.processed_dir = ensure_dir(self.config['paths']['processed_data'])
self.final_dir = ensure_dir(self.config['paths']['final_data'])
def collect_arxiv(self):
print("\n" + "="*50 + "\nCOLLECTING FROM ARXIV\n" + "="*50)
from collectors.arxiv_collector import ArxivCollector
c = ArxivCollector(self.config_path)
c.collect_all()
c.create_instruction_pairs()
def collect_stackexchange(self):
print("\n" + "="*50 + "\nCOLLECTING FROM STACKEXCHANGE\n" + "="*50)
from collectors.stackexchange_collector import StackExchangeCollector
c = StackExchangeCollector(self.config_path)
c.collect_all()
c.create_instruction_pairs()
def collect_wikipedia(self):
print("\n" + "="*50 + "\nCOLLECTING FROM WIKIPEDIA\n" + "="*50)
from collectors.wikipedia_collector import WikipediaCollector
c = WikipediaCollector(self.config_path)
c.collect_all()
c.create_instruction_pairs()
def collect_distill(self):
print("\n" + "="*50 + "\nCOLLECTING FROM DISTILL.PUB\n" + "="*50)
from collectors.distill_collector import DistillCollector
c = DistillCollector(self.config_path)
c.collect_all()
c.create_instruction_pairs()
def collect_huggingface(self):
print("\n" + "="*50 + "\nCOLLECTING FROM HUGGINGFACE\n" + "="*50)
from collectors.huggingface_datasets_collector import HuggingFaceCollector
HuggingFaceCollector(self.config_path).collect_all()
def collect_all(self):
self.collect_arxiv()
self.collect_stackexchange()
self.collect_wikipedia()
self.collect_distill()
self.collect_huggingface()
def merge_collected_data(self):
print("\n" + "="*50 + "\nMERGING DATA\n" + "="*50)
files = []
for sub in ['arxiv', 'stackexchange', 'wikipedia', 'distill', 'huggingface']:
path = f"{self.raw_dir}/{sub}"
if os.path.exists(path):
files.extend([f for f in list_files(path, '.jsonl') if 'progress' not in f])
merged = f"{self.processed_dir}/merged.jsonl"
merge_jsonl_files(list(set(files)), merged)
return merged
def process_all(self):
from processors.processors import CodeFilter, Deduplicator, QualityFilter, InstructionFormatter
merged = self.merge_collected_data()
print("\n" + "="*50 + "\nCODE FILTERING\n" + "="*50)
cf_out = f"{self.processed_dir}/code_filtered.jsonl"
CodeFilter(self.config_path).filter_file(merged, cf_out)
print("\n" + "="*50 + "\nDEDUPLICATING\n" + "="*50)
dd_out = f"{self.processed_dir}/deduped.jsonl"
Deduplicator(self.config_path).deduplicate_file(cf_out, dd_out)
print("\n" + "="*50 + "\nQUALITY FILTERING\n" + "="*50)
qf_out = f"{self.processed_dir}/quality.jsonl"
QualityFilter(self.config_path).filter_file(dd_out, qf_out)
print("\n" + "="*50 + "\nFORMATTING\n" + "="*50)
fmt_out = f"{self.processed_dir}/formatted.jsonl"
InstructionFormatter(self.config_path).process_file(qf_out, fmt_out)
return fmt_out
def finalize(self, input_file: str = None):
print("\n" + "="*50 + "\nFINALIZING\n" + "="*50)
input_file = input_file or f"{self.processed_dir}/formatted.jsonl"
from processors.processors import DatasetFinalizer
stats = DatasetFinalizer(self.config_path).finalize(input_file, str(self.final_dir))
save_json({'created': datetime.now().isoformat(), **stats}, f"{self.final_dir}/metadata.json")
print(f"\nDataset ready at {self.final_dir}")
print(f"Train: {stats['train']}, Validation: {stats['validation']}")
def run_full(self, skip_collection=False):
print("\n" + "="*60 + "\nML-SLM DATASET PIPELINE\n" + "="*60)
if not skip_collection:
self.collect_all()
formatted = self.process_all()
self.finalize(formatted)
def main():
parser = argparse.ArgumentParser(description='ML-SLM Dataset Pipeline')
parser.add_argument('--config', default='configs/config.yaml')
parser.add_argument('--full', action='store_true', help='Run full pipeline')
parser.add_argument('--skip-collection', action='store_true')
parser.add_argument('command', nargs='?', choices=['collect', 'process', 'finalize'])
parser.add_argument('--source', type=str)
parser.add_argument('--all', action='store_true')
args = parser.parse_args()
p = Pipeline(args.config)
if args.full:
p.run_full(args.skip_collection)
elif args.command == 'collect':
if args.all:
p.collect_all()
elif args.source:
valid_sources = ['arxiv', 'stackexchange', 'wikipedia', 'distill', 'huggingface']
if args.source not in valid_sources:
print(f"Unknown source: {args.source}")
print(f"Valid sources: {', '.join(valid_sources)}")
sys.exit(1)
getattr(p, f'collect_{args.source}')()
elif args.command == 'process':
p.process_all()
elif args.command == 'finalize':
p.finalize()
else:
parser.print_help()
if __name__ == '__main__':
main()