Skip to content

Commit 25fda45

Browse files
committed
housekeeping: various formatting tasks
1 parent afbbf4c commit 25fda45

File tree

7 files changed

+241
-132
lines changed

7 files changed

+241
-132
lines changed

.vscode/settings.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
{
22
"[python]": {
3-
"editor.defaultFormatter": "ms-python.autopep8"
3+
"editor.defaultFormatter": "ms-python.python",
4+
"editor.formatOnSave": true,
5+
"editor.formatOnSaveMode": "modifications",
6+
"editor.formatOnType": true,
7+
"editor.formatOnPaste": true,
8+
"editor.codeActionsOnSave": {
9+
"source.organizeImports": false
10+
}
411
}
512
}

app/models/improved_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def __init__(self, agent_id: int, num_tasks: int, num_features: int):
3838
self.policy_net: DQN = DQN(num_features).to(self.device)
3939
self.target_net.load_state_dict(self.policy_net.state_dict())
4040
self.target_net.eval()
41-
self.optimizer = optim.RMSprop(self.policy_net.parameters())
42-
self.memory = ReplayMemory(10000)
41+
self.optimizer: optim.RMSprop = optim.RMSprop(self.policy_net.parameters())
42+
self.memory: ReplayMemory = ReplayMemory(10000)
4343
self.previous_loss: List[float] = [0] * num_tasks
4444

4545
def observe(self, task: int) -> Tuple[float, float]:

app/models/llama/attention.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,16 @@ class Attention:
3333
wo (Linear): Linear layer for output projection.
3434
3535
Methods:
36-
__call__(self, x: Tensor, cache_k: Optional[Tensor], cache_v: Optional[Tensor], start_pos: int, freqs_cis: Tensor, mask: Optional[Tensor], jit_ctx: Optional[Dict[Variable, int]] = None) -> Tuple[Tensor, Tensor, Tensor]:
36+
__call__(
37+
self,
38+
x: Tensor,
39+
cache_k: Optional[Tensor],
40+
cache_v: Optional[Tensor],
41+
start_pos: int,
42+
freqs_cis: Tensor,
43+
mask: Optional[Tensor],
44+
jit_ctx: Optional[Dict[Variable, int]] = None
45+
) -> Tuple[Tensor, Tensor, Tensor]:
3746
Apply multi-head attention to the input sequence `x`.
3847
"""
3948

@@ -48,7 +57,16 @@ def __init__(self, dim, n_heads, n_kv_heads, linear=Linear):
4857
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
4958
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
5059

51-
def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor], start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor], jit_ctx:Optional[Dict[Variable,int]]=None) -> Tuple[Tensor, Tensor, Tensor]:
60+
def __call__(
61+
self,
62+
x: Tensor,
63+
cache_k: Optional[Tensor],
64+
cache_v: Optional[Tensor],
65+
start_pos: int,
66+
freqs_cis: Tensor,
67+
mask: Optional[Tensor],
68+
jit_ctx: Optional[Dict[Variable, int]] = None
69+
) -> Tuple[Tensor, Tensor, Tensor]:
5270
"""
5371
Apply multi-head attention to the input sequence `x`.
5472
@@ -85,5 +103,8 @@ def __call__(self, x:Tensor, cache_k:Optional[Tensor], cache_v:Optional[Tensor],
85103

86104
cache_k, cache_v = keys, values
87105
keys, values = repeat_kv(keys, self.n_rep).realize(), repeat_kv(values, self.n_rep).realize()
88-
attn = Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask).transpose(1, 2).reshape(bsz, seqlen, -1)
106+
attn = (Tensor.scaled_dot_product_attention(xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), mask)
107+
.transpose(1, 2)
108+
.reshape(bsz, seqlen, -1))
109+
89110
return self.wo(attn).realize(), cache_k.realize(), cache_v.realize()

app/models/llama/constants.py

Lines changed: 151 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,71 +5,206 @@
55
MODEL_PARAMS = {
66
"1": {
77
"7B": {
8-
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000},
8+
"args": {
9+
"dim": 4096,
10+
"multiple_of": 256,
11+
"n_heads": 32,
12+
"n_layers": 32,
13+
"norm_eps": 1e-06,
14+
"vocab_size": 32000
15+
},
916
"files": 1,
1017
},
1118
"13B": {
12-
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000},
19+
"args": {
20+
"dim": 5120,
21+
"multiple_of": 256,
22+
"n_heads": 40,
23+
"n_layers": 40,
24+
"norm_eps": 1e-06,
25+
"vocab_size": 32000
26+
},
1327
"files": 2,
1428
},
1529
"30B": {
16-
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000},
30+
"args": {
31+
"dim": 6656,
32+
"multiple_of": 256,
33+
"n_heads": 52,
34+
"n_layers": 60,
35+
"norm_eps": 1e-06,
36+
"vocab_size": 32000
37+
},
1738
"files": 4,
1839
},
1940
"65B": {
20-
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
41+
"args": {
42+
"dim": 8192,
43+
"multiple_of": 256,
44+
"n_heads": 64,
45+
"n_layers": 80,
46+
"norm_eps": 1e-05,
47+
"vocab_size": 32000
48+
},
2149
"files": 8,
2250
},
2351
},
2452
"2": {
2553
"7B": {
26-
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000},
54+
"args": {
55+
"dim": 4096,
56+
"multiple_of": 256,
57+
"n_heads": 32,
58+
"n_layers": 32,
59+
"norm_eps": 1e-05,
60+
"vocab_size": 32000
61+
},
2762
"files": 1,
2863
},
2964
"13B": {
30-
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000},
65+
"args": {
66+
"dim": 5120,
67+
"multiple_of": 256,
68+
"n_heads": 40,
69+
"n_layers": 40,
70+
"norm_eps": 1e-05,
71+
"vocab_size": 32000
72+
},
3173
"files": 2,
3274
},
3375
"70B": {
34-
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
76+
"args": {
77+
"dim": 8192,
78+
"multiple_of": 4096,
79+
"ffn_dim_multiplier": 1.3,
80+
"n_heads": 64,
81+
"n_kv_heads": 8,
82+
"n_layers": 80,
83+
"norm_eps": 1e-05,
84+
"vocab_size": 32000
85+
},
3586
"files": 8,
3687
},
3788
},
3889
"code": {
3990
"7B": {
40-
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
91+
"args": {
92+
"dim": 4096,
93+
"n_layers": 32,
94+
"n_heads": 32,
95+
"multiple_of": 256,
96+
"ffn_dim_multiplier": 1.0,
97+
"norm_eps": 1e-5,
98+
"rope_theta": 1000000,
99+
"vocab_size": 32016
100+
},
41101
"files": 1,
42102
},
43103
"7B-Python": {
44-
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
104+
"args": {
105+
"dim": 4096,
106+
"n_layers": 32,
107+
"n_heads": 32,
108+
"multiple_of": 256,
109+
"ffn_dim_multiplier": 1.0,
110+
"norm_eps": 1e-5,
111+
"rope_theta": 1000000,
112+
"vocab_size": 32000
113+
},
45114
"files": 1,
46115
},
47116
"7B-Instruct": {
48-
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
117+
"args": {
118+
"dim": 4096,
119+
"n_layers": 32,
120+
"n_heads": 32,
121+
"multiple_of": 256,
122+
"ffn_dim_multiplier": 1.0,
123+
"norm_eps": 1e-5,
124+
"rope_theta": 1000000,
125+
"vocab_size": 32016
126+
},
49127
"files": 1,
50128
},
51129
"13B": {
52-
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
130+
"args": {
131+
"dim": 5120,
132+
"n_layers": 40,
133+
"n_heads": 40,
134+
"multiple_of": 256,
135+
"ffn_dim_multiplier": 1.0,
136+
"norm_eps": 1e-5,
137+
"rope_theta": 1000000,
138+
"vocab_size": 32016
139+
},
53140
"files": 2,
54141
},
55142
"13B-Python": {
56-
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
143+
"args": {
144+
"dim": 5120,
145+
"n_layers": 40,
146+
"n_heads": 40,
147+
"multiple_of": 256,
148+
"ffn_dim_multiplier": 1.0,
149+
"norm_eps": 1e-5,
150+
"rope_theta": 1000000,
151+
"vocab_size": 32000
152+
},
57153
"files": 2,
58154
},
59155
"13B-Instruct": {
60-
"args": {"dim": 5120, "n_layers": 40, "n_headvocab_sizes": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
156+
"args": {
157+
"dim": 5120,
158+
"n_layers": 40,
159+
"n_headvocab_sizes": 40,
160+
"multiple_of": 256,
161+
"ffn_dim_multiplier": 1.0,
162+
"norm_eps": 1e-5,
163+
"rope_theta": 1000000,
164+
"vocab_size": 32000
165+
},
61166
"files": 2,
62167
},
63168
"34B": {
64-
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
169+
"args": {
170+
"dim": 8192,
171+
"n_layers": 48,
172+
"n_heads": 64,
173+
"n_kv_heads": 8,
174+
"multiple_of": 256,
175+
"ffn_dim_multiplier": 1.0,
176+
"norm_eps": 1e-5,
177+
"rope_theta": 1000000,
178+
"vocab_size": 32016
179+
},
65180
"files": 4,
66181
},
67182
"34B-Python": {
68-
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
183+
"args": {
184+
"dim": 8192,
185+
"n_layers": 48,
186+
"n_heads": 64,
187+
"n_kv_heads": 8,
188+
"multiple_of": 256,
189+
"ffn_dim_multiplier": 1.0,
190+
"norm_eps": 1e-5,
191+
"rope_theta": 1000000,
192+
"vocab_size": 32000
193+
},
69194
"files": 4,
70195
},
71196
"34B-Instruct": {
72-
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
197+
"args": {
198+
"dim": 8192,
199+
"n_layers": 48,
200+
"n_heads": 64,
201+
"n_kv_heads": 8,
202+
"multiple_of": 256,
203+
"ffn_dim_multiplier": 1.0,
204+
"norm_eps": 1e-5,
205+
"rope_theta": 1000000,
206+
"vocab_size": 32000
207+
},
73208
"files": 4,
74209
},
75210
}

0 commit comments

Comments
 (0)