Skip to content

Commit eb825c1

Browse files
authored
Fix #1474 - AssertionError:assert param_slice.shape == loaded_weight.shape (#1631)
1 parent 1b290ac commit eb825c1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vllm/model_executor/models/gpt_j.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def load_weights(self,
250250
if att_weight_name not in name:
251251
continue
252252
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
253-
shard_size = param.shape[1]
253+
shard_size = param.shape[0] // 3
254254
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
255255
(tp_rank + 1)]
256256
param_slice = param.data[shard_size * stride_id:shard_size *

0 commit comments

Comments
 (0)