|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Optimize runtime parameters for llama-simple binary using eval time measurements. |
| 4 | +Usage: python tune_tps.py --model /path/to/model.gguf |
| 5 | +""" |
| 6 | +import os |
| 7 | +import time |
| 8 | +import argparse |
| 9 | +from functools import partial |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +# pip install scikit-optimize |
| 13 | +from skopt import gp_minimize, expected_minimum |
| 14 | +from skopt.plots import plot_objective, plot_convergence |
| 15 | +from skopt.space import Categorical |
| 16 | +import matplotlib.pyplot as plt |
| 17 | +import json |
| 18 | + |
| 19 | +BAD_CONFIGURATIONS = [] |
| 20 | + |
| 21 | +# Progress tracking global variables |
| 22 | +progress_start_time = None |
| 23 | +progress_current_call = 0 |
| 24 | +progress_total_calls = 0 |
| 25 | +progress_best_score = float('inf') |
| 26 | + |
| 27 | +def display_progress(): |
| 28 | + """Display current optimization progress with time estimates.""" |
| 29 | + global progress_start_time, progress_current_call, progress_total_calls, progress_best_score |
| 30 | + |
| 31 | + if progress_start_time is None: |
| 32 | + return |
| 33 | + |
| 34 | + elapsed_time = time.time() - progress_start_time |
| 35 | + if progress_current_call > -1: |
| 36 | + avg_time_per_call = elapsed_time / progress_current_call |
| 37 | + remaining_calls = progress_total_calls - progress_current_call |
| 38 | + estimated_remaining_time = avg_time_per_call * remaining_calls |
| 39 | + |
| 40 | + progress_percent = (progress_current_call / progress_total_calls) * 100 |
| 41 | + |
| 42 | + print(f"\n{'='*60}") |
| 43 | + print(f"OPTIMIZATION PROGRESS") |
| 44 | + print(f"{'='*60}") |
| 45 | + print(f"Iteration: {progress_current_call}/{progress_total_calls} ({progress_percent:.1f}%)") |
| 46 | + print(f"Elapsed time: {elapsed_time:.1f}s") |
| 47 | + print(f"Est. remaining time: {estimated_remaining_time:.1f}s") |
| 48 | + print(f"Best metric so far: {progress_best_score:.4f}") |
| 49 | + print(f"{'='*60}\n") |
| 50 | + |
| 51 | +def run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, binary_path="./build/bin/llama-cli", iterations=1): |
| 52 | + """Run llama-siple with specified options and return eval time.""" |
| 53 | + try: |
| 54 | + run_options_str = get_opts_fn(run_options, model_path, binary_path) |
| 55 | + print(run_options_str) |
| 56 | + |
| 57 | + results = [] |
| 58 | + |
| 59 | + # Run the test (can increase iterations for more stable results) |
| 60 | + for _ in range(iterations): |
| 61 | + results.append(run_binary_fn(run_options_str)) |
| 62 | + |
| 63 | + # Return eval time as the objective (we want to minimize this) |
| 64 | + return np.mean(results) |
| 65 | + |
| 66 | + except Exception as e: |
| 67 | + BAD_CONFIGURATIONS.append(run_options) |
| 68 | + print("ERROR:", e, run_options) |
| 69 | + print("BAD_CONFIGURATIONS:", BAD_CONFIGURATIONS) |
| 70 | + return 1000 # High penalty for failed runs |
| 71 | + |
| 72 | + |
| 73 | +def optimize_runtime_with_progress(x, get_opts_fn, run_binary_fn, run_options_list, model_path, llama_simple_path): |
| 74 | + """Objective function for optimization with progress tracking.""" |
| 75 | + global progress_current_call, progress_best_score |
| 76 | + |
| 77 | + progress_current_call += 1 |
| 78 | + |
| 79 | + run_options = { |
| 80 | + run_options_list[i][0]: run_options_list[i][1][run_options_list[i][1].index(x[i])] |
| 81 | + for i in range(len(run_options_list)) |
| 82 | + } |
| 83 | + |
| 84 | + result = run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, llama_simple_path) |
| 85 | + |
| 86 | + # Update best score |
| 87 | + if result < progress_best_score: |
| 88 | + progress_best_score = result |
| 89 | + |
| 90 | + # Display progress every call |
| 91 | + display_progress() |
| 92 | + |
| 93 | + return result |
| 94 | + |
| 95 | + |
| 96 | +def load_cache(cache_filename): |
| 97 | + """Load cached optimization results.""" |
| 98 | + try: |
| 99 | + with open(cache_filename, "r") as cache_file: |
| 100 | + cache_data = json.load(cache_file) |
| 101 | + return cache_data["x0"], cache_data["y0"] |
| 102 | + except: |
| 103 | + pass |
| 104 | + return None, None |
| 105 | + |
| 106 | + |
| 107 | +def save_cache(cache_filename, x0, y0): |
| 108 | + """Save optimization results to cache.""" |
| 109 | + # Convert numpy int64 objects to Python int objects |
| 110 | + x0 = [[int(item) if isinstance(item, np.int64) else item for item in sublist] for sublist in x0] |
| 111 | + y0 = [int(item) if isinstance(item, np.int64) else item for item in y0] |
| 112 | + |
| 113 | + cache_data = {"x0": x0, "y0": y0} |
| 114 | + with open(cache_filename, "w") as cache_file: |
| 115 | + json.dump(cache_data, cache_file) |
| 116 | + |
| 117 | + |
| 118 | +def plot_iterations(result): |
| 119 | + """Plot optimization iterations.""" |
| 120 | + search_space = result.space |
| 121 | + x_iters = result.x_iters |
| 122 | + func_vals = result.func_vals |
| 123 | + search_space_names = [dim.name for dim in search_space] |
| 124 | + opts = search_space_names + ["objective_r"] |
| 125 | + |
| 126 | + num_params = len(opts) + 1 |
| 127 | + fig, axs = plt.subplots(num_params, figsize=(8, num_params * 8), sharex=True) |
| 128 | + iterations = list(range(1, len(x_iters) + 1)) |
| 129 | + |
| 130 | + for i, param in enumerate(opts): |
| 131 | + if param == "objective_r": |
| 132 | + param_values = func_vals |
| 133 | + else: |
| 134 | + param_index = search_space_names.index(param) |
| 135 | + param_values = [x[param_index] for x in x_iters] |
| 136 | + |
| 137 | + axs[i].scatter(iterations, param_values) |
| 138 | + axs[i].set_xlabel("Iteration") |
| 139 | + axs[i].set_ylabel(param) |
| 140 | + |
| 141 | + plot_convergence(result, true_minimum=0, ax=axs[-1]) |
| 142 | + return axs |
| 143 | + |
| 144 | +def parse_args(default_bin): |
| 145 | + parser = argparse.ArgumentParser(description='Optimize llama-simple runtime parameters') |
| 146 | + parser.add_argument('--model', '-m', required=True, help='Path to the GGUF model file') |
| 147 | + parser.add_argument('--ngl', type=int, required=True, help='Max number of GPU layers') |
| 148 | + parser.add_argument('--llama-binary', default=default_bin, |
| 149 | + help='Path to llama-simple binary (default: ./build/bin/llama-simple)') |
| 150 | + parser.add_argument('--n-calls', type=int, default=50, |
| 151 | + help='Number of optimization calls (default: 20)') |
| 152 | + parser.add_argument('--cache', default='cache_simple.json', |
| 153 | + help='Cache file name (default: cache_simple.json)') |
| 154 | + parser.add_argument('--single-execution', type=str, |
| 155 | + help='Run single execution with specified options (format: "--param1=value1 --param2=value2")') |
| 156 | + |
| 157 | + args = parser.parse_args() |
| 158 | + return args |
| 159 | + |
| 160 | +def main(args, get_opts_fn, run_binary_fn, run_options_list): |
| 161 | + |
| 162 | + # Check if llama-simple binary exists |
| 163 | + if not os.path.exists(args.llama_binary): |
| 164 | + print(f"Error: llama-simple binary not found at {args.llama_binary}") |
| 165 | + print("Please build llama.cpp first or specify correct path with --llama-binary") |
| 166 | + return |
| 167 | + |
| 168 | + # Check if model exists |
| 169 | + if not os.path.exists(args.model): |
| 170 | + print(f"Error: Model file not found at {args.model}") |
| 171 | + return |
| 172 | + |
| 173 | + # Handle single execution mode |
| 174 | + if args.single_execution: |
| 175 | + try: |
| 176 | + print("Single execution") |
| 177 | + run_options = args.single_execution |
| 178 | + run_iterations(get_opts_fn, run_binary_fn, run_options, args.model, args.llama_binary) |
| 179 | + return |
| 180 | + except ValueError as e: |
| 181 | + print(f"Error parsing single execution options: {e}") |
| 182 | + return |
| 183 | + |
| 184 | + # Initialize progress tracking |
| 185 | + global progress_start_time, progress_total_calls |
| 186 | + progress_start_time = time.time() |
| 187 | + progress_total_calls = args.n_calls |
| 188 | + |
| 189 | + # Create optimization dimensions |
| 190 | + dimensions = [Categorical(opt[1]) for opt in run_options_list] |
| 191 | + for i, opt in enumerate(run_options_list): |
| 192 | + dimensions[i].name = opt[0] |
| 193 | + |
| 194 | + # Load cache |
| 195 | + x0, y0 = load_cache(args.cache) |
| 196 | + |
| 197 | + # Create objective function |
| 198 | + objective_function = partial(optimize_runtime_with_progress, |
| 199 | + get_opts_fn=get_opts_fn, |
| 200 | + run_binary_fn=run_binary_fn, |
| 201 | + run_options_list=run_options_list, |
| 202 | + model_path=args.model, |
| 203 | + llama_simple_path=args.llama_binary) |
| 204 | + |
| 205 | + print(f"Starting optimization with {args.n_calls} calls and {args.ngl} gpu layers...") |
| 206 | + print(f"Using model: {args.model}") |
| 207 | + print(f"Cache file: {args.cache}") |
| 208 | + |
| 209 | + # Run optimization |
| 210 | + result = gp_minimize(objective_function, dimensions, |
| 211 | + n_calls=args.n_calls, |
| 212 | + n_initial_points=min(10, args.n_calls), |
| 213 | + random_state=42, |
| 214 | + x0=x0, y0=y0, |
| 215 | + initial_point_generator="lhs") |
| 216 | + |
| 217 | + # Save results |
| 218 | + save_cache(args.cache, result.x_iters, result.func_vals) |
| 219 | + |
| 220 | + # Print results |
| 221 | + print(f"\nBest options found: {result.x}") |
| 222 | + print(f"Minimum eval time: {result.fun:.4f} seconds") |
| 223 | + |
| 224 | + # Convert result.x back to human-readable format - FIX: Find index of value in options list |
| 225 | + best_options = {} |
| 226 | + for i, (name, options) in enumerate(run_options_list): |
| 227 | + # Find the value in result.x[i] and locate its index in the options list |
| 228 | + value = result.x[i] |
| 229 | + if value in options: |
| 230 | + best_options[name] = value |
| 231 | + else: |
| 232 | + # Fallback: use the first option if value not found |
| 233 | + print(f"Warning: Value '{value}' not found in options for {name}, using first option") |
| 234 | + best_options[name] = options[0] |
| 235 | + |
| 236 | + print("\nBest configuration:") |
| 237 | + for name, value in best_options.items(): |
| 238 | + print(f" {name}: {value}") |
| 239 | + |
| 240 | + min_x, _ = expected_minimum(result) |
| 241 | + print(f"Expected minimum: {min_x}") |
| 242 | + |
| 243 | + if BAD_CONFIGURATIONS: |
| 244 | + print(f"\nBAD_CONFIGURATIONS: {len(BAD_CONFIGURATIONS)}") |
| 245 | + |
| 246 | + # Plot results |
| 247 | + try: |
| 248 | + plot_iterations(result) |
| 249 | + plot_objective(result) |
| 250 | + # Might need PyQt6 |
| 251 | + plt.show() |
| 252 | + except Exception as e: |
| 253 | + print(f"Plotting failed: {e}") |
0 commit comments