|
| 1 | +#!/bin/bash |
| 2 | +#SBATCH --job-name=embed-optimus-kg |
| 3 | +#SBATCH --account=kempner_mzitnik_lab |
| 4 | +#SBATCH --partition=kempner_h100 |
| 5 | +#SBATCH --cpus-per-task=8 |
| 6 | +#SBATCH --mem=64G |
| 7 | +#SBATCH --time=0-03:00:00 |
| 8 | +#SBATCH --output=/n/holylfs06/LABS/mzitnik_lab/Users/rshamji/rshamji/simple-evals/embedding_logs/embed_%j.out |
| 9 | +#SBATCH --error=/n/holylfs06/LABS/mzitnik_lab/Users/rshamji/rshamji/simple-evals/embedding_logs/embed_%j.err |
| 10 | +#SBATCH --gres=gpu:1 |
| 11 | + |
| 12 | +################################################################################ |
| 13 | +# EMBED OPTIMUS KG - One-time embedding of OptmusKG nodes using KaLM |
| 14 | +# |
| 15 | +# This script embeds all OptmusKG nodes with KaLM embeddings for use in |
| 16 | +# embedding-based or hybrid (RRF fusion) search modes. |
| 17 | +# |
| 18 | +# Purpose: |
| 19 | +# - Generate dense vector embeddings for all 192,682 OptmusKG nodes |
| 20 | +# - Save embeddings to embeddings_kalm.npy (for embedding/hybrid search) |
| 21 | +# - One-time operation (~5-10 minutes on GPU) |
| 22 | +# |
| 23 | +# Usage: |
| 24 | +# bash embed_optimus_kg.sh # Embed OptmusKG with default settings |
| 25 | +# bash embed_optimus_kg.sh optimus # Explicitly specify graph (default) |
| 26 | +# |
| 27 | +# Prerequisites: |
| 28 | +# - ARK repo with OptmusKG graph data |
| 29 | +# - simple-evals/.venv with sentence-transformers, torch, numpy, pandas installed |
| 30 | +# - Azure GPU node (recommended; CPU will be very slow) |
| 31 | +# |
| 32 | +# Output: |
| 33 | +# - embeddings_kalm.npy saved to: ark/benchmarks/stark/data/graphs/optimus/ |
| 34 | +# - Size: ~2.5GB (192,682 nodes × 3,840 dims × 4 bytes) |
| 35 | +# |
| 36 | +# After completion: |
| 37 | +# - run_5q_optimus_test.sh can use --search-mode embedding or hybrid |
| 38 | +# - Full SLURM jobs can use embedding/hybrid search across 5000 questions |
| 39 | +################################################################################ |
| 40 | + |
| 41 | +set -e |
| 42 | + |
| 43 | +# Colors for output |
| 44 | +RED='\033[0;31m' |
| 45 | +GREEN='\033[0;32m' |
| 46 | +YELLOW='\033[1;33m' |
| 47 | +BLUE='\033[0;34m' |
| 48 | +NC='\033[0m' # No Color |
| 49 | + |
| 50 | +# Configuration |
| 51 | +GRAPH_NAME=${1:-optimus} |
| 52 | +PROJECT_ROOT="/n/holylfs06/LABS/mzitnik_lab/Users/rshamji/rshamji" |
| 53 | +ARK_DIR="${PROJECT_ROOT}/ark" |
| 54 | +GRAPH_PATH="${ARK_DIR}/benchmarks/stark/data/graphs/${GRAPH_NAME}" |
| 55 | +NODES_PARQUET="${GRAPH_PATH}/nodes.parquet" |
| 56 | +EMBEDDINGS_OUTPUT="${GRAPH_PATH}/embeddings_kalm.npy" |
| 57 | + |
| 58 | +################################################################################ |
| 59 | +# STEP 0: VALIDATION |
| 60 | +################################################################################ |
| 61 | + |
| 62 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 63 | +echo -e "${BLUE}EMBED ${GRAPH_NAME^^} KG - KaLM Node Embeddings${NC}" |
| 64 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 65 | + |
| 66 | +# Check ARK repo exists |
| 67 | +if [ ! -d "$ARK_DIR" ]; then |
| 68 | + echo -e "${RED}✗ ERROR: ARK repo not found at ${ARK_DIR}${NC}" |
| 69 | + exit 1 |
| 70 | +fi |
| 71 | +echo -e "${GREEN}✓ ARK repo found at ${ARK_DIR}${NC}" |
| 72 | + |
| 73 | +# Check graph exists |
| 74 | +if [ ! -d "$GRAPH_PATH" ]; then |
| 75 | + echo -e "${RED}✗ ERROR: Graph '${GRAPH_NAME}' not found at ${GRAPH_PATH}${NC}" |
| 76 | + echo " Available graphs:" |
| 77 | + ls -1 "${ARK_DIR}/benchmarks/stark/data/graphs/" |
| 78 | + exit 1 |
| 79 | +fi |
| 80 | +echo -e "${GREEN}✓ Graph '${GRAPH_NAME}' found at ${GRAPH_PATH}${NC}" |
| 81 | + |
| 82 | +# Check nodes.parquet exists |
| 83 | +if [ ! -f "$NODES_PARQUET" ]; then |
| 84 | + echo -e "${RED}✗ ERROR: nodes.parquet not found at ${NODES_PARQUET}${NC}" |
| 85 | + exit 1 |
| 86 | +fi |
| 87 | + |
| 88 | +# Get node count (approximate from file size since parquet is binary) |
| 89 | +NODE_COUNT=$(python3 -c "import pandas; df = pandas.read_parquet('$NODES_PARQUET'); print(len(df))" 2>/dev/null || echo "?") |
| 90 | +echo -e "${GREEN}✓ Nodes file found: ${NODE_COUNT} nodes${NC}" |
| 91 | + |
| 92 | +# Check if embeddings already exist |
| 93 | +if [ -f "$EMBEDDINGS_OUTPUT" ]; then |
| 94 | + EMBED_SIZE=$(ls -lh "$EMBEDDINGS_OUTPUT" | awk '{print $5}') |
| 95 | + echo -e "${YELLOW}⚠ Embeddings already exist at: ${EMBEDDINGS_OUTPUT} (${EMBED_SIZE})${NC}" |
| 96 | + read -p " Overwrite existing embeddings? (y/n): " -n 1 -r |
| 97 | + echo |
| 98 | + if [[ ! $REPLY =~ ^[Yy]$ ]]; then |
| 99 | + echo -e "${GREEN}✓ Using existing embeddings (skipping embedding step)${NC}" |
| 100 | + exit 0 |
| 101 | + fi |
| 102 | +fi |
| 103 | + |
| 104 | +################################################################################ |
| 105 | +# STEP 1: CHECK VENV AND DEPENDENCIES |
| 106 | +################################################################################ |
| 107 | + |
| 108 | +echo "" |
| 109 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 110 | +echo -e "${BLUE}STEP 1: CHECK DEPENDENCIES${NC}" |
| 111 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 112 | + |
| 113 | +VENV_PYTHON="${PROJECT_ROOT}/simple-evals/.venv/bin/python" |
| 114 | + |
| 115 | +if [ ! -f "$VENV_PYTHON" ]; then |
| 116 | + echo -e "${RED}✗ ERROR: Venv Python not found at ${VENV_PYTHON}${NC}" |
| 117 | + echo " Create with: cd simple-evals && python -m venv .venv && .venv/bin/pip install sentence-transformers torch" |
| 118 | + exit 1 |
| 119 | +fi |
| 120 | + |
| 121 | +# Check embedding dependencies |
| 122 | +echo "Checking embedding dependencies..." |
| 123 | +$VENV_PYTHON -c " |
| 124 | +import sys |
| 125 | +missing = [] |
| 126 | +
|
| 127 | +try: |
| 128 | + from sentence_transformers import SentenceTransformer |
| 129 | + print(' ✓ sentence-transformers') |
| 130 | +except ImportError: |
| 131 | + missing.append('sentence-transformers') |
| 132 | + print(' ✗ sentence-transformers') |
| 133 | +
|
| 134 | +try: |
| 135 | + import torch |
| 136 | + print(' ✓ torch') |
| 137 | +except ImportError: |
| 138 | + missing.append('torch') |
| 139 | + print(' ✗ torch') |
| 140 | +
|
| 141 | +try: |
| 142 | + import numpy |
| 143 | + print(' ✓ numpy') |
| 144 | +except ImportError: |
| 145 | + missing.append('numpy') |
| 146 | + print(' ✗ numpy') |
| 147 | +
|
| 148 | +try: |
| 149 | + import pandas |
| 150 | + print(' ✓ pandas') |
| 151 | +except ImportError: |
| 152 | + missing.append('pandas') |
| 153 | + print(' ✗ pandas') |
| 154 | +
|
| 155 | +if missing: |
| 156 | + print(f'\nMissing: {missing}') |
| 157 | + print('Install with: pip install sentence-transformers torch pandas numpy') |
| 158 | + sys.exit(1) |
| 159 | +" |
| 160 | + |
| 161 | +if [ $? -ne 0 ]; then |
| 162 | + echo -e "${RED}✗ ERROR: Missing embedding dependencies${NC}" |
| 163 | + exit 1 |
| 164 | +fi |
| 165 | + |
| 166 | +echo -e "${GREEN}✓ All dependencies present${NC}" |
| 167 | + |
| 168 | +################################################################################ |
| 169 | +# STEP 2: EMBED ALL NODES |
| 170 | +################################################################################ |
| 171 | + |
| 172 | +echo "" |
| 173 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 174 | +echo -e "${BLUE}STEP 2: EMBEDDING ALL NODES WITH KaLM${NC}" |
| 175 | +echo -e "${BLUE}Model: tencent/KaLM-Embedding-Gemma3-12B-2511${NC}" |
| 176 | +echo -e "${BLUE}Batch size: 256 | GPU enabled: auto-detect${NC}" |
| 177 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 178 | + |
| 179 | +echo "" |
| 180 | +echo "Embedding process started..." |
| 181 | +echo " This will take 10-20 minutes depending on GPU availability" |
| 182 | +echo " Embeddings will be saved to: ${EMBEDDINGS_OUTPUT}" |
| 183 | +echo "" |
| 184 | + |
| 185 | +# Run embedding via embed_kg.py |
| 186 | +cd "$PROJECT_ROOT" |
| 187 | + |
| 188 | +$VENV_PYTHON -m simple_evals.embed_kg \ |
| 189 | + --graph-path "$GRAPH_PATH" \ |
| 190 | + --output-path "$EMBEDDINGS_OUTPUT" \ |
| 191 | + --batch-size 256 |
| 192 | + |
| 193 | +if [ $? -ne 0 ]; then |
| 194 | + echo -e "${RED}✗ Embedding FAILED${NC}" |
| 195 | + exit 1 |
| 196 | +fi |
| 197 | + |
| 198 | +################################################################################ |
| 199 | +# STEP 3: VERIFY EMBEDDINGS |
| 200 | +################################################################################ |
| 201 | + |
| 202 | +echo "" |
| 203 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 204 | +echo -e "${BLUE}STEP 3: VERIFY EMBEDDINGS${NC}" |
| 205 | +echo -e "${BLUE}═══════════════════════════════════════════════════════════════${NC}" |
| 206 | + |
| 207 | +if [ ! -f "$EMBEDDINGS_OUTPUT" ]; then |
| 208 | + echo -e "${RED}✗ ERROR: Embeddings file not created at ${EMBEDDINGS_OUTPUT}${NC}" |
| 209 | + exit 1 |
| 210 | +fi |
| 211 | + |
| 212 | +# Check file size |
| 213 | +EMBED_SIZE=$(ls -lh "$EMBEDDINGS_OUTPUT" | awk '{print $5}') |
| 214 | +EMBED_SIZE_BYTES=$(stat -c%s "$EMBEDDINGS_OUTPUT") |
| 215 | +EXPECTED_SIZE=$((192682 * 3840 * 4)) # nodes × dims × float32 |
| 216 | + |
| 217 | +echo -e "${GREEN}✓ Embeddings file exists${NC}" |
| 218 | +echo " Path: ${EMBEDDINGS_OUTPUT}" |
| 219 | +echo " Size: ${EMBED_SIZE} (${EMBED_SIZE_BYTES} bytes)" |
| 220 | + |
| 221 | +# Quick validation |
| 222 | +$VENV_PYTHON "$EMBEDDINGS_OUTPUT" << 'VALIDATE_EMBEDDINGS' |
| 223 | +import numpy as np |
| 224 | +import sys |
| 225 | +from pathlib import Path |
| 226 | +
|
| 227 | +embeddings_path = Path(sys.argv[1]) |
| 228 | +
|
| 229 | +try: |
| 230 | + embeddings = np.load(embeddings_path, allow_pickle=False) |
| 231 | + print(f"\n✓ Embeddings validated") |
| 232 | + print(f" Shape: {embeddings.shape}") |
| 233 | + print(f" Dtype: {embeddings.dtype}") |
| 234 | + print(f" Min value: {embeddings.min():.6f}") |
| 235 | + print(f" Max value: {embeddings.max():.6f}") |
| 236 | + print(f" Mean value: {embeddings.mean():.6f}") |
| 237 | +
|
| 238 | + # Check for NaN/inf |
| 239 | + nan_count = np.isnan(embeddings).sum() |
| 240 | + inf_count = np.isinf(embeddings).sum() |
| 241 | + if nan_count == 0 and inf_count == 0: |
| 242 | + print(f"✓ No NaN or Inf values detected") |
| 243 | + else: |
| 244 | + print(f"⚠ Warning: {nan_count} NaN, {inf_count} Inf values detected") |
| 245 | +
|
| 246 | +except Exception as e: |
| 247 | + print(f"✗ ERROR: Failed to load embeddings: {e}") |
| 248 | + sys.exit(1) |
| 249 | +
|
| 250 | +VALIDATE_EMBEDDINGS |
| 251 | + |
| 252 | +if [ $? -ne 0 ]; then |
| 253 | + echo -e "${RED}✗ Embedding validation failed${NC}" |
| 254 | + exit 1 |
| 255 | +fi |
| 256 | + |
| 257 | +################################################################################ |
| 258 | +# COMPLETION |
| 259 | +################################################################################ |
| 260 | + |
| 261 | +echo "" |
| 262 | +echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" |
| 263 | +echo -e "${GREEN}✓ EMBEDDING COMPLETE - ${GRAPH_NAME^^} KG embeddings ready${NC}" |
| 264 | +echo -e "${GREEN}═══════════════════════════════════════════════════════════════${NC}" |
| 265 | + |
| 266 | +echo "" |
| 267 | +echo "Next steps:" |
| 268 | +echo " 1. Run tests with embedding/hybrid search:" |
| 269 | +echo " bash simple-evals/run_5q_optimus_test.sh 5 ${GRAPH_NAME} embedding" |
| 270 | +echo " bash simple-evals/run_5q_optimus_test.sh 5 ${GRAPH_NAME} hybrid" |
| 271 | +echo "" |
| 272 | +echo " 2. Run full SLURM job with embedding/hybrid search:" |
| 273 | +echo " edit simple-evals/run_ark_healthbench_kg_full.slurm (set SEARCH_MODE=embedding or hybrid)" |
| 274 | +echo " sbatch simple-evals/run_ark_healthbench_kg_full.slurm" |
| 275 | +echo "" |
| 276 | +echo " 3. Embeddings location: ${EMBEDDINGS_OUTPUT}" |
| 277 | +echo "" |
0 commit comments