|
12 | 12 | "pack_single=1,pack_complete=0,pack_buffer_size=50", |
13 | 13 | "ds_name": "RefactFIMCodeDataset" |
14 | 14 | } |
15 | | -_bigcode_tokenizer_mapping = { |
16 | | - "eot_idx": 0, |
17 | | - "padding_idx": 4, |
18 | | - "fim_prefix": 1, |
19 | | - "fim_middle": 2, |
20 | | - "fim_suffix": 3, |
21 | | - "escape": 14 |
22 | | -} |
23 | | -_starcoder_base = { |
24 | | - "lora_target_modules_mapping": { |
25 | | - "qkv": ["attn.q_attn", "attn.c_attn"], |
26 | | - "out": ["attn.c_proj"], |
27 | | - "backproj": ["attn.c_proj"], |
28 | | - "mlp": ["mlp.c_fc", "mlp.c_proj"], |
29 | | - }, |
30 | | - "freeze_exceptions_mapping": { |
31 | | - "wte": ["wte", "wpe"], |
32 | | - "lm_head": ["lm_head"], |
33 | | - "lora": ["lora"] |
34 | | - }, |
35 | | - "tokenizer": _bigcode_tokenizer_mapping, |
36 | | - "train_ds_pipeline": _fim_train_ds_pipeline, |
37 | | - "test_ds_pipeline": _fim_test_ds_pipeline, |
38 | | - "train_model_modifiers": [ |
39 | | - "flash_sa.apply_flash_mha_to_starcoder_model" |
40 | | - ], |
41 | | - "force_enable_checkpointing": False |
42 | | -} |
43 | | -_starcoder2_base = { |
44 | | - "lora_target_modules_mapping": { |
45 | | - "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], |
46 | | - "out": ["self_attn.o_proj"], |
47 | | - "backproj": ["self_attn.o_proj"], |
48 | | - "mlp": ["mlp.c_fc", "mlp.c_proj"], |
49 | | - }, |
50 | | - "freeze_exceptions_mapping": { |
51 | | - "wte": ["embed_tokens"], |
52 | | - "lm_head": ["lm_head"], |
53 | | - "lora": ["lora"] |
54 | | - }, |
55 | | - "tokenizer": _bigcode_tokenizer_mapping, |
56 | | - "train_ds_pipeline": _fim_train_ds_pipeline, |
57 | | - "test_ds_pipeline": _fim_test_ds_pipeline, |
58 | | - "train_model_modifiers": [ |
59 | | - "flash_sa.apply_flash_mha_to_starcoder2_model" |
60 | | - ], |
61 | | - "force_enable_checkpointing": True |
62 | | -} |
63 | | -_deepseek_base = { |
64 | | - "lora_target_modules_mapping": { |
65 | | - "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], |
66 | | - "out": ["self_attn.o_proj"], |
67 | | - "backproj": ["self_attn.o_proj"], |
68 | | - "mlp": ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"], |
69 | | - }, |
70 | | - "freeze_exceptions_mapping": { |
71 | | - "wte": ["embed_tokens"], |
72 | | - "lm_head": ["lm_head"], |
73 | | - "lora": ["lora"] |
74 | | - }, |
75 | | - "tokenizer": { |
76 | | - "eot_idx": 32021, # `<|EOT|>` |
77 | | - "padding_idx": 32018, # `<pad>` |
78 | | - "fim_prefix": 32016, # `<|fim▁begin|>` |
79 | | - "fim_middle": 32017, # `<|fim▁end|>` |
80 | | - "fim_suffix": 32015, # `<|fim▁hole|>` |
81 | | - "escape": 32013, # using `<|begin▁of▁sentence|>` token for now |
82 | | - }, |
83 | | - "train_ds_pipeline": { |
84 | | - "ds_opts": f"{_fim_train_ds_pipeline['ds_opts']},spm_prob=0.0", |
85 | | - "ds_name": _fim_train_ds_pipeline["ds_name"] |
86 | | - }, |
87 | | - "test_ds_pipeline": _fim_test_ds_pipeline, |
88 | | - "train_model_modifiers": [ |
89 | | - "flash_sa.apply_flash_mha_to_codellama_model" |
90 | | - ], |
91 | | - "force_enable_checkpointing": False |
92 | | -} |
| 15 | + |
93 | 16 | _qwen_base = { |
94 | 17 | "lora_target_modules_mapping": { |
95 | 18 | "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], |
|
122 | 45 | } |
123 | 46 |
|
124 | 47 | config = { |
125 | | - "Refact/1.6B": { |
126 | | - "lora_target_modules_mapping": { |
127 | | - "qkv": ["attn.q", "attn.kv"], |
128 | | - "out": ["attn.c_proj"], |
129 | | - "backproj": ["attn.c_proj"], |
130 | | - "mlp": ["mlp.gate_up_proj", "mlp.c_proj"], |
131 | | - }, |
132 | | - "freeze_exceptions_mapping": { |
133 | | - "wte": ["wte"], |
134 | | - "lm_head": ["lm_head"], |
135 | | - "lora": ["lora"] |
136 | | - }, |
137 | | - "tokenizer": _bigcode_tokenizer_mapping, |
138 | | - "train_ds_pipeline": _fim_train_ds_pipeline, |
139 | | - "test_ds_pipeline": _fim_test_ds_pipeline, |
140 | | - "train_model_modifiers": [ |
141 | | - "flash_sa.apply_flash_mha_to_refact_model" |
142 | | - ], |
143 | | - "force_enable_checkpointing": False |
144 | | - }, |
145 | | - |
146 | | - "starcoder/1b/base": _starcoder_base, |
147 | | - |
148 | | - "starcoder/3b/base": _starcoder_base, |
149 | | - |
150 | | - "starcoder/7b/base": { |
151 | | - **_starcoder_base, |
152 | | - "force_enable_checkpointing": True |
153 | | - }, |
154 | | - |
155 | | - "starcoder2/3b/base": _starcoder2_base, |
156 | | - |
157 | | - "starcoder2/7b/base": { |
158 | | - **_starcoder2_base, |
159 | | - "force_enable_checkpointing": True |
160 | | - }, |
161 | | - |
162 | | - "starcoder2/15b/base": { |
163 | | - **_starcoder2_base, |
164 | | - "force_enable_checkpointing": True |
165 | | - }, |
166 | | - |
167 | | - "codellama/7b": { |
168 | | - "lora_target_modules_mapping": { |
169 | | - "qkv": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], |
170 | | - "out": ["self_attn.o_proj"], |
171 | | - "backproj": ["self_attn.o_proj"], |
172 | | - "mlp": ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"], |
173 | | - }, |
174 | | - "freeze_exceptions_mapping": { |
175 | | - "wte": ["embed_tokens"], |
176 | | - "lm_head": ["lm_head"], |
177 | | - "lora": ["lora"] |
178 | | - }, |
179 | | - "tokenizer": { |
180 | | - "eot_idx": 32010, |
181 | | - "padding_idx": 2, # there is no padding token, so instead using `eos` token as in `gpt2` |
182 | | - "fim_prefix": 32007, |
183 | | - "fim_middle": 32009, |
184 | | - "fim_suffix": 32008, |
185 | | - "escape": 0, # using <unk> token |
186 | | - "bos_idx": 1 |
187 | | - }, |
188 | | - "train_ds_pipeline": { |
189 | | - **_fim_train_ds_pipeline, |
190 | | - "ds_name": "CodeLLamaFIMDataset" |
191 | | - }, |
192 | | - "test_ds_pipeline": { |
193 | | - **_fim_test_ds_pipeline, |
194 | | - "ds_name": "CodeLLamaFIMDataset" |
195 | | - }, |
196 | | - "train_model_modifiers": [ |
197 | | - "flash_sa.apply_flash_mha_to_codellama_model" |
198 | | - ], |
199 | | - "force_enable_checkpointing": True |
200 | | - }, |
201 | | - |
202 | | - "deepseek-coder/1.3b/base": _deepseek_base, |
203 | | - |
204 | | - "deepseek-coder/5.7b/mqa-base": { |
205 | | - **_deepseek_base, |
206 | | - "force_enable_checkpointing": True |
207 | | - }, |
208 | | - |
209 | | - "deepseek-coder/6.7b/base": { |
210 | | - **_deepseek_base, |
211 | | - "force_enable_checkpointing": True |
212 | | - }, |
213 | | - |
214 | 48 | # qwen models |
215 | 49 | "qwen2.5/coder/32b/base": { |
216 | 50 | **_qwen_base, |
|
0 commit comments