Skip to content

Commit 1680717

Browse files
committed
update
Signed-off-by: guangli.bao <[email protected]>
1 parent d863ca8 commit 1680717

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

src/guidellm/utils/preprocessing_sharegpt_data.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,43 @@
33
import os
44
import re
55
from pathlib import Path
6+
from typing import Callable, Optional
67

78
import numpy as np
89
from datasets import load_dataset
910
from transformers import AutoTokenizer
1011

1112
MIN_CHAR = 10
1213
MAX_CHAR = 1000
13-
_tokenizer = None
1414

1515

16-
def estimate_num_tokens(text: str) -> int:
17-
_tokenizer: AutoTokenizer
18-
if _tokenizer is None:
19-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
20-
_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
21-
return len(_tokenizer(text, return_tensors=None))
16+
def create_token_estimator(
17+
model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
18+
) -> Callable[[str], int]:
19+
_tokenizer: Optional[AutoTokenizer] = None
20+
21+
def initialize() -> None:
22+
nonlocal _tokenizer
23+
if _tokenizer is None:
24+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
25+
try:
26+
_tokenizer = AutoTokenizer.from_pretrained(model_name)
27+
except (OSError, ImportError, ValueError) as e:
28+
raise RuntimeError(f"Failed to initialize tokenizer: {e}") from e
29+
30+
def estimate_num_tokens(text: str) -> int:
31+
initialize()
32+
33+
if _tokenizer is None:
34+
return 0
35+
36+
try:
37+
encoding = _tokenizer(text, return_tensors=None)
38+
return len(encoding["input_ids"])
39+
except (AttributeError, TypeError, RuntimeError) as e:
40+
raise ValueError(f"Error processing text: {e}") from e
41+
42+
return estimate_num_tokens
2243

2344

2445
def extract_and_save_with_filtering(file):
@@ -72,6 +93,7 @@ def extract_and_save_with_filtering(file):
7293
with Path(sharegpt_file).open("r", encoding="utf-8") as file:
7394
data = json.load(file)
7495

96+
estimate_tokens = create_token_estimator()
7597
num_of_ids = len(data)
7698
data = data[: int(num_of_ids * args.parse)]
7799
for d in data:
@@ -80,9 +102,9 @@ def extract_and_save_with_filtering(file):
80102
gpt_tokens = []
81103
for conv in d["conversations"]:
82104
if conv["from"] == "human":
83-
human_tokens.append(estimate_num_tokens(conv["value"]))
105+
human_tokens.append(estimate_tokens(conv["value"]))
84106
if conv["from"] == "gpt":
85-
token_number = estimate_num_tokens(conv["value"])
107+
token_number = estimate_tokens(conv["value"])
86108
conv["num_tokens"] = token_number
87109
gpt_tokens.append(token_number)
88110
if len(human_tokens) == 0:

0 commit comments

Comments
 (0)