Commit 68cc760
committed
Tune python scripts
diff --git a/scripts/tune/tune.py b/scripts/tune/tune.py
new file mode 100644
index 000000000..eff17d3
--- /dev/null
+++ b/scripts/tune/tune.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python3
+"""
+Optimize runtime parameters for llama-simple binary using eval time measurements.
+Usage: python tune_tps.py --model /path/to/model.gguf
+"""
+import os
+import time
+import argparse
+from functools import partial
+
+import numpy as np
+# pip install scikit-optimize
+from skopt import gp_minimize, expected_minimum
+from skopt.plots import plot_objective, plot_convergence
+from skopt.space import Categorical
+import matplotlib.pyplot as plt
+import json
+
+BAD_CONFIGURATIONS = []
+
+# Progress tracking global variables
+progress_start_time = None
+progress_current_call = 0
+progress_total_calls = 0
+progress_best_score = float('inf')
+
+def display_progress():
+ """Display current optimization progress with time estimates."""
+ global progress_start_time, progress_current_call, progress_total_calls, progress_best_score
+
+ if progress_start_time is None:
+ return
+
+ elapsed_time = time.time() - progress_start_time
+ if progress_current_call > -1:
+ avg_time_per_call = elapsed_time / progress_current_call
+ remaining_calls = progress_total_calls - progress_current_call
+ estimated_remaining_time = avg_time_per_call * remaining_calls
+
+ progress_percent = (progress_current_call / progress_total_calls) * 100
+
+ print(f"\n{'='*60}")
+ print(f"OPTIMIZATION PROGRESS")
+ print(f"{'='*60}")
+ print(f"Iteration: {progress_current_call}/{progress_total_calls} ({progress_percent:.1f}%)")
+ print(f"Elapsed time: {elapsed_time:.1f}s")
+ print(f"Est. remaining time: {estimated_remaining_time:.1f}s")
+ print(f"Best metric so far: {progress_best_score:.4f}")
+ print(f"{'='*60}\n")
+
+def run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, binary_path="./build/bin/llama-cli", iterations=1):
+ """Run llama-siple with specified options and return eval time."""
+ try:
+ run_options_str = get_opts_fn(run_options, model_path, binary_path)
+ print(run_options_str)
+
+ results = []
+
+ # Run the test (can increase iterations for more stable results)
+ for _ in range(iterations):
+ results.append(run_binary_fn(run_options_str))
+
+ # Return eval time as the objective (we want to minimize this)
+ return np.mean(results)
+
+ except Exception as e:
+ BAD_CONFIGURATIONS.append(run_options)
+ print("ERROR:", e, run_options)
+ print("BAD_CONFIGURATIONS:", BAD_CONFIGURATIONS)
+ return 1000 # High penalty for failed runs
+
+
+def optimize_runtime_with_progress(x, get_opts_fn, run_binary_fn, run_options_list, model_path, llama_simple_path):
+ """Objective function for optimization with progress tracking."""
+ global progress_current_call, progress_best_score
+
+ progress_current_call += 1
+
+ run_options = {
+ run_options_list[i][0]: run_options_list[i][1][run_options_list[i][1].index(x[i])]
+ for i in range(len(run_options_list))
+ }
+
+ result = run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, llama_simple_path)
+
+ # Update best score
+ if result < progress_best_score:
+ progress_best_score = result
+
+ # Display progress every call
+ display_progress()
+
+ return result
+
+
+def load_cache(cache_filename):
+ """Load cached optimization results."""
+ try:
+ with open(cache_filename, "r") as cache_file:
+ cache_data = json.load(cache_file)
+ return cache_data["x0"], cache_data["y0"]
+ except:
+ pass
+ return None, None
+
+
+def save_cache(cache_filename, x0, y0):
+ """Save optimization results to cache."""
+ # Convert numpy int64 objects to Python int objects
+ x0 = [[int(item) if isinstance(item, np.int64) else item for item in sublist] for sublist in x0]
+ y0 = [int(item) if isinstance(item, np.int64) else item for item in y0]
+
+ cache_data = {"x0": x0, "y0": y0}
+ with open(cache_filename, "w") as cache_file:
+ json.dump(cache_data, cache_file)
+
+
+def plot_iterations(result):
+ """Plot optimization iterations."""
+ search_space = result.space
+ x_iters = result.x_iters
+ func_vals = result.func_vals
+ search_space_names = [dim.name for dim in search_space]
+ opts = search_space_names + ["objective_r"]
+
+ num_params = len(opts) + 1
+ fig, axs = plt.subplots(num_params, figsize=(8, num_params * 8), sharex=True)
+ iterations = list(range(1, len(x_iters) + 1))
+
+ for i, param in enumerate(opts):
+ if param == "objective_r":
+ param_values = func_vals
+ else:
+ param_index = search_space_names.index(param)
+ param_values = [x[param_index] for x in x_iters]
+
+ axs[i].scatter(iterations, param_values)
+ axs[i].set_xlabel("Iteration")
+ axs[i].set_ylabel(param)
+
+ plot_convergence(result, true_minimum=0, ax=axs[-1])
+ return axs
+
+def parse_args(default_bin):
+ parser = argparse.ArgumentParser(description='Optimize llama-simple runtime parameters')
+ parser.add_argument('--model', '-m', required=True, help='Path to the GGUF model file')
+ parser.add_argument('--ngl', type=int, required=True, help='Max number of GPU layers')
+ parser.add_argument('--llama-binary', default=default_bin,
+ help='Path to llama-simple binary (default: ./build/bin/llama-simple)')
+ parser.add_argument('--n-calls', type=int, default=50,
+ help='Number of optimization calls (default: 20)')
+ parser.add_argument('--cache', default='cache_simple.json',
+ help='Cache file name (default: cache_simple.json)')
+ parser.add_argument('--single-execution', type=str,
+ help='Run single execution with specified options (format: "--param1=value1 --param2=value2")')
+
+ args = parser.parse_args()
+ return args
+
+def main(args, get_opts_fn, run_binary_fn, run_options_list):
+
+ # Check if llama-simple binary exists
+ if not os.path.exists(args.llama_binary):
+ print(f"Error: llama-simple binary not found at {args.llama_binary}")
+ print("Please build llama.cpp first or specify correct path with --llama-binary")
+ return
+
+ # Check if model exists
+ if not os.path.exists(args.model):
+ print(f"Error: Model file not found at {args.model}")
+ return
+
+ # Handle single execution mode
+ if args.single_execution:
+ try:
+ print("Single execution")
+ run_options = args.single_execution
+ run_iterations(get_opts_fn, run_binary_fn, run_options, args.model, args.llama_binary)
+ return
+ except ValueError as e:
+ print(f"Error parsing single execution options: {e}")
+ return
+
+ # Initialize progress tracking
+ global progress_start_time, progress_total_calls
+ progress_start_time = time.time()
+ progress_total_calls = args.n_calls
+
+ # Create optimization dimensions
+ dimensions = [Categorical(opt[1]) for opt in run_options_list]
+ for i, opt in enumerate(run_options_list):
+ dimensions[i].name = opt[0]
+
+ # Load cache
+ x0, y0 = load_cache(args.cache)
+
+ # Create objective function
+ objective_function = partial(optimize_runtime_with_progress,
+ get_opts_fn=get_opts_fn,
+ run_binary_fn=run_binary_fn,
+ run_options_list=run_options_list,
+ model_path=args.model,
+ llama_simple_path=args.llama_binary)
+
+ print(f"Starting optimization with {args.n_calls} calls and {args.ngl} gpu layers...")
+ print(f"Using model: {args.model}")
+ print(f"Cache file: {args.cache}")
+
+ # Run optimization
+ result = gp_minimize(objective_function, dimensions,
+ n_calls=args.n_calls,
+ n_initial_points=min(10, args.n_calls),
+ random_state=42,
+ x0=x0, y0=y0,
+ initial_point_generator="lhs")
+
+ # Save results
+ save_cache(args.cache, result.x_iters, result.func_vals)
+
+ # Print results
+ print(f"\nBest options found: {result.x}")
+ print(f"Minimum eval time: {result.fun:.4f} seconds")
+
+ # Convert result.x back to human-readable format - FIX: Find index of value in options list
+ best_options = {}
+ for i, (name, options) in enumerate(run_options_list):
+ # Find the value in result.x[i] and locate its index in the options list
+ value = result.x[i]
+ if value in options:
+ best_options[name] = value
+ else:
+ # Fallback: use the first option if value not found
+ print(f"Warning: Value '{value}' not found in options for {name}, using first option")
+ best_options[name] = options[0]
+
+ print("\nBest configuration:")
+ for name, value in best_options.items():
+ print(f" {name}: {value}")
+
+ min_x, _ = expected_minimum(result)
+ print(f"Expected minimum: {min_x}")
+
+ if BAD_CONFIGURATIONS:
+ print(f"\nBAD_CONFIGURATIONS: {len(BAD_CONFIGURATIONS)}")
+
+ # Plot results
+ try:
+ plot_iterations(result)
+ plot_objective(result)
+ # Might need PyQt6
+ plt.show()
+ except Exception as e:
+ print(f"Plotting failed: {e}")
diff --git a/scripts/tune/tune_quality.py b/scripts/tune/tune_quality.py
new file mode 100644
index 000000000..ffae255
--- /dev/null
+++ b/scripts/tune/tune_quality.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python3
+"""
+BERTScore-based translation quality optimization for llama.cpp models.
+Uses BERTScore to evaluate translation quality instead of HellaSwag accuracy.
+"""
+import subprocess
+import sys
+import os
+import re
+import json
+import hashlib
+import numpy as np
+from typing import Dict, List, Tuple, Any, Optional
+from collections import Counter
+
+# Import bert_score for translation quality evaluation
+import bert_score
+
+# Import language_tool_python for grammar checking
+import language_tool_python
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, script_dir)
+from tune import parse_args, main
+
+# Configuration
+BERTSCORE_MODEL = 'microsoft/deberta-v3-base'
+
+# Translation benchmarks for quality evaluation
+# Tiny subset of https://openslr.org/100
+TRANSLATION_BENCHMARKS = [
+ {
+ "prompt": "Translate the following English text to French:\n\nEnglish: As you can see, it does not look like a slam lesson, it is a language lesson, a language which allows to give orders to machines and computers the language of the 21st century: the computer code.\nFrench:",
+ "ground_truth": "Comme vous pouvez le constater, il ne s'agit pas d'un cours de slam, il s'agit d'un cours de langue, une langue qui permet de donner des ordres à des machines et à des ordinateurs, la langue du 21e siècle : le code informatique.",
+ "tool": "fr-FR"
+ },
+ {
+ "prompt": "Translate the following English text to Spanish:\n\nEnglish: Some years ago, when I was diving in the Lombok Strait, in Indonesia, 98 feet below the water, with that feeling of weightlessness, surrounded by a great biodiversity of reefs, corals, sea turtles, ocean sunfishes and fishes of all colors, I had an intense feeling of connection with nature.\nSpanish:",
+ "ground_truth": "Hace unos años, cuando me encontraba buceando en el estrecho de Lombok, en Indonesia, a 30 metros debajo del agua, con esa sensación de ingravidez, rodeado de una gran biodiversidad, de arrecifes, de corales, de tortugas, de peces mola mola y de peces de todos los colores, tuve una intensa sensación de estar conectado con la naturaleza.",
+ "tool": "es-ES"
+ },
+ {
+ "prompt": "Translate the following English text to Portuguese:\n\nEnglish: Have you ever stopped to think about clothes for disabled people?\nPortuguese:",
+ "ground_truth": "Vocês já pararam pra pensar como é o vestuário das pessoas com deficiência?",
+ "tool": "pt-PT"
+ }
+]
+
+def get_metrics(metrics_filepath: str, ground_truth: str, prediction: str, tool: str) -> Dict[str, Any]:
+ """
+ Calculate BERTScore and other quality metrics for translation evaluation.
+ Caches results to avoid recomputation.
+ """
+ print(f"Calculating metrics: {metrics_filepath}")
+
+ metrics = {
+ 'bertscore_model': None,
+ 'bertscore_P': None,
+ 'bertscore_R': None,
+ 'bertscore_F1': None,
+ 'grammar_errors': None,
+ 'repetition_score': None,
+ 'objective_r': None
+ }
+
+ # Load cached scores
+ try:
+ with open(metrics_filepath, 'r', encoding='utf-8') as f:
+ metrics.update(json.load(f))
+ except FileNotFoundError:
+ pass
+
+ # Calculate BERTScore if not cached or model changed
+ if (not metrics["bertscore_P"] or not metrics["bertscore_R"] or
+ not metrics["bertscore_F1"] or metrics["bertscore_model"] != BERTSCORE_MODEL):
+ try:
+ metrics["bertscore_model"] = BERTSCORE_MODEL
+ score = bert_score.score([prediction], [ground_truth], model_type=BERTSCORE_MODEL)
+ metrics["bertscore_P"], metrics["bertscore_R"], metrics["bertscore_F1"] = (
+ score[0].item(), score[1].item(), score[2].item()
+ )
+ except Exception as e:
+ print(f"Warning: BERTScore calculation failed: {e}")
+ metrics["bertscore_P"] = metrics["bertscore_R"] = metrics["bertscore_F1"] = 0.0
+
+ # Calculate grammar errors if not cached
+ if metrics["grammar_errors"] is None:
+ metrics["grammar_errors"] = 0.0
+
+ language_tool = language_tool_python.LanguageTool(tool)
+ try:
+ matches = language_tool.check(prediction)
+ metrics["grammar_errors"] = len(matches) / max(len(prediction.split()), 1)
+ except Exception as e:
+ print(f"Warning: Grammar checking failed: {e}")
+ metrics["grammar_errors"] = 0.0
+
+ # Calculate repetition score if not cached
+ if metrics["repetition_score"] is None:
+ try:
+ words = prediction.split()
+ if len(words) > 0:
+ word_counts = Counter(words)
+ repeated_words = sum(count - 1 for count in word_counts.values() if count > 1)
+ metrics["repetition_score"] = repeated_words / len(words)
+ else:
+ metrics["repetition_score"] = 0.0
+ except Exception as e:
+ print(f"Warning: Repetition calculation failed: {e}")
+ metrics["repetition_score"] = 0.0
+
+ # Calculate objective score (we want to minimize this)
+ # Higher BERTScore Recall = better translation quality = lower objective value
+ # Add penalties for grammar errors and repetitions
+ if metrics["bertscore_R"] is not None:
+ grammar_penalty = metrics["grammar_errors"] * 0.1 # Small penalty for grammar errors
+ repetition_penalty = metrics["repetition_score"] * 0.05 # Small penalty for repetitions
+ metrics["objective_r"] = -(metrics["bertscore_R"] - grammar_penalty - repetition_penalty)
+ else:
+ metrics["objective_r"] = 1.0 # Bad score if BERTScore failed
+
+ # Save metrics to cache
+ try:
+ with open(metrics_filepath, 'w', encoding='utf-8') as f:
+ json.dump(metrics, f, indent=2, ensure_ascii=False)
+ except Exception as e:
+ print(f"Warning: Failed to save metrics: {e}")
+
+ return metrics
+
+def run_binary(run_options_str):
+ """Run the binary and evaluate translation quality using BERTScore."""
+ try:
+ # Parse the command to extract parameters
+ parts = run_options_str.split()
+ model_path = None
+ binary_path = None
+
+ # Find model path and binary path
+ for i, part in enumerate(parts):
+ if part == "-m" and i + 1 < len(parts):
+ model_path = parts[i + 1]
+ elif part.endswith("llama-cli") or part.endswith("main"):
+ binary_path = part
+
+ if not model_path or not binary_path:
+ print("Error: Could not parse model path or binary path from command")
+ return 100.0
+
+ # Create output directory for this run
+ run_hash = hashlib.md5(run_options_str.encode()).hexdigest()[:8]
+ output_dir = f"translation_eval_{run_hash}"
+ os.makedirs(output_dir, exist_ok=True)
+
+ all_scores = []
+
+ # Run translation benchmarks
+ for i, benchmark in enumerate(TRANSLATION_BENCHMARKS):
+ print(f"Running benchmark {i+1}/{len(TRANSLATION_BENCHMARKS)}")
+
+ # Build command for this benchmark - use the base command and add benchmark-specific params
+ benchmark_cmd = run_options_str.split()
+
+ # Add benchmark-specific parameters
+ benchmark_cmd.extend(["--prompt", benchmark["prompt"]])
+
+ # Run the command
+ try:
+ process = subprocess.run(benchmark_cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ timeout=120, # 2 minute timeout per benchmark
+ check=False)
+
+ if process.returncode != 0:
+ print(f"Warning: Benchmark {i+1} failed with return code {process.returncode}")
+ print(f"STDERR: {process.stderr.decode()}")
+ all_scores.append(1.0) # Bad score for failed runs
+ continue
+
+ # Extract prediction from output
+ output = process.stdout.decode()
+ prediction = output.strip()
+
+ # Remove the prompt from prediction if it's included
+ if benchmark["prompt"] in prediction:
+ prediction = prediction.split(benchmark["prompt"])[-1].strip()
+
+ # Calculate metrics
+ metrics_filepath = os.path.join(output_dir, f"benchmark_{i}_metrics.json")
+ metrics = get_metrics(metrics_filepath,
+ benchmark["ground_truth"], prediction, benchmark["tool"])
+
+ objective_score = metrics.get("objective_r", 1.0)
+ all_scores.append(objective_score)
+
+ print(f"Benchmark {i+1} - BERTScore R: {metrics.get('bertscore_R', 0):.4f}, "
+ f"Objective: {objective_score:.4f}")
+
+ except subprocess.TimeoutExpired:
+ print(f"Warning: Benchmark {i+1} timed out")
+ all_scores.append(1.0) # Bad score for timeouts
+ except Exception as e:
+ print(f"Error running benchmark {i+1}: {e}")
+ all_scores.append(1.0) # Bad score for errors
+
+ # Calculate average score across all benchmarks
+ if all_scores:
+ avg_score = np.mean(all_scores)
+ print(f"Average translation quality objective score: {avg_score:.4f}")
+ return avg_score
+ else:
+ print("Warning: No successful benchmarks")
+ return 100.0 # Bad score if no benchmarks succeeded
+
+ except Exception as e:
+ print(f"Error in run_binary: {e}")
+ return 100.0 # Bad score for any other errors
+
+if __name__ == "__main__":
+ args = parse_args(default_bin='./build/bin/llama-cli')
+
+ # Define quality-focused sampling parameters for optimization
+ run_options_list = [
+ # Core Sampling Parameters (Most Critical for Quality)
+
+ # 1. Temperature - Controls randomness vs determinism
+ ("--temp", [
+ "--temp 0.1", # Very focused, deterministic
+ "--temp 0.3", # Focused, good for factual tasks
+ "--temp 0.5", # Moderate creativity
+ "--temp 0.7", # Balanced (recommended default)
+ "--temp 0.8", # Good balance
+ "--temp 0.9", # More creative
+ "--temp 1.0", # Creative but coherent
+ "--temp 1.2" # More creative, potentially less coherent
+ ]),
+
+ # 2. Top-p (Nucleus Sampling) - Controls diversity while maintaining quality
+ ("--top-p", [
+ "--top-p 0.5", # Very focused
+ "--top-p 0.7", # Focused, higher quality
+ "--top-p 0.8", # Good balance
+ "--top-p 0.85", # Balanced
+ "--top-p 0.9", # Good balance (recommended)
+ "--top-p 0.95", # Standard default
+ "--top-p 0.98", # More diverse
+ "--top-p 1.0" # No nucleus filtering
+ ]),
+
+ # 3. Top-k - Limits token selection to most probable candidates
+ ("--top-k", [
+ "--top-k 10", # Very focused
+ "--top-k 20", # More focused, higher quality
+ "--top-k 30", # Balanced
+ "--top-k 40", # Good balance (default)
+ "--top-k 50", # Balanced, more diverse
+ "--top-k 60", # More diverse
+ "--top-k 80", # Very diverse
+ "--top-k 100" # Most diverse
+ ]),
+
+ # 4. Min-p - Filters out low-probability tokens
+ ("--min-p", [
+ "--min-p 0.01", # Very permissive
+ "--min-p 0.02", # Permissive
+ "--min-p 0.05", # Good default
+ "--min-p 0.08", # More restrictive
+ "--min-p 0.1", # Restrictive, higher quality
+ "--min-p 0.15", # Very restrictive
+ "--min-p 0.2" # Extremely restrictive
+ ]),
+
+ # Repetition Control (Critical for Coherence)
+
+ # 5. Repeat Penalty - Prevents repetitive text
+ ("--repeat-penalty", [
+ "--repeat-penalty 1.0", # Disabled
+ "--repeat-penalty 1.02", # Very light penalty
+ "--repeat-penalty 1.05", # Light penalty (recommended)
+ "--repeat-penalty 1.1", # Moderate penalty (recommended)
+ "--repeat-penalty 1.15", # Moderate-strong penalty
+ "--repeat-penalty 1.2", # Strong penalty
+ "--repeat-penalty 1.25", # Very strong penalty
+ "--repeat-penalty 1.3" # Extreme penalty
+ ]),
+
+ # 6. Repeat Last N - How far back to look for repetitions
+ ("--repeat-last-n", [
+ "--repeat-last-n 16", # Short context
+ "--repeat-last-n 32", # Short-medium context
+ "--repeat-last-n 64", # Balanced default
+ "--repeat-last-n 96", # Medium-large context
+ "--repeat-last-n 128", # Large context
+ "--repeat-last-n 192", # Very large context
+ "--repeat-last-n 256" # Maximum context
+ ]),
+
+ # Advanced Quality Parameters
+
+ # 7. Typical-p - Promotes contextually coherent tokens
+ ("--typical", [
+ "--typical 1.0", # Disabled
+ "--typical 0.95", # Light filtering
+ "--typical 0.9", # Recommended for quality
+ "--typical 0.85", # Moderate filtering
+ "--typical 0.8", # Strong filtering
+ "--typical 0.75", # Very strong filtering
+ "--typical 0.7" # Extreme filtering
+ ]),
+
+ # 8. Mirostat - Adaptive sampling for consistent quality
+ ("--mirostat", [
+ "--mirostat 0", # Disabled (default)
+ "--mirostat 1", # Mirostat v1
+ "--mirostat 2" # Mirostat v2 (often better quality)
+ ]),
+
+ # Keep seed constant for reproducible results
+ ("--seed", ["-s 42"]),
+ ]
+
+ def run_str(run_options, model_path, binary_path):
+ """Build command string for llama-cli with translation evaluation."""
+ if isinstance(run_options, dict):
+ run_options = " ".join(run_options.values())
+ # Use the main binary for translation evaluation
+ return f"{binary_path} -m {model_path} --threads 8 -ngl {args.ngl} {run_options}"
+
+ main(args, run_str, run_binary, run_options_list)
diff --git a/scripts/tune/tune_quality_swag.py b/scripts/tune/tune_quality_swag.py
new file mode 100644
index 000000000..1eaedad
--- /dev/null
+++ b/scripts/tune/tune_quality_swag.py
@@ -0,0 +1,172 @@
+import subprocess
+import sys
+import os
+import re
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, script_dir)
+from tune import parse_args, main
+
+def run_binary(run_options_str):
+ """Run the binary and parse HellaSwag accuracy score."""
+ try:
+ process = subprocess.run(run_options_str,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ check=False, # Don't raise exception on non-zero exit
+ timeout=300 # 5 minute timeout
+ )
+
+ if process.returncode != 0:
+ print(f"Warning: Process returned non-zero exit code: {process.returncode}")
+ print(f"STDERR: {process.stderr.decode()}")
+ return 100.0 # Return bad score for failed runs
+
+ # Parse HellaSwag accuracy from stdout
+ stdout_text = process.stdout.decode()
+ stderr_text = process.stderr.decode()
+
+ # Look for HellaSwag accuracy patterns in output
+ # Pattern for format: "20 75.00000000% [53.1299%, 88.8138%]"
+ accuracy_patterns = [
+ r"20\t([\d.]+)%\t\[",
+ ]
+
+ accuracy = None
+ for pattern in accuracy_patterns:
+ match = re.search(pattern, stdout_text, re.IGNORECASE)
+ if match:
+ accuracy = float(match.group(1))
+ # Convert percentage to decimal if needed (values > 1.0 are likely percentages)
+ if accuracy > 1.0:
+ accuracy = accuracy / 100.0
+ break
+
+ if accuracy is None:
+ print("Warning: Could not parse HellaSwag accuracy from output")
+ print("STDOUT:", stdout_text[:500]) # Show first 500 chars
+ print("STDERR:", stderr_text[:500])
+ return 100.0 # Return bad score for unparseable results
+ else:
+ print(f"HellaSwag accuracy: {accuracy:.4f}")
+
+ # Return negative accuracy since we want to MINIMIZE the objective function
+ # (higher accuracy = lower objective value = better)
+ return -accuracy
+
+ except subprocess.TimeoutExpired:
+ print("Warning: Process timed out")
+ return 100.0 # Return bad score for timeouts
+ except Exception as e:
+ print(f"Error running command: {e}")
+ return 100.0 # Return bad score for other errors
+
+if __name__ == "__main__":
+ args = parse_args(default_bin='./build/bin/llama-perplexity')
+
+ # Define quality-focused sampling parameters for optimization
+ run_options_list = [
+ # Core Sampling Parameters (Most Critical for Quality)
+
+ # 1. Temperature - Controls randomness vs determinism
+ ("--temp", [
+ "--temp 0.1", # Very focused, deterministic
+ "--temp 0.3", # Focused, good for factual tasks
+ "--temp 0.5", # Moderate creativity
+ "--temp 0.7", # Balanced (recommended default)
+ "--temp 0.8", # Good balance
+ "--temp 0.9", # More creative
+ "--temp 1.0", # Creative but coherent
+ "--temp 1.2" # More creative, potentially less coherent
+ ]),
+
+ # 2. Top-p (Nucleus Sampling) - Controls diversity while maintaining quality
+ ("--top-p", [
+ "--top-p 0.5", # Very focused
+ "--top-p 0.7", # Focused, higher quality
+ "--top-p 0.8", # Good balance
+ "--top-p 0.85", # Balanced
+ "--top-p 0.9", # Good balance (recommended)
+ "--top-p 0.95", # Standard default
+ "--top-p 0.98", # More diverse
+ "--top-p 1.0" # No nucleus filtering
+ ]),
+
+ # 3. Top-k - Limits token selection to most probable candidates
+ ("--top-k", [
+ "--top-k 10", # Very focused
+ "--top-k 20", # More focused, higher quality
+ "--top-k 30", # Balanced
+ "--top-k 40", # Good balance (default)
+ "--top-k 50", # Balanced, more diverse
+ "--top-k 60", # More diverse
+ "--top-k 80", # Very diverse
+ "--top-k 100" # Most diverse
+ ]),
+
+ # 4. Min-p - Filters out low-probability tokens
+ ("--min-p", [
+ "--min-p 0.01", # Very permissive
+ "--min-p 0.02", # Permissive
+ "--min-p 0.05", # Good default
+ "--min-p 0.08", # More restrictive
+ "--min-p 0.1", # Restrictive, higher quality
+ "--min-p 0.15", # Very restrictive
+ "--min-p 0.2" # Extremely restrictive
+ ]),
+
+ # Repetition Control (Critical for Coherence)
+
+ # 5. Repeat Penalty - Prevents repetitive text
+ ("--repeat-penalty", [
+ "--repeat-penalty 1.0", # Disabled
+ "--repeat-penalty 1.02", # Very light penalty
+ "--repeat-penalty 1.05", # Light penalty (recommended)
+ "--repeat-penalty 1.1", # Moderate penalty (recommended)
+ "--repeat-penalty 1.15", # Moderate-strong penalty
+ "--repeat-penalty 1.2", # Strong penalty
+ "--repeat-penalty 1.25", # Very strong penalty
+ "--repeat-penalty 1.3" # Extreme penalty
+ ]),
+
+ # 6. Repeat Last N - How far back to look for repetitions
+ ("--repeat-last-n", [
+ "--repeat-last-n 16", # Short context
+ "--repeat-last-n 32", # Short-medium context
+ "--repeat-last-n 64", # Balanced default
+ "--repeat-last-n 96", # Medium-large context
+ "--repeat-last-n 128", # Large context
+ "--repeat-last-n 192", # Very large context
+ "--repeat-last-n 256" # Maximum context
+ ]),
+
+ # Advanced Quality Parameters
+
+ # 7. Typical-p - Promotes contextually coherent tokens
+ ("--typical", [
+ "--typical 1.0", # Disabled
+ "--typical 0.95", # Light filtering
+ "--typical 0.9", # Recommended for quality
+ "--typical 0.85", # Moderate filtering
+ "--typical 0.8", # Strong filtering
+ "--typical 0.75", # Very strong filtering
+ "--typical 0.7" # Extreme filtering
+ ]),
+
+ # 8. Mirostat - Adaptive sampling for consistent quality
+ ("--mirostat", [
+ "--mirostat 0", # Disabled (default)
+ "--mirostat 1", # Mirostat v1
+ "--mirostat 2" # Mirostat v2 (often better quality)
+ ]),
+
+ # Keep seed constant for reproducible results
+ ("--seed", ["-s 42"]),
+ ]
+ def run_str(run_options, model_path, binary_path):
+ """Build command string for llama-perplexity with hellaswag evaluation."""
+ run_opts = " ".join(run_options.values())
+ # Use the perplexity command with hellaswag evaluation as specified
+ return f"{binary_path} -m {model_path} -f hellaswag_val_full.txt --hellaswag-tasks 20 --hellaswag -ngl {args.ngl} {run_opts}"
+ main(args, run_str, run_binary, run_options_list)
diff --git a/scripts/tune/tune_requirements.txt b/scripts/tune/tune_requirements.txt
new file mode 100644
index 000000000..50cb56b
--- /dev/null
+++ b/scripts/tune/tune_requirements.txt
@@ -0,0 +1,3 @@
+language_tool_python
+bert_score
+scikit-optimize
diff --git a/scripts/tune/tune_tps.py b/scripts/tune/tune_tps.py
new file mode 100644
index 000000000..8584713
--- /dev/null
+++ b/scripts/tune/tune_tps.py
@@ -0,0 +1,80 @@
+import subprocess
+import sys
+import os
+import re
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, script_dir)
+from tune import parse_args, main
+
+def run_str(run_options, model_path, binary_path):
+ run_opts = " ".join(run_options.values())
+ return f"{binary_path} -m {model_path} -p 'Hello, how are you?' -n 1 {run_opts}"
+
+def run_binary(run_options_str):
+ process = subprocess.run(run_options_str,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ check=True,
+ )
+ if process.returncode != 0:
+ raise Exception(f"Error running: '{run_options_str}':\n{process.stderr}")
+
+ # Parse timing information from stderr
+ stderr_text = process.stderr.decode()
+
+ # Updated regex patterns for llama-simple output
+ prompt_eval_time_pattern = r"prompt eval time\s*=\s*([\d.]+)\s*ms"
+ eval_time_pattern = r"eval time\s*=\s*([\d.]+)\s*ms"
+
+ prompt_match = re.search(prompt_eval_time_pattern, stderr_text)
+ eval_match = re.search(eval_time_pattern, stderr_text)
+
+ if prompt_match and eval_match:
+ prompt_eval_time = float(prompt_match.group(1)) / 1000 # Convert to seconds
+ eval_time = float(eval_match.group(1)) / 1000 # Convert to seconds
+ else:
+ # Fallback: look for any timing patterns
+ print("Warning: Could not parse timing info, using fallback")
+ print("STDERR:", stderr_text)
+ return 1000 # High penalty for failed parsing
+
+ print("prompt eval time:", prompt_eval_time)
+ print("eval time:", eval_time)
+
+ return eval_time
+
+if __name__ == "__main__":
+ args = parse_args(default_bin='./build/bin/llama-cli')
+ # Define runtime options to optimize - Core Performance Parameters
+ run_options_list = [
+ # 1. Batch Processing Parameters (most critical for throughput)
+ ("--batch-size", ["--batch-size 31", "--batch-size 64", "--batch-size 128", "--batch-size 256", "--batch-size 512", "--batch-size 1024", "--batch-size 2048"]),
+ ("--ubatch-size", ["--ubatch-size 32", "--ubatch-size 64", "--ubatch-size 128", "--ubatch-size 256", "--ubatch-size 512"]),
+
+ # 2. Context and Memory Parameters
+ ("--ctx-size", ["-c 512", "-c 1024", "-c 2048", "-c 4096", "-c 8192"]),
+ ("--defrag-thold", ["--defrag-thold -1", "--defrag-thold 0.1", "--defrag-thold 0.2", "--defrag-thold 0.5"]),
+
+ # 3. GPU Offloading Parameters (critical for GPU performance)
+ # Set range to a value that makes sense for your model
+ ("--n-gpu-layers", [f"--n-gpu-layers {i}" for i in range(args.ngl)]),
+
+ # 4. CPU Optimization Parameters
+ ("--threads", ["-t 4", "-t 8", "-t 12", "-t 16"]),
+ # ("--prio", ["--prio 0", "--prio 1", "--prio 2"]),
+
+ # 5. Memory and Caching Parameters
+ # ("--use-mmap", ["", "--no-mmap"]),
+ ("--use-mlock", ["--mlock", ""]),
+ ("--kv-unified", ["--kv-unified", ""]),
+
+ # 6. Advanced Performance Features
+ ("--flash-attn", ["--flash-attn", ""]),
+ # ("--no-kv-offload", ["--no-kv-offload", ""]), # Empty string means don't use the flag
+
+ # Keep seed constant for reproducible results
+ ("--seed", ["-s 42"]),
+ ]
+ main(args, run_str, run_binary, run_options_list)1 parent 1473dce commit 68cc760
File tree
5 files changed
+838
-0
lines changed- scripts/tune
5 files changed
+838
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
0 commit comments