Skip to content

Commit d26c6e5

Browse files
committed
Update convert.py to check for dependencies on startup
1 parent 73497c5 commit d26c6e5

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

llamacpp/convert.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,21 @@
1717
# and vocabulary.
1818
#
1919

20+
# Check if torch is installed and show and error and exit if not
2021
import sys
2122
import json
2223
import struct
23-
import numpy as np
24-
import torch
2524

26-
from sentencepiece import SentencePieceProcessor
25+
try:
26+
import torch
27+
import numpy as np
28+
from sentencepiece import SentencePieceProcessor
29+
except ImportError:
30+
print("Error: torch, sentencepiece and numpy are required to run this script.")
31+
print("Please install using the following command:")
32+
print(" pip install torch sentencepiece numpy")
33+
sys.exit(1)
34+
2735

2836
def main():
2937
if len(sys.argv) < 3:
@@ -35,7 +43,7 @@ def main():
3543
# output in the same directory as the model
3644
dir_model = sys.argv[1]
3745

38-
fname_hparams = sys.argv[1] + "/params.json"
46+
fname_hparams = sys.argv[1] + "/params.json"
3947
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
4048

4149
def get_n_parts(dim):
@@ -76,35 +84,35 @@ def get_n_parts(dim):
7684
n_parts = get_n_parts(hparams["dim"])
7785

7886
print(hparams)
79-
print('n_parts = ', n_parts)
87+
print("n_parts = ", n_parts)
8088

8189
for p in range(n_parts):
82-
print('Processing part ', p)
90+
print("Processing part ", p)
8391

84-
#fname_model = sys.argv[1] + "/consolidated.00.pth"
92+
# fname_model = sys.argv[1] + "/consolidated.00.pth"
8593
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
8694
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
87-
if (p > 0):
95+
if p > 0:
8896
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
8997

9098
# weights_only requires torch 1.13.1, remove this param or update if you get an "invalid keyword argument" error
9199
model = torch.load(fname_model, map_location="cpu", weights_only=True)
92100

93101
fout = open(fname_out, "wb")
94102

95-
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
103+
fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex
96104
fout.write(struct.pack("i", hparams["vocab_size"]))
97105
fout.write(struct.pack("i", hparams["dim"]))
98106
fout.write(struct.pack("i", hparams["multiple_of"]))
99107
fout.write(struct.pack("i", hparams["n_heads"]))
100108
fout.write(struct.pack("i", hparams["n_layers"]))
101-
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
109+
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
102110
fout.write(struct.pack("i", ftype))
103111

104112
# Is this correct??
105113
for i in range(32000):
106114
# TODO: this is probably wrong - not sure how this tokenizer works
107-
text = tokenizer.decode([29889, i]).encode('utf-8')
115+
text = tokenizer.decode([29889, i]).encode("utf-8")
108116
# remove the first byte (it's always '.')
109117
text = text[1:]
110118
fout.write(struct.pack("i", len(text)))
@@ -120,16 +128,16 @@ def get_n_parts(dim):
120128

121129
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
122130

123-
#data = tf.train.load_variable(dir_model, name).squeeze()
131+
# data = tf.train.load_variable(dir_model, name).squeeze()
124132
data = v.numpy().squeeze()
125-
n_dims = len(data.shape);
133+
n_dims = len(data.shape)
126134

127135
# for efficiency - transpose some matrices
128136
# "model/h.*/attn/c_attn/w"
129137
# "model/h.*/attn/c_proj/w"
130138
# "model/h.*/mlp/c_fc/w"
131139
# "model/h.*/mlp/c_proj/w"
132-
#if name[-14:] == "/attn/c_attn/w" or \
140+
# if name[-14:] == "/attn/c_attn/w" or \
133141
# name[-14:] == "/attn/c_proj/w" or \
134142
# name[-11:] == "/mlp/c_fc/w" or \
135143
# name[-13:] == "/mlp/c_proj/w":
@@ -146,11 +154,11 @@ def get_n_parts(dim):
146154
ftype_cur = 0
147155

148156
# header
149-
sname = name.encode('utf-8')
157+
sname = name.encode("utf-8")
150158
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
151159
for i in range(n_dims):
152160
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
153-
fout.write(sname);
161+
fout.write(sname)
154162

155163
# data
156164
data.tofile(fout)
@@ -163,5 +171,6 @@ def get_n_parts(dim):
163171
print("Done. Output file: " + fname_out + ", (part ", p, ")")
164172
print("")
165173

166-
if __name__ == '__main__':
174+
175+
if __name__ == "__main__":
167176
main()

0 commit comments

Comments
 (0)