Skip to content

Commit 646fdc5

Browse files
authored
Merge pull request #16 from jesusmb1995/jmb/tune_scripts2
Quality and Speed tuning scripts
2 parents 1473dce + ae1d001 commit 646fdc5

File tree

6 files changed

+670
-0
lines changed

6 files changed

+670
-0
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ exclude =
1515
build,
1616
# This contains builds that we don't want to check
1717
dist # This is generated with `python build .` for package releases
18+
scripts/tune
1819
# max-complexity = 10

pyrightconfig.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@
1919
"pythonVersion": "3.10",
2020
},
2121
],
22+
"exclude": [
23+
"scripts/tune"
24+
]
2225
}

scripts/tune/tune.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Optimize runtime parameters for llama-simple binary using eval time measurements.
4+
Usage: python tune_tps.py --model /path/to/model.gguf
5+
"""
6+
import os
7+
import time
8+
import argparse
9+
from functools import partial
10+
11+
import numpy as np
12+
# pip install scikit-optimize
13+
from skopt import gp_minimize, expected_minimum
14+
from skopt.plots import plot_objective, plot_convergence
15+
from skopt.space import Categorical
16+
import matplotlib.pyplot as plt
17+
import json
18+
19+
BAD_CONFIGURATIONS = []
20+
21+
# Progress tracking global variables
22+
progress_start_time = None
23+
progress_current_call = 0
24+
progress_total_calls = 0
25+
progress_best_score = float('inf')
26+
27+
def display_progress():
28+
"""Display current optimization progress with time estimates."""
29+
global progress_start_time, progress_current_call, progress_total_calls, progress_best_score
30+
31+
if progress_start_time is None:
32+
return
33+
34+
elapsed_time = time.time() - progress_start_time
35+
if progress_current_call > -1:
36+
avg_time_per_call = elapsed_time / progress_current_call
37+
remaining_calls = progress_total_calls - progress_current_call
38+
estimated_remaining_time = avg_time_per_call * remaining_calls
39+
40+
progress_percent = (progress_current_call / progress_total_calls) * 100
41+
42+
print(f"\n{'='*60}")
43+
print(f"OPTIMIZATION PROGRESS")
44+
print(f"{'='*60}")
45+
print(f"Iteration: {progress_current_call}/{progress_total_calls} ({progress_percent:.1f}%)")
46+
print(f"Elapsed time: {elapsed_time:.1f}s")
47+
print(f"Est. remaining time: {estimated_remaining_time:.1f}s")
48+
print(f"Best metric so far: {progress_best_score:.4f}")
49+
print(f"{'='*60}\n")
50+
51+
def run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, binary_path="./build/bin/llama-cli", iterations=1):
52+
"""Run llama-siple with specified options and return eval time."""
53+
try:
54+
run_options_str = get_opts_fn(run_options, model_path, binary_path)
55+
print(run_options_str)
56+
57+
results = []
58+
59+
# Run the test (can increase iterations for more stable results)
60+
for _ in range(iterations):
61+
results.append(run_binary_fn(run_options_str))
62+
63+
# Return eval time as the objective (we want to minimize this)
64+
return np.mean(results)
65+
66+
except Exception as e:
67+
BAD_CONFIGURATIONS.append(run_options)
68+
print("ERROR:", e, run_options)
69+
print("BAD_CONFIGURATIONS:", BAD_CONFIGURATIONS)
70+
return 1000 # High penalty for failed runs
71+
72+
73+
def optimize_runtime_with_progress(x, get_opts_fn, run_binary_fn, run_options_list, model_path, llama_simple_path):
74+
"""Objective function for optimization with progress tracking."""
75+
global progress_current_call, progress_best_score
76+
77+
progress_current_call += 1
78+
79+
run_options = {
80+
run_options_list[i][0]: run_options_list[i][1][run_options_list[i][1].index(x[i])]
81+
for i in range(len(run_options_list))
82+
}
83+
84+
result = run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, llama_simple_path)
85+
86+
# Update best score
87+
if result < progress_best_score:
88+
progress_best_score = result
89+
90+
# Display progress every call
91+
display_progress()
92+
93+
return result
94+
95+
96+
def load_cache(cache_filename):
97+
"""Load cached optimization results."""
98+
try:
99+
with open(cache_filename, "r") as cache_file:
100+
cache_data = json.load(cache_file)
101+
return cache_data["x0"], cache_data["y0"]
102+
except:
103+
pass
104+
return None, None
105+
106+
107+
def save_cache(cache_filename, x0, y0):
108+
"""Save optimization results to cache."""
109+
# Convert numpy int64 objects to Python int objects
110+
x0 = [[int(item) if isinstance(item, np.int64) else item for item in sublist] for sublist in x0]
111+
y0 = [int(item) if isinstance(item, np.int64) else item for item in y0]
112+
113+
cache_data = {"x0": x0, "y0": y0}
114+
with open(cache_filename, "w") as cache_file:
115+
json.dump(cache_data, cache_file)
116+
117+
118+
def plot_iterations(result):
119+
"""Plot optimization iterations."""
120+
search_space = result.space
121+
x_iters = result.x_iters
122+
func_vals = result.func_vals
123+
search_space_names = [dim.name for dim in search_space]
124+
opts = search_space_names + ["objective_r"]
125+
126+
num_params = len(opts) + 1
127+
fig, axs = plt.subplots(num_params, figsize=(8, num_params * 8), sharex=True)
128+
iterations = list(range(1, len(x_iters) + 1))
129+
130+
for i, param in enumerate(opts):
131+
if param == "objective_r":
132+
param_values = func_vals
133+
else:
134+
param_index = search_space_names.index(param)
135+
param_values = [x[param_index] for x in x_iters]
136+
137+
axs[i].scatter(iterations, param_values)
138+
axs[i].set_xlabel("Iteration")
139+
axs[i].set_ylabel(param)
140+
141+
plot_convergence(result, true_minimum=0, ax=axs[-1])
142+
return axs
143+
144+
def parse_args(default_bin):
145+
parser = argparse.ArgumentParser(description='Optimize llama-simple runtime parameters')
146+
parser.add_argument('--model', '-m', required=True, help='Path to the GGUF model file')
147+
parser.add_argument('--ngl', type=int, required=True, help='Max number of GPU layers')
148+
parser.add_argument('--llama-binary', default=default_bin,
149+
help='Path to llama-simple binary (default: ./build/bin/llama-simple)')
150+
parser.add_argument('--n-calls', type=int, default=50,
151+
help='Number of optimization calls (default: 20)')
152+
parser.add_argument('--cache', default='cache_simple.json',
153+
help='Cache file name (default: cache_simple.json)')
154+
parser.add_argument('--single-execution', type=str,
155+
help='Run single execution with specified options (format: "--param1=value1 --param2=value2")')
156+
157+
args = parser.parse_args()
158+
return args
159+
160+
def main(args, get_opts_fn, run_binary_fn, run_options_list):
161+
162+
# Check if llama-simple binary exists
163+
if not os.path.exists(args.llama_binary):
164+
print(f"Error: llama-simple binary not found at {args.llama_binary}")
165+
print("Please build llama.cpp first or specify correct path with --llama-binary")
166+
return
167+
168+
# Check if model exists
169+
if not os.path.exists(args.model):
170+
print(f"Error: Model file not found at {args.model}")
171+
return
172+
173+
# Handle single execution mode
174+
if args.single_execution:
175+
try:
176+
print("Single execution")
177+
run_options = args.single_execution
178+
run_iterations(get_opts_fn, run_binary_fn, run_options, args.model, args.llama_binary)
179+
return
180+
except ValueError as e:
181+
print(f"Error parsing single execution options: {e}")
182+
return
183+
184+
# Initialize progress tracking
185+
global progress_start_time, progress_total_calls
186+
progress_start_time = time.time()
187+
progress_total_calls = args.n_calls
188+
189+
# Create optimization dimensions
190+
dimensions = [Categorical(opt[1]) for opt in run_options_list]
191+
for i, opt in enumerate(run_options_list):
192+
dimensions[i].name = opt[0]
193+
194+
# Load cache
195+
x0, y0 = load_cache(args.cache)
196+
197+
# Create objective function
198+
objective_function = partial(optimize_runtime_with_progress,
199+
get_opts_fn=get_opts_fn,
200+
run_binary_fn=run_binary_fn,
201+
run_options_list=run_options_list,
202+
model_path=args.model,
203+
llama_simple_path=args.llama_binary)
204+
205+
print(f"Starting optimization with {args.n_calls} calls and {args.ngl} gpu layers...")
206+
print(f"Using model: {args.model}")
207+
print(f"Cache file: {args.cache}")
208+
209+
# Run optimization
210+
result = gp_minimize(objective_function, dimensions,
211+
n_calls=args.n_calls,
212+
n_initial_points=min(10, args.n_calls),
213+
random_state=42,
214+
x0=x0, y0=y0,
215+
initial_point_generator="lhs")
216+
217+
# Save results
218+
save_cache(args.cache, result.x_iters, result.func_vals)
219+
220+
# Print results
221+
print(f"\nBest options found: {result.x}")
222+
print(f"Minimum eval time: {result.fun:.4f} seconds")
223+
224+
# Convert result.x back to human-readable format - FIX: Find index of value in options list
225+
best_options = {}
226+
for i, (name, options) in enumerate(run_options_list):
227+
# Find the value in result.x[i] and locate its index in the options list
228+
value = result.x[i]
229+
if value in options:
230+
best_options[name] = value
231+
else:
232+
# Fallback: use the first option if value not found
233+
print(f"Warning: Value '{value}' not found in options for {name}, using first option")
234+
best_options[name] = options[0]
235+
236+
print("\nBest configuration:")
237+
for name, value in best_options.items():
238+
print(f" {name}: {value}")
239+
240+
min_x, _ = expected_minimum(result)
241+
print(f"Expected minimum: {min_x}")
242+
243+
if BAD_CONFIGURATIONS:
244+
print(f"\nBAD_CONFIGURATIONS: {len(BAD_CONFIGURATIONS)}")
245+
246+
# Plot results
247+
try:
248+
plot_iterations(result)
249+
plot_objective(result)
250+
# Might need PyQt6
251+
plt.show()
252+
except Exception as e:
253+
print(f"Plotting failed: {e}")

0 commit comments

Comments
 (0)