Skip to content

Commit 9592e1f

Browse files
committed
fix: uint16→uint32 dtype detection for large-vocab tokenizers
Root cause: Qwen tokenizer (vocab 151K > 65535) needs uint32 but data loaders only checked tokenizer_type=='tiktoken'. Now checks vocab_size > 65535 from meta.pkl in all data loaders.
1 parent 36ab1af commit 9592e1f

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

supergpt/training/distill.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,20 @@ def distill(
159159
print(f" Student: {student_params/1e6:.1f}M params "
160160
f"({teacher_params/student_params:.1f}× compression)")
161161

162-
# Load data
162+
# Load data — auto-detect dtype from meta.pkl
163163
block_size = teacher_config.block_size
164+
data_dtype = np.uint16 # default
165+
meta_path = os.path.join(data_dir, "meta.pkl")
166+
if os.path.exists(meta_path):
167+
import pickle
168+
with open(meta_path, "rb") as f:
169+
meta = pickle.load(f)
170+
if meta.get("vocab_size", 0) > 65535 or meta.get("tokenizer_type") == "tiktoken":
171+
data_dtype = np.uint32
164172
train_data = np.memmap(os.path.join(data_dir, "train.bin"),
165-
dtype=np.uint16, mode="r")
173+
dtype=data_dtype, mode="r")
166174
val_data = np.memmap(os.path.join(data_dir, "val.bin"),
167-
dtype=np.uint16, mode="r")
175+
dtype=data_dtype, mode="r")
168176

169177
def get_batch(split):
170178
data = train_data if split == "train" else val_data

supergpt/training/finetune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def load_data(data_dir: str, split: str, block_size: int, batch_size: int, devic
4141
if os.path.exists(meta_path):
4242
with open(meta_path, "rb") as f:
4343
meta = pickle.load(f)
44-
if meta.get("tokenizer_type") == "tiktoken":
44+
vocab_size = meta.get("vocab_size", 0)
45+
if vocab_size > 65535 or meta.get("tokenizer_type") == "tiktoken":
4546
dtype = np.uint32
4647
else:
4748
dtype = np.uint16

supergpt/training/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,13 @@ def load_data(data_dir: str, split: str, block_size: int, batch_size: int, devic
343343
"""Load a batch of data from the memory-mapped binary file."""
344344
data_path = os.path.join(data_dir, f"{split}.bin")
345345

346-
# Detect dtype from meta.pkl
346+
# Detect dtype from meta.pkl — use uint32 if vocab > 65535
347347
meta_path = os.path.join(data_dir, "meta.pkl")
348348
if os.path.exists(meta_path):
349349
with open(meta_path, "rb") as f:
350350
meta = pickle.load(f)
351-
if meta.get("tokenizer_type") == "tiktoken":
351+
vocab_size = meta.get("vocab_size", 0)
352+
if vocab_size > 65535 or meta.get("tokenizer_type") == "tiktoken":
352353
dtype = np.uint32
353354
else:
354355
dtype = np.uint16

0 commit comments

Comments
 (0)