Skip to content

Commit 728b0c4

Browse files
committed
Updated DenseLayer, added docs
1 parent dde1de3 commit 728b0c4

File tree

6 files changed

+113
-130
lines changed

6 files changed

+113
-130
lines changed

docs/about.rst

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@ Research
1515

1616
If you have used jaxKAN in your research, we'd love to hear from you! Below, you can find a list of academic publications that have used jaxKAN.
1717

18-
- Rigas, S., Anagnostopoulos, F., Papachristou, M., & Alexandridis, G. (2026). Training deep physics-informed Kolmogorov–Arnold networks. Computer Methods in Applied Mechanics and Engineering, 452, 118761. https://doi.org/10.1016/j.cma.2026.118761
18+
- Rigas, S., Verma, D., Alexandridis, G., & Wang, Y. (2026). Initialization schemes for Kolmogorov–Arnold networks: An empirical study. The Fourteenth International Conference on Learning Representations (ICLR 2026). https://openreview.net/forum?id=dwNXKkiP51 | `GitHub Reference <https://github.com/srigas/KAN_Initialization_Schemes>`_
1919

20-
- Cerardi, N., Tolley, E., & Mishra, A. (2026). Solving the cosmological Vlasov–Poisson equations with physics-informed Kolmogorov–Arnold networks. Monthly Notices of the Royal Astronomical Society, 545, staf2241. https://doi.org/10.1093/mnras/staf2241 | `GitHub Reference <https://github.com/nicolas-cerardi/cdm-pikan>`_
20+
- Daniels, M., & Rigollet, P. (2026). Splat regression models. The Fourteenth International Conference on Learning Representations (ICLR 2026). https://openreview.net/forum?id=rubeJmT1XM
21+
22+
- Rigas, S., Papaioannou, T., Trakadas, P., & Alexandridis, G. (2026). A Dynamic Framework for Grid Adaptation in Kolmogorov-Arnold Networks. arXiv. https://doi.org/10.48550/arXiv.2601.18672 | `GitHub Reference <https://github.com/srigas/kan_grid>`_
2123

22-
- Daniels, M., & Rigollet, P. (2025). Splat regression models (No. arXiv:2511.14042). arXiv. https://doi.org/10.48550/arXiv.2511.14042
24+
- Leiva, F., Canales, C., Valenzuela, M., & Ruiz-del-Solar, J. (2026). Data-driven control of hydraulic impact hammers under strict operational and control constraints. arXiv. https://doi.org/10.48550/arXiv.2601.07813
2325

24-
- Rigas, S., Verma, D., Alexandridis, G., & Wang, Y. (2025). Initialization schemes for Kolmogorov-Arnold networks: An empirical study. arXiv. https://doi.org/10.48550/ARXIV.2509.03417 | `GitHub Reference <https://github.com/srigas/KAN_Initialization_Schemes>`_
26+
- Rigas, S., Anagnostopoulos, F., Papachristou, M., & Alexandridis, G. (2026). Training deep physics-informed Kolmogorov–Arnold networks. Computer Methods in Applied Mechanics and Engineering, 452, 118761. https://doi.org/10.1016/j.cma.2026.118761
27+
28+
- Cerardi, N., Tolley, E., & Mishra, A. (2026). Solving the cosmological Vlasov–Poisson equations with physics-informed Kolmogorov–Arnold networks. Monthly Notices of the Royal Astronomical Society, 545, staf2241. https://doi.org/10.1093/mnras/staf2241 | `GitHub Reference <https://github.com/nicolas-cerardi/cdm-pikan>`_
2529

2630
- Howard, A. A., Jacob, B., & Stinis, P. (2025). Multifidelity kolmogorov–arnold networks. Machine Learning: Science and Technology, 6(3), 035038. https://doi.org/10.1088/2632-2153/adf702
2731

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
# -- Project information -----------------------------------------------------
1010

1111
project = 'jaxkan'
12-
copyright = '2025, Spyros Rigas, Michalis Papachristou'
12+
copyright = '2026, Spyros Rigas, Michalis Papachristou'
1313
author = 'Spyros Rigas, Michalis Papachristou'
1414

15-
release = '0.3.5'
15+
release = '0.3.6'
1616

1717
# -- General configuration ------------------------------------------------
1818

jaxkan/layers/Dense.py

Lines changed: 69 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,137 +5,112 @@
55
from typing import Union
66

77

8-
class Dense(nnx.Module):
8+
class DenseLayer(nnx.Module):
99
"""
10-
Weight-normalized Dense layer for use in MLP architectures.
11-
12-
This layer implements weight normalization as described in:
13-
"Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks"
14-
by Salimans & Kingma (arXiv:1602.07868)
10+
Dense layer with random weight factorization (RWF) for use in MLP architectures.
1511
1612
Note: This is not a KAN layer, but a standard MLP building block used in advanced
1713
KAN architectures like KKAN (see jaxkan.models module).
1814
1915
Attributes:
20-
rngs (nnx.Rngs):
21-
Random number generator state.
22-
W (nnx.Param):
23-
Weight matrix.
2416
g (nnx.Param):
25-
Scale parameter for weight normalization.
26-
b (Union[nnx.Param, None]):
27-
Bias parameter if add_bias is True, else None.
17+
Scale factor vector of shape (n_out,) from the RWF reparameterization.
18+
v (nnx.Param):
19+
Direction matrix of shape (n_in, n_out) from the RWF reparameterization.
20+
b (nnx.Param or None):
21+
Bias vector of shape (n_out,), or None if add_bias is False.
22+
activation (callable or None):
23+
Activation function applied after the linear transformation, or None.
2824
"""
2925

30-
def __init__(self, n_in: int, n_out: int, init_scheme: str = 'glorot',
26+
def __init__(self, n_in: int, n_out: int, activation = None,
27+
RWF: dict = {"mean": 1.0, "std": 0.1},
3128
add_bias: bool = True, seed: int = 42):
3229
"""
33-
Initializes a Dense layer with weight normalization.
30+
Initializes a Dense layer with RWF.
3431
3532
Args:
3633
n_in (int):
3734
Number of input features.
3835
n_out (int):
3936
Number of output features.
40-
init_scheme (str):
41-
Initialization scheme for weight matrix W. Options:
42-
- 'glorot' or 'xavier': Glorot/Xavier normal initialization (default)
43-
- 'glorot_uniform': Glorot/Xavier uniform initialization
44-
- 'he' or 'kaiming': He/Kaiming normal initialization
45-
- 'he_uniform': He/Kaiming uniform initialization
46-
- 'lecun': LeCun normal initialization
47-
- 'normal': Standard normal initialization
48-
- 'uniform': Uniform initialization in [-1, 1]
49-
add_bias (bool):
50-
Whether to include a bias term.
51-
seed (int):
52-
Random seed for initialization.
53-
37+
activation (callable, optional):
38+
Activation function applied after the linear transformation.
39+
Defaults to None.
40+
RWF (dict, optional):
41+
Dictionary with keys ``'mean'`` and ``'std'`` controlling the
42+
log-normal scale of the RWF reparameterization.
43+
Defaults to ``{"mean": 1.0, "std": 0.1}``.
44+
add_bias (bool, optional):
45+
Whether to include a learnable bias term. Defaults to True.
46+
seed (int, optional):
47+
Random seed for parameter initialization. Defaults to 42.
48+
5449
Example:
55-
>>> layer = Dense(n_in=64, n_out=32, init_scheme='glorot', add_bias=True, seed=42)
50+
>>> layer = DenseLayer(n_in=64, n_out=32, add_bias=True, seed=42)
5651
"""
5752
# Setup nnx rngs
58-
self.rngs = nnx.Rngs(seed)
59-
60-
# Get the initializer based on init_scheme
61-
initializer = self._get_initializer(init_scheme.lower())
53+
rngs = nnx.Rngs(seed)
6254

63-
# Initialize weight matrix W
64-
# Shape: (n_in, n_out)
65-
self.W = nnx.Param(initializer(
66-
self.rngs.params(), (n_in, n_out), jnp.float32))
67-
68-
# Initialize scale parameter g (one per output feature)
69-
# Shape: (n_out,)
70-
self.g = nnx.Param(jnp.ones((n_out,)))
71-
72-
# Initialize bias parameter b
73-
# Shape: (n_out,)
55+
# Initialize kernel via RWF - shape (n_in, n_out)
56+
mu, sigma = RWF["mean"], RWF["std"]
57+
58+
# Glorot Initialization
59+
stddev = jnp.sqrt(2.0/(n_in + n_out))
60+
61+
# Weight matrix with shape (n_in, n_out)
62+
w = nnx.initializers.normal(stddev=stddev)(
63+
rngs.params(), (n_in, n_out), jnp.float32
64+
)
65+
66+
# Reparameterization towards g, v
67+
g = nnx.initializers.normal(stddev=sigma)(
68+
rngs.params(), (n_out,), jnp.float32
69+
)
70+
g += mu
71+
g = jnp.exp(g) # shape (n_out,)
72+
v = w/g # shape (n_in, n_out)
73+
74+
self.g = nnx.Param(g)
75+
self.v = nnx.Param(v)
76+
77+
# Initialize bias - shape (n_out,)
7478
if add_bias:
7579
self.b = nnx.Param(jnp.zeros((n_out,)))
7680
else:
7781
self.b = None
7882

79-
def _get_initializer(self, init_scheme: str):
80-
"""
81-
Returns the appropriate initializer based on the scheme name.
82-
83-
Args:
84-
init_scheme (str):
85-
Name of the initialization scheme.
86-
87-
Returns:
88-
initializer:
89-
An nnx initializer function.
90-
"""
91-
init_map = {
92-
'glorot': nnx.initializers.glorot_normal(),
93-
'xavier': nnx.initializers.glorot_normal(),
94-
'glorot_uniform': nnx.initializers.glorot_uniform(),
95-
'xavier_uniform': nnx.initializers.glorot_uniform(),
96-
'he': nnx.initializers.he_normal(),
97-
'kaiming': nnx.initializers.he_normal(),
98-
'he_uniform': nnx.initializers.he_uniform(),
99-
'kaiming_uniform': nnx.initializers.he_uniform(),
100-
'lecun': nnx.initializers.lecun_normal(),
101-
'lecun_uniform': nnx.initializers.lecun_uniform(),
102-
'normal': nnx.initializers.normal(stddev=1.0),
103-
'uniform': nnx.initializers.uniform(scale=1.0),
104-
}
105-
106-
if init_scheme not in init_map:
107-
raise ValueError(f"Unknown init_scheme: {init_scheme}. "
108-
f"Available options: {list(init_map.keys())}")
83+
self.activation = activation
10984

110-
return init_map[init_scheme]
11185

11286
def __call__(self, x):
11387
"""
114-
Forward pass with weight normalization.
115-
116-
Computes: y = g * (x @ V) + b, where V = W / ||W||_2 (column-wise)
88+
Applies the dense layer to the input.
11789
11890
Args:
119-
x (jnp.array):
120-
Input tensor, shape (batch, n_in).
91+
x (jnp.ndarray):
92+
Input array of shape (batch, n_in).
12193
12294
Returns:
123-
y (jnp.array):
124-
Output tensor, shape (batch, n_out).
125-
95+
jnp.ndarray:
96+
Output array of shape (batch, n_out).
97+
12698
Example:
127-
>>> layer = Dense(n_in=64, n_out=32, seed=42)
128-
>>> x = jax.random.uniform(jax.random.key(0), (100, 64))
129-
>>> y = layer(x) # shape: (100, 32)
99+
>>> layer = DenseLayer(n_in=4, n_out=2)
100+
>>> x = jnp.ones((3, 4))
101+
>>> y = layer(x) # shape: (3, 2)
130102
"""
131-
# Weight normalization: V = W / ||W||_2 (column-wise)
132-
W_norm = jnp.linalg.norm(self.W, axis=0, keepdims=True)
133-
V = self.W / (W_norm + 1e-8)
134-
135-
# Compute output: y = g * (x @ V) + b
136-
y = self.g * jnp.dot(x, V)
103+
# Reconstruct kernel
104+
g, v = self.g[...], self.v[...]
105+
kernel = g * v
106+
107+
# Apply kernel and bias
108+
y = jnp.dot(x, kernel)
137109

138110
if self.b is not None:
139-
y = y + self.b
111+
y = y + self.b[...]
112+
113+
if self.activation is not None:
114+
y = self.activation(y)
140115

141116
return y

jaxkan/models/KKAN.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Union, List
66

77
from ..layers import get_layer
8-
from ..layers.Dense import Dense
8+
from ..layers.Dense import DenseLayer
99
from ..layers.Chebyshev import Cb
1010
from .utils import get_activation
1111

@@ -112,7 +112,7 @@ class InnerBlock(nnx.Module):
112112
Activation function.
113113
input_embedding (ChebyshevEmbedding):
114114
Chebyshev embedding layer for input.
115-
input_layer (Dense):
115+
input_layer (DenseLayer):
116116
Dense layer after input embedding.
117117
hidden_layers (nnx.List):
118118
List of hidden Dense layers.
@@ -156,19 +156,19 @@ def __init__(self,
156156
self.input_embedding = ChebyshevEmbedding(D_e=D_e)
157157

158158
# Input Dense layer: (D_e + 1) -> H
159-
self.input_layer = Dense(n_in=D_e + 1, n_out=H, seed=seed)
159+
self.input_layer = DenseLayer(n_in=D_e + 1, n_out=H, seed=seed)
160160

161161
# Hidden Dense layers: H -> H
162162
self.hidden_layers = nnx.List([
163-
Dense(n_in=H, n_out=H, seed=seed + i + 1)
163+
DenseLayer(n_in=H, n_out=H, seed=seed + i + 1)
164164
for i in range(L)
165165
])
166166

167167
# Output Chebyshev embedding (operates on H-dimensional vector)
168168
self.output_embedding = ChebyshevEmbedding(D_e=D_e)
169169

170170
# Final Dense layer: H * (D_e + 1) -> m
171-
self.output_layer = Dense(n_in=H * (D_e + 1), n_out=m, seed=seed)
171+
self.output_layer = DenseLayer(n_in=H * (D_e + 1), n_out=m, seed=seed)
172172

173173

174174
def __call__(self, x_p):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
55

66
[project]
77
name = "jaxkan"
8-
version = "0.3.5"
8+
version = "0.3.6"
99
description = "A JAX implementation of Kolmogorov-Arnold Networks"
1010
readme = "README.md"
1111
keywords = ["JAX", "NNX", "KANs", "Kolmogorov-Arnold", "PIKAN"]

tests/layers/test_dense_layer.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax.numpy as jnp
55
from flax import nnx
66

7-
from jaxkan.layers.Dense import Dense
7+
from jaxkan.layers.Dense import DenseLayer
88

99

1010
@pytest.fixture
@@ -26,53 +26,57 @@ def layer_params():
2626

2727
# Tests
2828
def test_dense_layer_initialization(seed, layer_params):
29-
"""Test that Dense layer initializes correctly."""
30-
layer = Dense(**layer_params, seed=seed)
29+
"""Test that DenseLayer initializes correctly."""
30+
layer = DenseLayer(**layer_params, seed=seed)
3131

32-
assert layer.W[...].shape == (layer_params["n_in"], layer_params["n_out"]), "W shape incorrect"
32+
assert layer.v[...].shape == (layer_params["n_in"], layer_params["n_out"]), "v shape incorrect"
3333
assert layer.g[...].shape == (layer_params["n_out"],), "g shape incorrect"
3434
assert layer.b[...].shape == (layer_params["n_out"],), "b shape incorrect"
3535

3636

3737
def test_dense_layer_no_bias(seed, layer_params):
38-
"""Test Dense layer without bias."""
39-
layer = Dense(**layer_params, add_bias=False, seed=seed)
38+
"""Test DenseLayer without bias."""
39+
layer = DenseLayer(**layer_params, add_bias=False, seed=seed)
4040

4141
assert layer.b is None, "Bias should be None when add_bias=False"
4242

4343

4444
def test_dense_layer_forward_pass(seed, layer_params, x):
4545
"""Test forward pass produces correct output shape."""
46-
layer = Dense(**layer_params, seed=seed)
46+
layer = DenseLayer(**layer_params, seed=seed)
4747
y = layer(x)
4848

4949
assert y.shape == (x.shape[0], layer_params["n_out"]), "Forward pass output shape incorrect"
5050

5151

5252
def test_dense_layer_weight_normalization(seed, layer_params, x):
53-
"""Test that weight normalization is applied (columns of V have unit norm)."""
54-
layer = Dense(**layer_params, seed=seed)
53+
"""Test that the RWF kernel is correctly reconstructed as g * v."""
54+
layer = DenseLayer(**layer_params, seed=seed)
5555

56-
# Compute normalized weights as in forward pass
57-
W_norm = jnp.linalg.norm(layer.W[...], axis=0, keepdims=True)
58-
V = layer.W[...] / (W_norm + 1e-8)
56+
# Reconstruct kernel as in forward pass
57+
expected = jnp.dot(x, layer.g[...] * layer.v[...]) + layer.b[...]
58+
actual = layer(x)
5959

60-
# Check that each column has approximately unit norm
61-
col_norms = jnp.linalg.norm(V, axis=0)
62-
assert jnp.allclose(col_norms, 1.0, atol=1e-5), "Normalized weight columns should have unit norm"
60+
assert jnp.allclose(actual, expected, atol=1e-5), "RWF kernel reconstruction mismatch"
6361

6462

6563
def test_dense_layer_init_schemes(seed, layer_params):
66-
"""Test different initialization schemes."""
67-
schemes = ['glorot', 'he', 'lecun', 'normal', 'uniform']
64+
"""Test different RWF configurations."""
65+
rwf_configs = [
66+
{"mean": 1.0, "std": 0.1},
67+
{"mean": 0.5, "std": 0.2},
68+
{"mean": 2.0, "std": 0.05},
69+
{"mean": 0.0, "std": 0.3},
70+
{"mean": 1.5, "std": 0.15},
71+
]
6872

69-
for scheme in schemes:
70-
layer = Dense(**layer_params, init_scheme=scheme, seed=seed)
71-
assert layer.W[...].shape == (layer_params["n_in"], layer_params["n_out"]), \
72-
f"Initialization failed for scheme: {scheme}"
73+
for rwf in rwf_configs:
74+
layer = DenseLayer(**layer_params, RWF=rwf, seed=seed)
75+
assert layer.v[...].shape == (layer_params["n_in"], layer_params["n_out"]), \
76+
f"Initialization failed for RWF config: {rwf}"
7377

7478

7579
def test_dense_layer_invalid_init_scheme(seed, layer_params):
76-
"""Test that invalid init_scheme raises ValueError."""
77-
with pytest.raises(ValueError, match="Unknown init_scheme"):
78-
Dense(**layer_params, init_scheme="invalid_scheme", seed=seed)
80+
"""Test that an incomplete RWF dict raises KeyError."""
81+
with pytest.raises(KeyError):
82+
DenseLayer(**layer_params, RWF={"mean": 1.0}, seed=seed) # missing 'std'

0 commit comments

Comments
 (0)