Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def main(args):
parser.add_argument(
"--redis_batch_size", type=int, default=256, help="Batch size for Redis vector operations (default: 256)"
)
parser.add_argument(
"--cross_encoder_model",
type=str,
default=None,
help="Name of the cross-encoder model to use for reranking (default: None)",
)
parser.add_argument(
"--rerank_k",
type=int,
default=10,
help="Number of candidates to rerank (default: 10)",
)
args = parser.parse_args()

main(args)
160 changes: 94 additions & 66 deletions run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def main():
parser.add_argument("--redis_index_name", type=str, default="idx_cache_match")
parser.add_argument("--redis_doc_prefix", type=str, default="cache:")
parser.add_argument("--redis_batch_size", type=int, default=256)
parser.add_argument(
"--cross_encoder_models",
type=str,
nargs="*",
default=None,
help="List of cross-encoder models (optional). If not provided, only bi-encoder is used.",
)
parser.add_argument("--rerank_k", type=int, default=10, help="Number of candidates to rerank.")

args = parser.parse_args()

Expand Down Expand Up @@ -101,76 +109,96 @@ def main():
for model_name in args.models:
print(f"\n Model: {model_name}")

# Sanitize model name for directory structure
safe_model_name = model_name.replace("/", "_")
# Prepare list of cross-encoders to iterate over (None = no reranking)
ce_models = args.cross_encoder_models if args.cross_encoder_models else [None]

for ce_model_name in ce_models:
if ce_model_name:
print(f" Cross-Encoder: {ce_model_name}")
else:
print(f" Cross-Encoder: None (Bi-Encoder only)")

# Sanitize model name for directory structure
safe_model_name = model_name.replace("/", "_")

for run_i in range(1, args.n_runs + 1):
print(f" Run {run_i}/{args.n_runs}...")

# 1. Bootstrapping Logic
# Sample 80% of the universe
run_universe = full_df.sample(
frac=args.sample_ratio, random_state=run_i
) # Use run_i as seed for reproducibility per run

for run_i in range(1, args.n_runs + 1):
print(f" Run {run_i}/{args.n_runs}...")
# Split into Queries (n_samples) and Cache (remainder)
if len(run_universe) <= args.n_samples:
print(
f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping."
)
continue

# 1. Bootstrapping Logic
# Sample 80% of the universe
run_universe = full_df.sample(
frac=args.sample_ratio, random_state=run_i
) # Use run_i as seed for reproducibility per run
queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000)
cache = run_universe.drop(queries.index)

# Split into Queries (n_samples) and Cache (remainder)
if len(run_universe) <= args.n_samples:
print(
f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping."
# Shuffle cache
cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True)
queries = queries.reset_index(drop=True)

# 2. Construct Output Path
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

# Include cross-encoder in output path if used
model_dir_name = safe_model_name
if ce_model_name:
safe_cross_encoder_name = ce_model_name.replace("/", "_")
model_dir_name = f"{safe_model_name}_rerank_{safe_cross_encoder_name}"

run_output_dir = os.path.join(
args.output_dir, dataset_name, model_dir_name, f"run_{run_i}", timestamp
)
continue

queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000)
cache = run_universe.drop(queries.index)

# Shuffle cache
cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True)
queries = queries.reset_index(drop=True)

# 2. Construct Output Path
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
run_output_dir = os.path.join(args.output_dir, dataset_name, safe_model_name, f"run_{run_i}", timestamp)
os.makedirs(run_output_dir, exist_ok=True)

# 3. Prepare Args for Evaluation
eval_args = BenchmarkArgs(
query_log_path=dataset_path, # Not strictly used by logic below but good for reference
sentence_column=args.sentence_column,
output_dir=run_output_dir,
n_samples=args.n_samples,
model_name=model_name,
cache_path=None,
full=args.full,
llm_name=args.llm_name,
llm_model=llm_classifier,
sweep_steps=200, # Default
use_redis=args.use_redis,
redis_url=args.redis_url,
redis_index_name=args.redis_index_name,
redis_doc_prefix=args.redis_doc_prefix,
redis_batch_size=args.redis_batch_size,
# device defaults to code logic
)

# 4. Run Evaluation
try:
print(" Matching...")
if args.use_redis:
queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args)
else:
queries_matched = run_matching(queries.copy(), cache.copy(), eval_args)

print(" Evaluating...")
if args.full:
run_full_evaluation(queries_matched, eval_args)
else:
run_chr_analysis(queries_matched, eval_args)

except Exception as e:
print(f" Error in run {run_i}: {e}")
import traceback

traceback.print_exc()
os.makedirs(run_output_dir, exist_ok=True)

# 3. Prepare Args for Evaluation
eval_args = BenchmarkArgs(
query_log_path=dataset_path, # Not strictly used by logic below but good for reference
sentence_column=args.sentence_column,
output_dir=run_output_dir,
n_samples=args.n_samples,
model_name=model_name,
cache_path=None,
full=args.full,
llm_name=args.llm_name,
llm_model=llm_classifier,
sweep_steps=200, # Default
use_redis=args.use_redis,
redis_url=args.redis_url,
redis_index_name=args.redis_index_name,
redis_doc_prefix=args.redis_doc_prefix,
redis_batch_size=args.redis_batch_size,
cross_encoder_model=ce_model_name,
rerank_k=args.rerank_k,
# device defaults to code logic
)

# 4. Run Evaluation
try:
print(" Matching...")
if args.use_redis:
queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args)
else:
queries_matched = run_matching(queries.copy(), cache.copy(), eval_args)

print(" Evaluating...")
if args.full:
run_full_evaluation(queries_matched, eval_args)
else:
run_chr_analysis(queries_matched, eval_args)

except Exception as e:
print(f" Error in run {run_i}: {e}")
import traceback

traceback.print_exc()

print("\nBenchmark completed.")

Expand Down
15 changes: 15 additions & 0 deletions run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Example usage:
uv run run_benchmark.py \
--dataset_dir "dataset" \
--output_dir "cross_encoder_results" \
--models "Alibaba-NLP/gte-modernbert-base" "redis/langcache-embed-v1" "redis/langcache-embed-v3-small" \
--dataset_names "vizio_unique_medium.csv" "axis_bank_unique_sentences.csv"\
--sentence_column "sentence" \
--n_runs 10 \
--n_samples 10000 \
--sample_ratio 0.8 \
--llm_name "tensoropera/Fox-1-1.6B" \
--full \
--use_redis \
# --cross_encoder_models "redis/langcache-reranker-v1-softmnrl-triplet" "Alibaba-NLP/gte-reranker-modernbert-base" \
# --rerank_k 5
83 changes: 69 additions & 14 deletions scripts/plot_multiple_precision_vs_cache_hit_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def main():
dataset_full_path = os.path.join(base_dir, dataset_name)
if not os.path.exists(dataset_full_path):
continue
fig, ax = plt.subplots(figsize=(10, 7))

# CHANGED: Create two subplots: one for curves, one for the AUC bar chart
fig, (ax_main, ax_bar) = plt.subplots(1, 2, figsize=(18, 8), gridspec_kw={'width_ratios': [2, 1]})
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

# Get base rate from first valid run to compute theoretical curves
Expand All @@ -48,6 +50,9 @@ def main():
if base_rate is not None:
break

# Theoretical AUCs storage
theory_aucs = {}

# Plot theoretical curves
if base_rate is not None:
# Theoretical Perfect (Uniform Negatives)
Expand All @@ -56,15 +61,21 @@ def main():
x_uniform = np.concatenate(([0], x_uniform))
y_uniform = np.concatenate(([1], y_uniform))
auc_uniform = base_rate * (1 - np.log(base_rate))
ax.plot(x_uniform, y_uniform, '--', color='black', label=f"Perfect (Uniform Negs), AUC: {auc_uniform:.3f}")
ax_main.plot(x_uniform, y_uniform, '--', color='black', label=f"Perfect (Uniform Negs), AUC: {auc_uniform:.3f}")
theory_aucs['Uniform'] = auc_uniform

# Theoretical Perfect (Zero Negatives)
x_zeros = [0, base_rate, 1.0]
y_zeros = [1.0, 1.0, base_rate]
auc_zeros = base_rate + 0.5 * (1 - base_rate**2)
ax.plot(x_zeros, y_zeros, ':', color='black', label=f"Perfect (Zero Negs), AUC: {auc_zeros:.3f}")
ax_main.plot(x_zeros, y_zeros, ':', color='black', label=f"Perfect (Zero Negs), AUC: {auc_zeros:.3f}")
theory_aucs['ZeroNegs'] = auc_zeros

sorted_models = sorted(model_data.keys())

# Store data for the bar plot
auc_records = []

for i, model_name in enumerate(sorted_models):
run_paths = model_data[model_name]
precisions_interp = []
Expand All @@ -81,7 +92,7 @@ def main():

try:
df = pd.read_csv(csv_path)
# Remove the last row because it's it's always precision = 1.0
# Remove the last row because it's always precision = 1.0
df = df.iloc[:-1]

x_chr = df["cache_hit_ratio"].values
Expand All @@ -95,7 +106,13 @@ def main():

p_interp = np.interp(common_chr, x_chr, y_prec)
precisions_interp.append(p_interp)
aucs_pchr.append(np.trapezoid(p_interp, common_chr))
# Use numpy.trapezoid (NumPy 2.0) or numpy.trapz (older)
try:
auc_val = np.trapezoid(p_interp, common_chr)
except AttributeError:
auc_val = np.trapz(p_interp, common_chr)

aucs_pchr.append(auc_val)

valid_runs += 1
except Exception as e:
Expand All @@ -111,15 +128,53 @@ def main():

color = colors[i % len(colors)]

label_chr = f"{model_name}, AUC: {mean_auc_pchr:.3f} ± {std_auc_pchr:.3f}"
ax.plot(common_chr, mean_p_chr, label=label_chr, color=color)
label_chr = f"{model_name}, AUC: {mean_auc_pchr:.3f} ± {std_auc_pchr:.3f}" # Simplified label for main plot, AUC is in bar chart
ax_main.plot(common_chr, mean_p_chr, label=label_chr, color=color)
if valid_runs > 1:
ax.fill_between(common_chr, mean_p_chr - std_p_chr, mean_p_chr + std_p_chr, color=color, alpha=0.2)
ax.set_xlabel("Cache Hit Ratio")
ax.set_ylabel("Precision")
ax.set_title("Precision vs Cache Hit Ratio")
ax.grid(True)
ax.legend()
ax_main.fill_between(common_chr, mean_p_chr - std_p_chr, mean_p_chr + std_p_chr, color=color, alpha=0.2)

# Save data for bar chart
auc_records.append({
'name': model_name,
'mean': mean_auc_pchr,
'std': std_auc_pchr,
'color': color
})

# --- Configure Main Curve Plot ---
ax_main.set_xlabel("Cache Hit Ratio")
ax_main.set_ylabel("Precision")
ax_main.set_title("Precision vs Cache Hit Ratio")
ax_main.grid(True)
ax_main.legend()

# --- Configure Bar Chart ---
if auc_records:
# Sort by mean AUC (ascending so best is at top)
auc_records.sort(key=lambda x: x['mean'], reverse=False)

names = [r['name'] for r in auc_records]
means = [r['mean'] for r in auc_records]
stds = [r['std'] for r in auc_records]
bar_colors = [r['color'] for r in auc_records]
y_pos = np.arange(len(names))

ax_bar.barh(y_pos, means, xerr=stds, color=bar_colors, align='center', capsize=5, alpha=0.8)
ax_bar.set_yticks(y_pos)
ax_bar.set_yticklabels(names)
ax_bar.set_xlabel("AUC")
ax_bar.set_title("AUC Comparison")
ax_bar.grid(axis='x', linestyle='--', alpha=0.7)

# Add theoretical lines to bar chart
if 'Uniform' in theory_aucs:
ax_bar.axvline(theory_aucs['Uniform'], color='black', linestyle='--', alpha=0.7)
if 'ZeroNegs' in theory_aucs:
ax_bar.axvline(theory_aucs['ZeroNegs'], color='black', linestyle=':', alpha=0.7)

# Set x-limits to focus on relevant area if needed, or 0-1
# ax_bar.set_xlim(0, 1.05)

fig.suptitle(f"Performance on {dataset_name.split('_')[0]}")
plt.tight_layout()
output_path = os.path.join(dataset_full_path, "precision_vs_cache_hit_ratio.png")
Expand All @@ -129,4 +184,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading