Skip to content

feat: Add top-k and groupmax activation functions with bias initialization #14

Open
psycoplankton wants to merge 2 commits into
LLM-Interp:masterfrom
psycoplankton:master
Open

feat: Add top-k and groupmax activation functions with bias initialization #14
psycoplankton wants to merge 2 commits into
LLM-Interp:masterfrom
psycoplankton:master

Conversation

@psycoplankton

@psycoplankton psycoplankton commented Jun 22, 2026

Copy link
Copy Markdown

[FEATURE] Add top-k and groupmax activation functions with bias initialization #13

This PR adds support for top-$k$ and groupmax feature activation methodologies (standard PyTorch autograd versions) as alternative sparsity bottlenecks to the default JumpReLU in CLT-Forge.

Key Changes

1. Configuration (src/clt_forge/config/)

  • Added activation_fn (choice of "jumprelu", "topk", or "groupmax") and k fields to both CLTConfig and CLTTrainingRunnerConfig.

2. Model Architecture (src/clt_forge/clt.py)

  • Implemented standard autograd-based topk and groupmax activations in the encode method.
  • Set requires_grad = False for the log_threshold parameter when using top-$k$ or groupmax activations (since threshold learning is not used).
  • Configured the loss method to bypass $L_0$ sparsity penalties and dead feature penalties (setting them to 0.0) under top-$k$ / groupmax configs.
  • Robustified _initialize_b_enc by clipping the percentile index with min(..., B - 1) to prevent IndexError on small batches or small feature dimensions.

3. Training & Optimization (src/clt_forge/training/clt_trainer.py)

  • Enabled _initialize_b_enc for all training procedures (including top-$k$/groupmax).
  • Filtered parameter groups passed to the Adam optimizer to exclude any parameters that do not require gradients (such as log_threshold during top-$k$ runs).

4. Verification & Testing (tests/test_topk.py)

  • Created a new unit test suite containing 8 test cases validating:
    • Sparsity guarantees (exactly $k$ active features per token; exactly $1$ per group for groupmax).
    • Correct gradient propagation through encoder/decoder weights.
    • Zeroed out $L_0$ and dead feature loss calculations.
    • Correct behavior of _initialize_b_enc under both topk and groupmax.

Verification Results

All new test cases passed successfully:

$ PYTHONPATH=src pytest tests/test_topk.py
============================= test session starts ==============================
collected 8 items

tests/test_topk.py ........                                              [100%]
======================== 8 passed, 4 warnings in 2.97s =========================

# Future Additions
There still needs to be the custom autograd for efficient matrix multiplication as is done in Eleuther's Sparsiy. 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant