33import os
44import re
55from pathlib import Path
6+ from typing import Callable , Optional
67
78import numpy as np
89from datasets import load_dataset
910from transformers import AutoTokenizer
1011
1112MIN_CHAR = 10
1213MAX_CHAR = 1000
13- _tokenizer = None
1414
1515
16- def estimate_num_tokens (text : str ) -> int :
17- _tokenizer : AutoTokenizer
18- if _tokenizer is None :
19- os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
20- _tokenizer = AutoTokenizer .from_pretrained ("mistralai/Mistral-7B-Instruct-v0.2" )
21- return len (_tokenizer (text , return_tensors = None ))
16+ def create_token_estimator (
17+ model_name : str = "mistralai/Mistral-7B-Instruct-v0.2" ,
18+ ) -> Callable [[str ], int ]:
19+ _tokenizer : Optional [AutoTokenizer ] = None
20+
21+ def initialize () -> None :
22+ nonlocal _tokenizer
23+ if _tokenizer is None :
24+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
25+ try :
26+ _tokenizer = AutoTokenizer .from_pretrained (model_name )
27+ except (OSError , ImportError , ValueError ) as e :
28+ raise RuntimeError (f"Failed to initialize tokenizer: { e } " ) from e
29+
30+ def estimate_num_tokens (text : str ) -> int :
31+ initialize ()
32+
33+ if _tokenizer is None :
34+ return 0
35+
36+ try :
37+ encoding = _tokenizer (text , return_tensors = None )
38+ return len (encoding ["input_ids" ])
39+ except (AttributeError , TypeError , RuntimeError ) as e :
40+ raise ValueError (f"Error processing text: { e } " ) from e
41+
42+ return estimate_num_tokens
2243
2344
2445def extract_and_save_with_filtering (file ):
@@ -72,6 +93,7 @@ def extract_and_save_with_filtering(file):
7293 with Path (sharegpt_file ).open ("r" , encoding = "utf-8" ) as file :
7394 data = json .load (file )
7495
96+ estimate_tokens = create_token_estimator ()
7597 num_of_ids = len (data )
7698 data = data [: int (num_of_ids * args .parse )]
7799 for d in data :
@@ -80,9 +102,9 @@ def extract_and_save_with_filtering(file):
80102 gpt_tokens = []
81103 for conv in d ["conversations" ]:
82104 if conv ["from" ] == "human" :
83- human_tokens .append (estimate_num_tokens (conv ["value" ]))
105+ human_tokens .append (estimate_tokens (conv ["value" ]))
84106 if conv ["from" ] == "gpt" :
85- token_number = estimate_num_tokens (conv ["value" ])
107+ token_number = estimate_tokens (conv ["value" ])
86108 conv ["num_tokens" ] = token_number
87109 gpt_tokens .append (token_number )
88110 if len (human_tokens ) == 0 :
0 commit comments