forked from xjdr-alt/entropix
-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmlx_main.py
More file actions
77 lines (71 loc) · 3.65 KB
/
mlx_main.py
File metadata and controls
77 lines (71 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import time
import mlx.core as mx
import pathlib
from typing import Optional, List, Tuple, Union
from mlx_model import load_entropix_model
from mlx_lm.utils import load
from mlx_lm.tokenizer_utils import load_tokenizer
from mlx_lm.utils import generate as generate_mlx_lm
from mlx_generate import generate
from mlx_sampler import SamplerConfig
from mlx_prompts import create_prompts_from_csv, prompt1, prompt2, prompt3, prompt4, prompt5, thinking_prompt, o1_claude_prompt
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description = "Generate text using Entropy based sampling based on input prompts")
parser.add_argument("--prompts", action="store_true", help = "Use predefined prompts from mlx_entropix.prompts")
parser.add_argument("--prompt_csv", action="store_true", help = "Use prompts from data/prompts.csv")
parser.add_argument("--input", type = str, help = "Input prompt to generate text")
parser.add_argument("--weights_dir", type = str, help = "Directory containing the model weights", default = "weights/Llama-3.2-1B-Instruct")
parser.add_argument("--entropix", action="store_true", default=True, help="Use Entropix model for generation")
parser.add_argument("--normal", action="store_true", help="Use normal model for generation")
args = parser.parse_args()
if not args.prompts and not args.input and not args.prompt_csv:
print("No input provided. Use --prompts to use predefined prompts from mlx_entropix.prompts or provide a custom prompt using --input")
print("Exiting...")
exit()
if args.normal:
model, tokenizer = load(args.weights_dir)
model_with_scores = False
sample_config_kwargs = SamplerConfig()
sample_config_kwargs = {
"temp": sample_config_kwargs.temp,
"top_p": sample_config_kwargs.top_p,
"min_p": sample_config_kwargs.min_p,
}
else:
path = Path(args.weights_dir)
tokenizer = load_tokenizer(path)
model = load_entropix_model(path)
model_with_scores = True
max_tokens = 4096
if args.input:
prompt = thinking_prompt.format(query = args.input)
if args.normal:
response = generate_mlx_lm(model, tokenizer, prompt=prompt, verbose=True, max_tokens = max_tokens, **sample_config_kwargs)
else:
response = generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens = max_tokens)
elif args.prompt_csv:
prompts = create_prompts_from_csv("data/prompts.csv")
for prompt in prompts:
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
if args.normal:
response = generate_mlx_lm(model, tokenizer, prompt=prompt, verbose=True, max_tokens = max_tokens, **sample_config_kwargs)
else:
response = generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens = max_tokens)
elif args.prompts:
prompts = [prompt1, prompt2, prompt3, prompt4, prompt5]
for prompt in prompts:
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
if args.normal:
response = generate_mlx_lm(model, tokenizer, prompt=prompt, verbose=True, max_tokens = max_tokens, **sample_config_kwargs)
else:
response = generate(model, tokenizer, prompt=prompt, verbose=True, max_tokens = max_tokens)
if __name__ == "__main__":
main()