feat: Add top-k and groupmax activation functions with bias initialization #14
Open
psycoplankton wants to merge 2 commits into
Open
feat: Add top-k and groupmax activation functions with bias initialization #14psycoplankton wants to merge 2 commits into
psycoplankton wants to merge 2 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[FEATURE] Add top-k and groupmax activation functions with bias initialization #13
This PR adds support for top-$k$ and groupmax feature activation methodologies (standard PyTorch autograd versions) as alternative sparsity bottlenecks to the default JumpReLU in
CLT-Forge.Key Changes
1. Configuration (
src/clt_forge/config/)activation_fn(choice of"jumprelu","topk", or"groupmax") andkfields to bothCLTConfigandCLTTrainingRunnerConfig.2. Model Architecture (
src/clt_forge/clt.py)topkandgroupmaxactivations in theencodemethod.requires_grad = Falsefor thelog_thresholdparameter when using top-$k$ or groupmax activations (since threshold learning is not used).lossmethod to bypass0.0) under top-$k$ / groupmax configs._initialize_b_encby clipping the percentile index withmin(..., B - 1)to preventIndexErroron small batches or small feature dimensions.3. Training & Optimization (
src/clt_forge/training/clt_trainer.py)_initialize_b_encfor all training procedures (including top-$k$/groupmax).Adamoptimizer to exclude any parameters that do not require gradients (such aslog_thresholdduring top-$k$ runs).4. Verification & Testing (
tests/test_topk.py)_initialize_b_encunder bothtopkandgroupmax.Verification Results
All new test cases passed successfully: