Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

All notable changes to this project will be documented in this file.

## [Unreleased]

### Added

* Added circuit-tracer CLT loading support for HuggingFace, local safetensors, and circuit-tracer cache sources in attribution workflows.
* Added conversion utilities for saving circuit-tracer attribution graphs and feature metadata in the existing CLT-Forge visual interface format.
* Added a notebook showing how to load open-source circuit-tracer CLTs and visualize them with the CLT-Forge interface.

## [0.1.0] - 2026-02-16

### Added
Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ graph = runner.run(
)
```

You can also run attribution with an open-source circuit-tracer CLT instead of
a CLT-Forge checkpoint:

``` python
from clt_forge.attribution.attribution import AttributionRunner

runner = AttributionRunner.from_circuit_tracer_hub(
hf_ref = "mntss/clt-gemma-2-2b-426k",
model_name = "google/gemma-2-2b",
)

graph = runner.run(
input_string = "The capital of France is",
folder_name = "where/to/save",
run_interventions = False,
)
```

See `notebooks/load_open_source_circuit_tracer_clt.ipynb` for an end-to-end
example that saves a CLT-Forge-compatible graph and opens the existing visual
interface.

------------------------------------------------------------------------

### 5. Start the Visual-Interface
Expand Down
192 changes: 192 additions & 0 deletions notebooks/load_open_source_circuit_tracer_clt.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load an open-source circuit-tracer CLT in CLT-Forge\n",
"\n",
"This notebook loads a trained circuit-tracer CLT from HuggingFace, runs attribution through CLT-Forge's attribution runner, saves a CLT-Forge-compatible graph artifact, optionally converts circuit-tracer feature metadata into CLT-Forge feature JSON files, and opens the existing CLT-Forge visual interface.\n",
"\n",
"It does not change the visual interface. The bridge happens in the Python library layer.\n",
"\n",
"Run it from an environment with CLT-Forge installed, or install the local checkout from the repository root with `python -m pip install -e .`. The first code cell checks for the required notebook dependencies before importing CLT-Forge."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import importlib.util\n",
"import sys\n",
"\n",
"def find_repo_root(start: Path) -> Path | None:\n",
" for candidate in (start, *start.parents):\n",
" if (candidate / \"src\" / \"clt_forge\").exists():\n",
" return candidate\n",
" return None\n",
"\n",
"repo_root = find_repo_root(Path.cwd().resolve())\n",
"if repo_root is not None:\n",
" sys.path.insert(0, str(repo_root / \"src\"))\n",
"\n",
"required_modules = {\n",
" \"clt_forge\": \"clt-forge\",\n",
" \"torch\": \"torch\",\n",
" \"transformers\": \"transformers\",\n",
" \"transformer_lens\": \"transformer-lens\",\n",
" \"huggingface_hub\": \"huggingface-hub\",\n",
" \"safetensors\": \"safetensors\",\n",
" \"dash\": \"dash\",\n",
" \"dash_cytoscape\": \"dash-cytoscape\",\n",
"}\n",
"missing_packages = [\n",
" package_name\n",
" for module_name, package_name in required_modules.items()\n",
" if importlib.util.find_spec(module_name) is None\n",
"]\n",
"if missing_packages:\n",
" raise RuntimeError(\n",
" \"Missing notebook dependencies. Install them from the CLT-Forge repository root with:\\n\\n\"\n",
" \"python -m pip install -e .\\n\"\n",
" f\"python -m pip install {' '.join(missing_packages)}\\n\\n\"\n",
" \"Then restart this notebook kernel.\"\n",
" )\n",
"\n",
"import torch\n",
"\n",
"from clt_forge.attribution.attribution import AttributionRunner\n",
"from clt_forge.attribution.circuit_tracer_features import (\n",
" download_clt_forge_feature_dicts_for_graph,\n",
")\n",
"from clt_forge.frontend.app import main\n",
"from clt_forge.frontend.config.settings import AppConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"dtype = torch.bfloat16 if device == \"cuda\" else torch.float32\n",
"\n",
"# circuit-tracer open-source CLT refs listed in the vendored circuit-tracer README:\n",
"# - mntss/clt-gemma-2-2b-426k\n",
"# - mntss/clt-gemma-2-2b-2.5M\n",
"# - mntss/clt-llama-3.2-1b-524k\n",
"model_name = \"google/gemma-2-2b\"\n",
"circuit_tracer_clt = \"mntss/clt-gemma-2-2b-426k\"\n",
"\n",
"output_base = repo_root if repo_root is not None else Path.cwd()\n",
"output_dir = output_base / \"outputs\" / \"circuit_tracer_gemma_demo\"\n",
"graph_path = output_dir / \"attribution_graph.pt\"\n",
"feature_dict_dir = output_dir / \"feature_dicts\"\n",
"output_dir.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"print(\n",
" f\"Loading circuit-tracer CLT {circuit_tracer_clt!r} \"\n",
" f\"with model {model_name!r} on {device}...\",\n",
" flush=True,\n",
")\n",
"start_time = time.time()\n",
"\n",
"runner = AttributionRunner.from_circuit_tracer_hub(\n",
" hf_ref=circuit_tracer_clt,\n",
" model_name=model_name,\n",
" device=device,\n",
" dtype=dtype,\n",
" backend=\"transformerlens\",\n",
" lazy_encoder=False,\n",
" lazy_decoder=True,\n",
" debug=False,\n",
")\n",
"\n",
"print(f\"Runner ready in {time.time() - start_time:.1f}s\", flush=True)\n",
"\n",
"print(\"Running attribution...\", flush=True)\n",
"start_time = time.time()\n",
"\n",
"result = runner.run(\n",
" input_string=\"The capital of France is\",\n",
" folder_name=str(output_dir),\n",
" graph_name=graph_path.name,\n",
" max_n_logits=5,\n",
" max_feature_nodes=4096,\n",
" batch_size=128,\n",
" offload=\"cpu\",\n",
" run_interventions=False,\n",
")\n",
"\n",
"print(f\"Attribution ready in {time.time() - start_time:.1f}s\", flush=True)\n",
"\n",
"graph_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional: pull circuit-tracer feature examples for the graph's active features\n",
"# and convert them to the CLT-Forge frontend feature JSON layout.\n",
"# Start with a small max_features while exploring; downloading every active\n",
"# feature can be slow for large graphs.\n",
"written_feature_files = download_clt_forge_feature_dicts_for_graph(\n",
" graph_result=result,\n",
" scan=result.get(\"circuit_tracer_scan\", circuit_tracer_clt),\n",
" output_dir=feature_dict_dir,\n",
" max_features=50,\n",
" strict=False,\n",
")\n",
"\n",
"len(written_feature_files)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg = AppConfig(\n",
" attr_graph_path=str(graph_path),\n",
" dict_base_folder=str(feature_dict_dir),\n",
" clt_checkpoint=\"\",\n",
" model_name=model_name,\n",
" model_class_name=\"HookedTransformer\",\n",
" port=8106,\n",
")\n",
"\n",
"main(cfg)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading