Skip to content

Conversation

@jesusmb1995
Copy link

@jesusmb1995 jesusmb1995 commented Sep 17, 2025

These scripts allow to automatically discover best settings in term of generative text quality and speed, inspired by Llama.cpp discussions but modernized for more re-usability/configuration. After configuring a set of possible configurations, it relies on bayesian/gaussian optimization to guide the process:

The main objective here was to quickly discover settings to tune our models, therefore it heavily relies on existing python tools (optimization, plots, etc.). I imagine we could have something similar on in the QVAC JS repository but it would require more effort and its currently out of scope.

The scripts are provided on a separate folder with their on python requirements, it does not affect behavior of Llama.cpp but merging it is beneficial for future usage/reference.

Example logs and plots


@jesusmb1995 jesusmb1995 marked this pull request as ready for review September 17, 2025 09:55
@jesusmb1995
Copy link
Author

Unrelated test failing due to un-maintained CI scripts.

Please review/approve
@gianni-cor
@olyasir
@jpgaribotti
@yuranich

@jesusmb1995
Copy link
Author

/review

1 similar comment
@jesusmb1995
Copy link
Author

/review

@jesusmb1995 jesusmb1995 changed the base branch from master to temp-latest September 22, 2025 13:02
@jesusmb1995 jesusmb1995 marked this pull request as draft September 22, 2025 13:02
@jesusmb1995 jesusmb1995 marked this pull request as ready for review September 22, 2025 13:06
@jesusmb1995 jesusmb1995 marked this pull request as draft September 22, 2025 13:12
diff --git a/scripts/tune/tune.py b/scripts/tune/tune.py
new file mode 100644
index 000000000..eff17d3
--- /dev/null
+++ b/scripts/tune/tune.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python3
+"""
+Optimize runtime parameters for llama-simple binary using eval time measurements.
+Usage: python tune_tps.py --model /path/to/model.gguf
+"""
+import os
+import time
+import argparse
+from functools import partial
+
+import numpy as np
+# pip install scikit-optimize
+from skopt import gp_minimize, expected_minimum
+from skopt.plots import plot_objective, plot_convergence
+from skopt.space import Categorical
+import matplotlib.pyplot as plt
+import json
+
+BAD_CONFIGURATIONS = []
+
+# Progress tracking global variables
+progress_start_time = None
+progress_current_call = 0
+progress_total_calls = 0
+progress_best_score = float('inf')
+
+def display_progress():
+    """Display current optimization progress with time estimates."""
+    global progress_start_time, progress_current_call, progress_total_calls, progress_best_score
+
+    if progress_start_time is None:
+        return
+
+    elapsed_time = time.time() - progress_start_time
+    if progress_current_call > -1:
+        avg_time_per_call = elapsed_time / progress_current_call
+        remaining_calls = progress_total_calls - progress_current_call
+        estimated_remaining_time = avg_time_per_call * remaining_calls
+
+        progress_percent = (progress_current_call / progress_total_calls) * 100
+
+        print(f"\n{'='*60}")
+        print(f"OPTIMIZATION PROGRESS")
+        print(f"{'='*60}")
+        print(f"Iteration: {progress_current_call}/{progress_total_calls} ({progress_percent:.1f}%)")
+        print(f"Elapsed time: {elapsed_time:.1f}s")
+        print(f"Est. remaining time: {estimated_remaining_time:.1f}s")
+        print(f"Best metric so far: {progress_best_score:.4f}")
+        print(f"{'='*60}\n")
+
+def run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, binary_path="./build/bin/llama-cli", iterations=1):
+    """Run llama-siple with specified options and return eval time."""
+    try:
+        run_options_str = get_opts_fn(run_options, model_path, binary_path)
+        print(run_options_str)
+
+        results = []
+
+        # Run the test (can increase iterations for more stable results)
+        for _ in range(iterations):
+            results.append(run_binary_fn(run_options_str))
+
+        # Return eval time as the objective (we want to minimize this)
+        return np.mean(results)
+
+    except Exception as e:
+        BAD_CONFIGURATIONS.append(run_options)
+        print("ERROR:", e, run_options)
+        print("BAD_CONFIGURATIONS:", BAD_CONFIGURATIONS)
+        return 1000  # High penalty for failed runs
+
+
+def optimize_runtime_with_progress(x, get_opts_fn, run_binary_fn, run_options_list, model_path, llama_simple_path):
+    """Objective function for optimization with progress tracking."""
+    global progress_current_call, progress_best_score
+
+    progress_current_call += 1
+
+    run_options = {
+        run_options_list[i][0]: run_options_list[i][1][run_options_list[i][1].index(x[i])]
+        for i in range(len(run_options_list))
+    }
+
+    result = run_iterations(get_opts_fn, run_binary_fn, run_options, model_path, llama_simple_path)
+
+    # Update best score
+    if result < progress_best_score:
+        progress_best_score = result
+
+    # Display progress every call
+    display_progress()
+
+    return result
+
+
+def load_cache(cache_filename):
+    """Load cached optimization results."""
+    try:
+        with open(cache_filename, "r") as cache_file:
+            cache_data = json.load(cache_file)
+            return cache_data["x0"], cache_data["y0"]
+    except:
+        pass
+    return None, None
+
+
+def save_cache(cache_filename, x0, y0):
+    """Save optimization results to cache."""
+    # Convert numpy int64 objects to Python int objects
+    x0 = [[int(item) if isinstance(item, np.int64) else item for item in sublist] for sublist in x0]
+    y0 = [int(item) if isinstance(item, np.int64) else item for item in y0]
+
+    cache_data = {"x0": x0, "y0": y0}
+    with open(cache_filename, "w") as cache_file:
+        json.dump(cache_data, cache_file)
+
+
+def plot_iterations(result):
+    """Plot optimization iterations."""
+    search_space = result.space
+    x_iters = result.x_iters
+    func_vals = result.func_vals
+    search_space_names = [dim.name for dim in search_space]
+    opts = search_space_names + ["objective_r"]
+
+    num_params = len(opts) + 1
+    fig, axs = plt.subplots(num_params, figsize=(8, num_params * 8), sharex=True)
+    iterations = list(range(1, len(x_iters) + 1))
+
+    for i, param in enumerate(opts):
+        if param == "objective_r":
+            param_values = func_vals
+        else:
+            param_index = search_space_names.index(param)
+            param_values = [x[param_index] for x in x_iters]
+
+        axs[i].scatter(iterations, param_values)
+        axs[i].set_xlabel("Iteration")
+        axs[i].set_ylabel(param)
+
+    plot_convergence(result, true_minimum=0, ax=axs[-1])
+    return axs
+
+def parse_args(default_bin):
+    parser = argparse.ArgumentParser(description='Optimize llama-simple runtime parameters')
+    parser.add_argument('--model', '-m', required=True, help='Path to the GGUF model file')
+    parser.add_argument('--ngl', type=int, required=True, help='Max number of GPU layers')
+    parser.add_argument('--llama-binary', default=default_bin,
+                       help='Path to llama-simple binary (default: ./build/bin/llama-simple)')
+    parser.add_argument('--n-calls', type=int, default=50,
+                       help='Number of optimization calls (default: 20)')
+    parser.add_argument('--cache', default='cache_simple.json',
+                       help='Cache file name (default: cache_simple.json)')
+    parser.add_argument('--single-execution', type=str,
+                       help='Run single execution with specified options (format: "--param1=value1 --param2=value2")')
+
+    args = parser.parse_args()
+    return args
+
+def main(args, get_opts_fn, run_binary_fn, run_options_list):
+
+    # Check if llama-simple binary exists
+    if not os.path.exists(args.llama_binary):
+        print(f"Error: llama-simple binary not found at {args.llama_binary}")
+        print("Please build llama.cpp first or specify correct path with --llama-binary")
+        return
+
+    # Check if model exists
+    if not os.path.exists(args.model):
+        print(f"Error: Model file not found at {args.model}")
+        return
+
+    # Handle single execution mode
+    if args.single_execution:
+        try:
+            print("Single execution")
+            run_options = args.single_execution
+            run_iterations(get_opts_fn, run_binary_fn, run_options, args.model, args.llama_binary)
+            return
+        except ValueError as e:
+            print(f"Error parsing single execution options: {e}")
+            return
+
+    # Initialize progress tracking
+    global progress_start_time, progress_total_calls
+    progress_start_time = time.time()
+    progress_total_calls = args.n_calls
+
+    # Create optimization dimensions
+    dimensions = [Categorical(opt[1]) for opt in run_options_list]
+    for i, opt in enumerate(run_options_list):
+        dimensions[i].name = opt[0]
+
+    # Load cache
+    x0, y0 = load_cache(args.cache)
+
+    # Create objective function
+    objective_function = partial(optimize_runtime_with_progress,
+                               get_opts_fn=get_opts_fn,
+                               run_binary_fn=run_binary_fn,
+                               run_options_list=run_options_list,
+                               model_path=args.model,
+                               llama_simple_path=args.llama_binary)
+
+    print(f"Starting optimization with {args.n_calls} calls and {args.ngl} gpu layers...")
+    print(f"Using model: {args.model}")
+    print(f"Cache file: {args.cache}")
+
+    # Run optimization
+    result = gp_minimize(objective_function, dimensions,
+                        n_calls=args.n_calls,
+                        n_initial_points=min(10, args.n_calls),
+                        random_state=42,
+                        x0=x0, y0=y0,
+                        initial_point_generator="lhs")
+
+    # Save results
+    save_cache(args.cache, result.x_iters, result.func_vals)
+
+    # Print results
+    print(f"\nBest options found: {result.x}")
+    print(f"Minimum eval time: {result.fun:.4f} seconds")
+
+    # Convert result.x back to human-readable format - FIX: Find index of value in options list
+    best_options = {}
+    for i, (name, options) in enumerate(run_options_list):
+        # Find the value in result.x[i] and locate its index in the options list
+        value = result.x[i]
+        if value in options:
+            best_options[name] = value
+        else:
+            # Fallback: use the first option if value not found
+            print(f"Warning: Value '{value}' not found in options for {name}, using first option")
+            best_options[name] = options[0]
+
+    print("\nBest configuration:")
+    for name, value in best_options.items():
+        print(f"  {name}: {value}")
+
+    min_x, _ = expected_minimum(result)
+    print(f"Expected minimum: {min_x}")
+
+    if BAD_CONFIGURATIONS:
+        print(f"\nBAD_CONFIGURATIONS: {len(BAD_CONFIGURATIONS)}")
+
+    # Plot results
+    try:
+        plot_iterations(result)
+        plot_objective(result)
+        # Might need PyQt6
+        plt.show()
+    except Exception as e:
+        print(f"Plotting failed: {e}")
diff --git a/scripts/tune/tune_quality.py b/scripts/tune/tune_quality.py
new file mode 100644
index 000000000..ffae255
--- /dev/null
+++ b/scripts/tune/tune_quality.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python3
+"""
+BERTScore-based translation quality optimization for llama.cpp models.
+Uses BERTScore to evaluate translation quality instead of HellaSwag accuracy.
+"""
+import subprocess
+import sys
+import os
+import re
+import json
+import hashlib
+import numpy as np
+from typing import Dict, List, Tuple, Any, Optional
+from collections import Counter
+
+# Import bert_score for translation quality evaluation
+import bert_score
+
+# Import language_tool_python for grammar checking
+import language_tool_python
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, script_dir)
+from tune import parse_args, main
+
+# Configuration
+BERTSCORE_MODEL = 'microsoft/deberta-v3-base'
+
+# Translation benchmarks for quality evaluation
+# Tiny subset of https://openslr.org/100
+TRANSLATION_BENCHMARKS = [
+    {
+        "prompt": "Translate the following English text to French:\n\nEnglish: As you can see, it does not look like a slam lesson, it is a language lesson, a language which allows to give orders to machines and computers the language of the 21st century: the computer code.\nFrench:",
+        "ground_truth": "Comme vous pouvez le constater, il ne s'agit pas d'un cours de slam, il s'agit d'un cours de langue, une langue qui permet de donner des ordres à des machines et à des ordinateurs, la langue du 21e siècle : le code informatique.",
+        "tool": "fr-FR"
+    },
+    {
+        "prompt": "Translate the following English text to Spanish:\n\nEnglish: Some years ago, when I was diving in the Lombok Strait, in Indonesia, 98 feet below the water, with that feeling of weightlessness, surrounded by a great biodiversity of reefs, corals, sea turtles, ocean sunfishes and fishes of all colors, I had an intense feeling of connection with nature.\nSpanish:",
+        "ground_truth": "Hace unos años, cuando me encontraba buceando en el estrecho de Lombok, en Indonesia, a 30 metros debajo del agua, con esa sensación de ingravidez, rodeado de una gran biodiversidad, de arrecifes, de corales, de tortugas, de peces mola mola y de peces de todos los colores, tuve una intensa sensación de estar conectado con la naturaleza.",
+        "tool": "es-ES"
+    },
+    {
+        "prompt": "Translate the following English text to Portuguese:\n\nEnglish: Have you ever stopped to think about clothes for disabled people?\nPortuguese:",
+        "ground_truth": "Vocês já pararam pra pensar como é o vestuário das pessoas com deficiência?",
+        "tool": "pt-PT"
+    }
+]
+
+def get_metrics(metrics_filepath: str, ground_truth: str, prediction: str, tool: str) -> Dict[str, Any]:
+    """
+    Calculate BERTScore and other quality metrics for translation evaluation.
+    Caches results to avoid recomputation.
+    """
+    print(f"Calculating metrics: {metrics_filepath}")
+
+    metrics = {
+        'bertscore_model': None,
+        'bertscore_P': None,
+        'bertscore_R': None,
+        'bertscore_F1': None,
+        'grammar_errors': None,
+        'repetition_score': None,
+        'objective_r': None
+    }
+
+    # Load cached scores
+    try:
+        with open(metrics_filepath, 'r', encoding='utf-8') as f:
+            metrics.update(json.load(f))
+    except FileNotFoundError:
+        pass
+
+    # Calculate BERTScore if not cached or model changed
+    if (not metrics["bertscore_P"] or not metrics["bertscore_R"] or
+        not metrics["bertscore_F1"] or metrics["bertscore_model"] != BERTSCORE_MODEL):
+        try:
+            metrics["bertscore_model"] = BERTSCORE_MODEL
+            score = bert_score.score([prediction], [ground_truth], model_type=BERTSCORE_MODEL)
+            metrics["bertscore_P"], metrics["bertscore_R"], metrics["bertscore_F1"] = (
+                score[0].item(), score[1].item(), score[2].item()
+            )
+        except Exception as e:
+            print(f"Warning: BERTScore calculation failed: {e}")
+            metrics["bertscore_P"] = metrics["bertscore_R"] = metrics["bertscore_F1"] = 0.0
+
+    # Calculate grammar errors if not cached
+    if metrics["grammar_errors"] is None:
+        metrics["grammar_errors"] = 0.0
+
+    language_tool = language_tool_python.LanguageTool(tool)
+    try:
+        matches = language_tool.check(prediction)
+        metrics["grammar_errors"] = len(matches) / max(len(prediction.split()), 1)
+    except Exception as e:
+        print(f"Warning: Grammar checking failed: {e}")
+        metrics["grammar_errors"] = 0.0
+
+    # Calculate repetition score if not cached
+    if metrics["repetition_score"] is None:
+        try:
+            words = prediction.split()
+            if len(words) > 0:
+                word_counts = Counter(words)
+                repeated_words = sum(count - 1 for count in word_counts.values() if count > 1)
+                metrics["repetition_score"] = repeated_words / len(words)
+            else:
+                metrics["repetition_score"] = 0.0
+        except Exception as e:
+            print(f"Warning: Repetition calculation failed: {e}")
+            metrics["repetition_score"] = 0.0
+
+    # Calculate objective score (we want to minimize this)
+    # Higher BERTScore Recall = better translation quality = lower objective value
+    # Add penalties for grammar errors and repetitions
+    if metrics["bertscore_R"] is not None:
+        grammar_penalty = metrics["grammar_errors"] * 0.1  # Small penalty for grammar errors
+        repetition_penalty = metrics["repetition_score"] * 0.05  # Small penalty for repetitions
+        metrics["objective_r"] = -(metrics["bertscore_R"] - grammar_penalty - repetition_penalty)
+    else:
+        metrics["objective_r"] = 1.0  # Bad score if BERTScore failed
+
+    # Save metrics to cache
+    try:
+        with open(metrics_filepath, 'w', encoding='utf-8') as f:
+            json.dump(metrics, f, indent=2, ensure_ascii=False)
+    except Exception as e:
+        print(f"Warning: Failed to save metrics: {e}")
+
+    return metrics
+
+def run_binary(run_options_str):
+    """Run the binary and evaluate translation quality using BERTScore."""
+    try:
+        # Parse the command to extract parameters
+        parts = run_options_str.split()
+        model_path = None
+        binary_path = None
+
+        # Find model path and binary path
+        for i, part in enumerate(parts):
+            if part == "-m" and i + 1 < len(parts):
+                model_path = parts[i + 1]
+            elif part.endswith("llama-cli") or part.endswith("main"):
+                binary_path = part
+
+        if not model_path or not binary_path:
+            print("Error: Could not parse model path or binary path from command")
+            return 100.0
+
+        # Create output directory for this run
+        run_hash = hashlib.md5(run_options_str.encode()).hexdigest()[:8]
+        output_dir = f"translation_eval_{run_hash}"
+        os.makedirs(output_dir, exist_ok=True)
+
+        all_scores = []
+
+        # Run translation benchmarks
+        for i, benchmark in enumerate(TRANSLATION_BENCHMARKS):
+            print(f"Running benchmark {i+1}/{len(TRANSLATION_BENCHMARKS)}")
+
+            # Build command for this benchmark - use the base command and add benchmark-specific params
+            benchmark_cmd = run_options_str.split()
+
+            # Add benchmark-specific parameters
+            benchmark_cmd.extend(["--prompt", benchmark["prompt"]])
+
+            # Run the command
+            try:
+                process = subprocess.run(benchmark_cmd,
+                                       stdout=subprocess.PIPE,
+                                       stderr=subprocess.PIPE,
+                                       timeout=120,  # 2 minute timeout per benchmark
+                                       check=False)
+
+                if process.returncode != 0:
+                    print(f"Warning: Benchmark {i+1} failed with return code {process.returncode}")
+                    print(f"STDERR: {process.stderr.decode()}")
+                    all_scores.append(1.0)  # Bad score for failed runs
+                    continue
+
+                # Extract prediction from output
+                output = process.stdout.decode()
+                prediction = output.strip()
+
+                # Remove the prompt from prediction if it's included
+                if benchmark["prompt"] in prediction:
+                    prediction = prediction.split(benchmark["prompt"])[-1].strip()
+
+                # Calculate metrics
+                metrics_filepath = os.path.join(output_dir, f"benchmark_{i}_metrics.json")
+                metrics = get_metrics(metrics_filepath,
+                                    benchmark["ground_truth"], prediction, benchmark["tool"])
+
+                objective_score = metrics.get("objective_r", 1.0)
+                all_scores.append(objective_score)
+
+                print(f"Benchmark {i+1} - BERTScore R: {metrics.get('bertscore_R', 0):.4f}, "
+                      f"Objective: {objective_score:.4f}")
+
+            except subprocess.TimeoutExpired:
+                print(f"Warning: Benchmark {i+1} timed out")
+                all_scores.append(1.0)  # Bad score for timeouts
+            except Exception as e:
+                print(f"Error running benchmark {i+1}: {e}")
+                all_scores.append(1.0)  # Bad score for errors
+
+        # Calculate average score across all benchmarks
+        if all_scores:
+            avg_score = np.mean(all_scores)
+            print(f"Average translation quality objective score: {avg_score:.4f}")
+            return avg_score
+        else:
+            print("Warning: No successful benchmarks")
+            return 100.0  # Bad score if no benchmarks succeeded
+
+    except Exception as e:
+        print(f"Error in run_binary: {e}")
+        return 100.0  # Bad score for any other errors
+
+if __name__ == "__main__":
+    args = parse_args(default_bin='./build/bin/llama-cli')
+
+    # Define quality-focused sampling parameters for optimization
+    run_options_list = [
+        # Core Sampling Parameters (Most Critical for Quality)
+
+        # 1. Temperature - Controls randomness vs determinism
+        ("--temp", [
+            "--temp 0.1",   # Very focused, deterministic
+            "--temp 0.3",   # Focused, good for factual tasks
+            "--temp 0.5",   # Moderate creativity
+            "--temp 0.7",   # Balanced (recommended default)
+            "--temp 0.8",   # Good balance
+            "--temp 0.9",   # More creative
+            "--temp 1.0",   # Creative but coherent
+            "--temp 1.2"    # More creative, potentially less coherent
+        ]),
+
+        # 2. Top-p (Nucleus Sampling) - Controls diversity while maintaining quality
+        ("--top-p", [
+            "--top-p 0.5",   # Very focused
+            "--top-p 0.7",   # Focused, higher quality
+            "--top-p 0.8",   # Good balance
+            "--top-p 0.85",  # Balanced
+            "--top-p 0.9",   # Good balance (recommended)
+            "--top-p 0.95",  # Standard default
+            "--top-p 0.98",  # More diverse
+            "--top-p 1.0"    # No nucleus filtering
+        ]),
+
+        # 3. Top-k - Limits token selection to most probable candidates
+        ("--top-k", [
+            "--top-k 10",   # Very focused
+            "--top-k 20",   # More focused, higher quality
+            "--top-k 30",   # Balanced
+            "--top-k 40",   # Good balance (default)
+            "--top-k 50",   # Balanced, more diverse
+            "--top-k 60",   # More diverse
+            "--top-k 80",   # Very diverse
+            "--top-k 100"   # Most diverse
+        ]),
+
+        # 4. Min-p - Filters out low-probability tokens
+        ("--min-p", [
+            "--min-p 0.01",  # Very permissive
+            "--min-p 0.02",  # Permissive
+            "--min-p 0.05",  # Good default
+            "--min-p 0.08",  # More restrictive
+            "--min-p 0.1",   # Restrictive, higher quality
+            "--min-p 0.15",  # Very restrictive
+            "--min-p 0.2"    # Extremely restrictive
+        ]),
+
+        # Repetition Control (Critical for Coherence)
+
+        # 5. Repeat Penalty - Prevents repetitive text
+        ("--repeat-penalty", [
+            "--repeat-penalty 1.0",   # Disabled
+            "--repeat-penalty 1.02",  # Very light penalty
+            "--repeat-penalty 1.05",  # Light penalty (recommended)
+            "--repeat-penalty 1.1",   # Moderate penalty (recommended)
+            "--repeat-penalty 1.15",  # Moderate-strong penalty
+            "--repeat-penalty 1.2",   # Strong penalty
+            "--repeat-penalty 1.25",  # Very strong penalty
+            "--repeat-penalty 1.3"    # Extreme penalty
+        ]),
+
+        # 6. Repeat Last N - How far back to look for repetitions
+        ("--repeat-last-n", [
+            "--repeat-last-n 16",   # Short context
+            "--repeat-last-n 32",   # Short-medium context
+            "--repeat-last-n 64",   # Balanced default
+            "--repeat-last-n 96",   # Medium-large context
+            "--repeat-last-n 128",  # Large context
+            "--repeat-last-n 192",  # Very large context
+            "--repeat-last-n 256"   # Maximum context
+        ]),
+
+        # Advanced Quality Parameters
+
+        # 7. Typical-p - Promotes contextually coherent tokens
+        ("--typical", [
+            "--typical 1.0",   # Disabled
+            "--typical 0.95",  # Light filtering
+            "--typical 0.9",   # Recommended for quality
+            "--typical 0.85",  # Moderate filtering
+            "--typical 0.8",   # Strong filtering
+            "--typical 0.75",  # Very strong filtering
+            "--typical 0.7"    # Extreme filtering
+        ]),
+
+        # 8. Mirostat - Adaptive sampling for consistent quality
+        ("--mirostat", [
+            "--mirostat 0",  # Disabled (default)
+            "--mirostat 1",  # Mirostat v1
+            "--mirostat 2"   # Mirostat v2 (often better quality)
+        ]),
+
+        # Keep seed constant for reproducible results
+        ("--seed", ["-s 42"]),
+    ]
+
+    def run_str(run_options, model_path, binary_path):
+        """Build command string for llama-cli with translation evaluation."""
+        if isinstance(run_options, dict):
+            run_options = " ".join(run_options.values())
+        # Use the main binary for translation evaluation
+        return f"{binary_path} -m {model_path} --threads 8 -ngl {args.ngl} {run_options}"
+
+    main(args, run_str, run_binary, run_options_list)
diff --git a/scripts/tune/tune_quality_swag.py b/scripts/tune/tune_quality_swag.py
new file mode 100644
index 000000000..1eaedad
--- /dev/null
+++ b/scripts/tune/tune_quality_swag.py
@@ -0,0 +1,172 @@
+import subprocess
+import sys
+import os
+import re
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, script_dir)
+from tune import parse_args, main
+
+def run_binary(run_options_str):
+    """Run the binary and parse HellaSwag accuracy score."""
+    try:
+        process = subprocess.run(run_options_str,
+                                 stdout=subprocess.PIPE,
+                                 stderr=subprocess.PIPE,
+                                 shell=True,
+                                 check=False,  # Don't raise exception on non-zero exit
+                                 timeout=300   # 5 minute timeout
+                                 )
+
+        if process.returncode != 0:
+            print(f"Warning: Process returned non-zero exit code: {process.returncode}")
+            print(f"STDERR: {process.stderr.decode()}")
+            return 100.0  # Return bad score for failed runs
+
+        # Parse HellaSwag accuracy from stdout
+        stdout_text = process.stdout.decode()
+        stderr_text = process.stderr.decode()
+
+        # Look for HellaSwag accuracy patterns in output
+        # Pattern for format: "20      75.00000000%    [53.1299%, 88.8138%]"
+        accuracy_patterns = [
+            r"20\t([\d.]+)%\t\[",
+        ]
+
+        accuracy = None
+        for pattern in accuracy_patterns:
+            match = re.search(pattern, stdout_text, re.IGNORECASE)
+            if match:
+                accuracy = float(match.group(1))
+                # Convert percentage to decimal if needed (values > 1.0 are likely percentages)
+                if accuracy > 1.0:
+                    accuracy = accuracy / 100.0
+                break
+
+        if accuracy is None:
+            print("Warning: Could not parse HellaSwag accuracy from output")
+            print("STDOUT:", stdout_text[:500])  # Show first 500 chars
+            print("STDERR:", stderr_text[:500])
+            return 100.0  # Return bad score for unparseable results
+        else:
+            print(f"HellaSwag accuracy: {accuracy:.4f}")
+
+        # Return negative accuracy since we want to MINIMIZE the objective function
+        # (higher accuracy = lower objective value = better)
+        return -accuracy
+
+    except subprocess.TimeoutExpired:
+        print("Warning: Process timed out")
+        return 100.0  # Return bad score for timeouts
+    except Exception as e:
+        print(f"Error running command: {e}")
+        return 100.0  # Return bad score for other errors
+
+if __name__ == "__main__":
+    args = parse_args(default_bin='./build/bin/llama-perplexity')
+
+    # Define quality-focused sampling parameters for optimization
+    run_options_list = [
+        # Core Sampling Parameters (Most Critical for Quality)
+
+        # 1. Temperature - Controls randomness vs determinism
+        ("--temp", [
+            "--temp 0.1",   # Very focused, deterministic
+            "--temp 0.3",   # Focused, good for factual tasks
+            "--temp 0.5",   # Moderate creativity
+            "--temp 0.7",   # Balanced (recommended default)
+            "--temp 0.8",   # Good balance
+            "--temp 0.9",   # More creative
+            "--temp 1.0",   # Creative but coherent
+            "--temp 1.2"    # More creative, potentially less coherent
+        ]),
+
+        # 2. Top-p (Nucleus Sampling) - Controls diversity while maintaining quality
+        ("--top-p", [
+            "--top-p 0.5",   # Very focused
+            "--top-p 0.7",   # Focused, higher quality
+            "--top-p 0.8",   # Good balance
+            "--top-p 0.85",  # Balanced
+            "--top-p 0.9",   # Good balance (recommended)
+            "--top-p 0.95",  # Standard default
+            "--top-p 0.98",  # More diverse
+            "--top-p 1.0"    # No nucleus filtering
+        ]),
+
+        # 3. Top-k - Limits token selection to most probable candidates
+        ("--top-k", [
+            "--top-k 10",   # Very focused
+            "--top-k 20",   # More focused, higher quality
+            "--top-k 30",   # Balanced
+            "--top-k 40",   # Good balance (default)
+            "--top-k 50",   # Balanced, more diverse
+            "--top-k 60",   # More diverse
+            "--top-k 80",   # Very diverse
+            "--top-k 100"   # Most diverse
+        ]),
+
+        # 4. Min-p - Filters out low-probability tokens
+        ("--min-p", [
+            "--min-p 0.01",  # Very permissive
+            "--min-p 0.02",  # Permissive
+            "--min-p 0.05",  # Good default
+            "--min-p 0.08",  # More restrictive
+            "--min-p 0.1",   # Restrictive, higher quality
+            "--min-p 0.15",  # Very restrictive
+            "--min-p 0.2"    # Extremely restrictive
+        ]),
+
+        # Repetition Control (Critical for Coherence)
+
+        # 5. Repeat Penalty - Prevents repetitive text
+        ("--repeat-penalty", [
+            "--repeat-penalty 1.0",   # Disabled
+            "--repeat-penalty 1.02",  # Very light penalty
+            "--repeat-penalty 1.05",  # Light penalty (recommended)
+            "--repeat-penalty 1.1",   # Moderate penalty (recommended)
+            "--repeat-penalty 1.15",  # Moderate-strong penalty
+            "--repeat-penalty 1.2",   # Strong penalty
+            "--repeat-penalty 1.25",  # Very strong penalty
+            "--repeat-penalty 1.3"    # Extreme penalty
+        ]),
+
+        # 6. Repeat Last N - How far back to look for repetitions
+        ("--repeat-last-n", [
+            "--repeat-last-n 16",   # Short context
+            "--repeat-last-n 32",   # Short-medium context
+            "--repeat-last-n 64",   # Balanced default
+            "--repeat-last-n 96",   # Medium-large context
+            "--repeat-last-n 128",  # Large context
+            "--repeat-last-n 192",  # Very large context
+            "--repeat-last-n 256"   # Maximum context
+        ]),
+
+        # Advanced Quality Parameters
+
+        # 7. Typical-p - Promotes contextually coherent tokens
+        ("--typical", [
+            "--typical 1.0",   # Disabled
+            "--typical 0.95",  # Light filtering
+            "--typical 0.9",   # Recommended for quality
+            "--typical 0.85",  # Moderate filtering
+            "--typical 0.8",   # Strong filtering
+            "--typical 0.75",  # Very strong filtering
+            "--typical 0.7"    # Extreme filtering
+        ]),
+
+        # 8. Mirostat - Adaptive sampling for consistent quality
+        ("--mirostat", [
+            "--mirostat 0",  # Disabled (default)
+            "--mirostat 1",  # Mirostat v1
+            "--mirostat 2"   # Mirostat v2 (often better quality)
+        ]),
+
+        # Keep seed constant for reproducible results
+        ("--seed", ["-s 42"]),
+    ]
+    def run_str(run_options, model_path, binary_path):
+        """Build command string for llama-perplexity with hellaswag evaluation."""
+        run_opts = " ".join(run_options.values())
+        # Use the perplexity command with hellaswag evaluation as specified
+        return f"{binary_path} -m {model_path} -f hellaswag_val_full.txt --hellaswag-tasks 20 --hellaswag -ngl {args.ngl} {run_opts}"
+    main(args, run_str, run_binary, run_options_list)
diff --git a/scripts/tune/tune_requirements.txt b/scripts/tune/tune_requirements.txt
new file mode 100644
index 000000000..50cb56b
--- /dev/null
+++ b/scripts/tune/tune_requirements.txt
@@ -0,0 +1,3 @@
+language_tool_python
+bert_score
+scikit-optimize
diff --git a/scripts/tune/tune_tps.py b/scripts/tune/tune_tps.py
new file mode 100644
index 000000000..8584713
--- /dev/null
+++ b/scripts/tune/tune_tps.py
@@ -0,0 +1,80 @@
+import subprocess
+import sys
+import os
+import re
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, script_dir)
+from tune import parse_args, main
+
+def run_str(run_options, model_path, binary_path):
+        run_opts = " ".join(run_options.values())
+        return f"{binary_path} -m {model_path} -p 'Hello, how are you?' -n 1 {run_opts}"
+
+def run_binary(run_options_str):
+        process = subprocess.run(run_options_str,
+                                 stdout=subprocess.PIPE,
+                                 stderr=subprocess.PIPE,
+                                 shell=True,
+                                 check=True,
+                                 )
+        if process.returncode != 0:
+            raise Exception(f"Error running: '{run_options_str}':\n{process.stderr}")
+
+        # Parse timing information from stderr
+        stderr_text = process.stderr.decode()
+
+        # Updated regex patterns for llama-simple output
+        prompt_eval_time_pattern = r"prompt eval time\s*=\s*([\d.]+)\s*ms"
+        eval_time_pattern = r"eval time\s*=\s*([\d.]+)\s*ms"
+
+        prompt_match = re.search(prompt_eval_time_pattern, stderr_text)
+        eval_match = re.search(eval_time_pattern, stderr_text)
+
+        if prompt_match and eval_match:
+            prompt_eval_time = float(prompt_match.group(1)) / 1000  # Convert to seconds
+            eval_time = float(eval_match.group(1)) / 1000  # Convert to seconds
+        else:
+            # Fallback: look for any timing patterns
+            print("Warning: Could not parse timing info, using fallback")
+            print("STDERR:", stderr_text)
+            return 1000  # High penalty for failed parsing
+
+        print("prompt eval time:", prompt_eval_time)
+        print("eval time:", eval_time)
+
+        return eval_time
+
+if __name__ == "__main__":
+    args = parse_args(default_bin='./build/bin/llama-cli')
+    # Define runtime options to optimize - Core Performance Parameters
+    run_options_list = [
+        # 1. Batch Processing Parameters (most critical for throughput)
+        ("--batch-size", ["--batch-size 31", "--batch-size 64", "--batch-size 128", "--batch-size 256", "--batch-size 512", "--batch-size 1024", "--batch-size 2048"]),
+        ("--ubatch-size", ["--ubatch-size 32", "--ubatch-size 64", "--ubatch-size 128", "--ubatch-size 256", "--ubatch-size 512"]),
+
+        # 2. Context and Memory Parameters
+        ("--ctx-size", ["-c 512", "-c 1024", "-c 2048", "-c 4096", "-c 8192"]),
+        ("--defrag-thold", ["--defrag-thold -1", "--defrag-thold 0.1", "--defrag-thold 0.2", "--defrag-thold 0.5"]),
+
+        # 3. GPU Offloading Parameters (critical for GPU performance)
+        # Set range to a value that makes sense for your model
+        ("--n-gpu-layers", [f"--n-gpu-layers {i}" for i in range(args.ngl)]),
+
+        # 4. CPU Optimization Parameters
+        ("--threads", ["-t 4", "-t 8", "-t 12", "-t 16"]),
+        # ("--prio", ["--prio 0", "--prio 1", "--prio 2"]),
+
+        # 5. Memory and Caching Parameters
+        # ("--use-mmap", ["", "--no-mmap"]),
+        ("--use-mlock", ["--mlock", ""]),
+        ("--kv-unified", ["--kv-unified", ""]),
+
+        # 6. Advanced Performance Features
+        ("--flash-attn", ["--flash-attn", ""]),
+        # ("--no-kv-offload", ["--no-kv-offload", ""]),  # Empty string means don't use the flag
+
+        # Keep seed constant for reproducible results
+        ("--seed", ["-s 42"]),
+    ]
+    main(args, run_str, run_binary, run_options_list)
@jesusmb1995 jesusmb1995 marked this pull request as ready for review September 22, 2025 13:21
@jesusmb1995
Copy link
Author

Fixed pr to temp-latest instead of master

@jpgaribotti jpgaribotti merged commit 646fdc5 into tetherto:temp-latest Sep 22, 2025
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants