|
5 | 5 | from typing import Union |
6 | 6 |
|
7 | 7 |
|
8 | | -class Dense(nnx.Module): |
| 8 | +class DenseLayer(nnx.Module): |
9 | 9 | """ |
10 | | - Weight-normalized Dense layer for use in MLP architectures. |
11 | | - |
12 | | - This layer implements weight normalization as described in: |
13 | | - "Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks" |
14 | | - by Salimans & Kingma (arXiv:1602.07868) |
| 10 | + Dense layer with random weight factorization (RWF) for use in MLP architectures. |
15 | 11 | |
16 | 12 | Note: This is not a KAN layer, but a standard MLP building block used in advanced |
17 | 13 | KAN architectures like KKAN (see jaxkan.models module). |
18 | 14 |
|
19 | 15 | Attributes: |
20 | | - rngs (nnx.Rngs): |
21 | | - Random number generator state. |
22 | | - W (nnx.Param): |
23 | | - Weight matrix. |
24 | 16 | g (nnx.Param): |
25 | | - Scale parameter for weight normalization. |
26 | | - b (Union[nnx.Param, None]): |
27 | | - Bias parameter if add_bias is True, else None. |
| 17 | + Scale factor vector of shape (n_out,) from the RWF reparameterization. |
| 18 | + v (nnx.Param): |
| 19 | + Direction matrix of shape (n_in, n_out) from the RWF reparameterization. |
| 20 | + b (nnx.Param or None): |
| 21 | + Bias vector of shape (n_out,), or None if add_bias is False. |
| 22 | + activation (callable or None): |
| 23 | + Activation function applied after the linear transformation, or None. |
28 | 24 | """ |
29 | 25 |
|
30 | | - def __init__(self, n_in: int, n_out: int, init_scheme: str = 'glorot', |
| 26 | + def __init__(self, n_in: int, n_out: int, activation = None, |
| 27 | + RWF: dict = {"mean": 1.0, "std": 0.1}, |
31 | 28 | add_bias: bool = True, seed: int = 42): |
32 | 29 | """ |
33 | | - Initializes a Dense layer with weight normalization. |
| 30 | + Initializes a Dense layer with RWF. |
34 | 31 |
|
35 | 32 | Args: |
36 | 33 | n_in (int): |
37 | 34 | Number of input features. |
38 | 35 | n_out (int): |
39 | 36 | Number of output features. |
40 | | - init_scheme (str): |
41 | | - Initialization scheme for weight matrix W. Options: |
42 | | - - 'glorot' or 'xavier': Glorot/Xavier normal initialization (default) |
43 | | - - 'glorot_uniform': Glorot/Xavier uniform initialization |
44 | | - - 'he' or 'kaiming': He/Kaiming normal initialization |
45 | | - - 'he_uniform': He/Kaiming uniform initialization |
46 | | - - 'lecun': LeCun normal initialization |
47 | | - - 'normal': Standard normal initialization |
48 | | - - 'uniform': Uniform initialization in [-1, 1] |
49 | | - add_bias (bool): |
50 | | - Whether to include a bias term. |
51 | | - seed (int): |
52 | | - Random seed for initialization. |
53 | | - |
| 37 | + activation (callable, optional): |
| 38 | + Activation function applied after the linear transformation. |
| 39 | + Defaults to None. |
| 40 | + RWF (dict, optional): |
| 41 | + Dictionary with keys ``'mean'`` and ``'std'`` controlling the |
| 42 | + log-normal scale of the RWF reparameterization. |
| 43 | + Defaults to ``{"mean": 1.0, "std": 0.1}``. |
| 44 | + add_bias (bool, optional): |
| 45 | + Whether to include a learnable bias term. Defaults to True. |
| 46 | + seed (int, optional): |
| 47 | + Random seed for parameter initialization. Defaults to 42. |
| 48 | +
|
54 | 49 | Example: |
55 | | - >>> layer = Dense(n_in=64, n_out=32, init_scheme='glorot', add_bias=True, seed=42) |
| 50 | + >>> layer = DenseLayer(n_in=64, n_out=32, add_bias=True, seed=42) |
56 | 51 | """ |
57 | 52 | # Setup nnx rngs |
58 | | - self.rngs = nnx.Rngs(seed) |
59 | | - |
60 | | - # Get the initializer based on init_scheme |
61 | | - initializer = self._get_initializer(init_scheme.lower()) |
| 53 | + rngs = nnx.Rngs(seed) |
62 | 54 |
|
63 | | - # Initialize weight matrix W |
64 | | - # Shape: (n_in, n_out) |
65 | | - self.W = nnx.Param(initializer( |
66 | | - self.rngs.params(), (n_in, n_out), jnp.float32)) |
67 | | - |
68 | | - # Initialize scale parameter g (one per output feature) |
69 | | - # Shape: (n_out,) |
70 | | - self.g = nnx.Param(jnp.ones((n_out,))) |
71 | | - |
72 | | - # Initialize bias parameter b |
73 | | - # Shape: (n_out,) |
| 55 | + # Initialize kernel via RWF - shape (n_in, n_out) |
| 56 | + mu, sigma = RWF["mean"], RWF["std"] |
| 57 | + |
| 58 | + # Glorot Initialization |
| 59 | + stddev = jnp.sqrt(2.0/(n_in + n_out)) |
| 60 | + |
| 61 | + # Weight matrix with shape (n_in, n_out) |
| 62 | + w = nnx.initializers.normal(stddev=stddev)( |
| 63 | + rngs.params(), (n_in, n_out), jnp.float32 |
| 64 | + ) |
| 65 | + |
| 66 | + # Reparameterization towards g, v |
| 67 | + g = nnx.initializers.normal(stddev=sigma)( |
| 68 | + rngs.params(), (n_out,), jnp.float32 |
| 69 | + ) |
| 70 | + g += mu |
| 71 | + g = jnp.exp(g) # shape (n_out,) |
| 72 | + v = w/g # shape (n_in, n_out) |
| 73 | + |
| 74 | + self.g = nnx.Param(g) |
| 75 | + self.v = nnx.Param(v) |
| 76 | + |
| 77 | + # Initialize bias - shape (n_out,) |
74 | 78 | if add_bias: |
75 | 79 | self.b = nnx.Param(jnp.zeros((n_out,))) |
76 | 80 | else: |
77 | 81 | self.b = None |
78 | 82 |
|
79 | | - def _get_initializer(self, init_scheme: str): |
80 | | - """ |
81 | | - Returns the appropriate initializer based on the scheme name. |
82 | | -
|
83 | | - Args: |
84 | | - init_scheme (str): |
85 | | - Name of the initialization scheme. |
86 | | -
|
87 | | - Returns: |
88 | | - initializer: |
89 | | - An nnx initializer function. |
90 | | - """ |
91 | | - init_map = { |
92 | | - 'glorot': nnx.initializers.glorot_normal(), |
93 | | - 'xavier': nnx.initializers.glorot_normal(), |
94 | | - 'glorot_uniform': nnx.initializers.glorot_uniform(), |
95 | | - 'xavier_uniform': nnx.initializers.glorot_uniform(), |
96 | | - 'he': nnx.initializers.he_normal(), |
97 | | - 'kaiming': nnx.initializers.he_normal(), |
98 | | - 'he_uniform': nnx.initializers.he_uniform(), |
99 | | - 'kaiming_uniform': nnx.initializers.he_uniform(), |
100 | | - 'lecun': nnx.initializers.lecun_normal(), |
101 | | - 'lecun_uniform': nnx.initializers.lecun_uniform(), |
102 | | - 'normal': nnx.initializers.normal(stddev=1.0), |
103 | | - 'uniform': nnx.initializers.uniform(scale=1.0), |
104 | | - } |
105 | | - |
106 | | - if init_scheme not in init_map: |
107 | | - raise ValueError(f"Unknown init_scheme: {init_scheme}. " |
108 | | - f"Available options: {list(init_map.keys())}") |
| 83 | + self.activation = activation |
109 | 84 |
|
110 | | - return init_map[init_scheme] |
111 | 85 |
|
112 | 86 | def __call__(self, x): |
113 | 87 | """ |
114 | | - Forward pass with weight normalization. |
115 | | - |
116 | | - Computes: y = g * (x @ V) + b, where V = W / ||W||_2 (column-wise) |
| 88 | + Applies the dense layer to the input. |
117 | 89 |
|
118 | 90 | Args: |
119 | | - x (jnp.array): |
120 | | - Input tensor, shape (batch, n_in). |
| 91 | + x (jnp.ndarray): |
| 92 | + Input array of shape (batch, n_in). |
121 | 93 |
|
122 | 94 | Returns: |
123 | | - y (jnp.array): |
124 | | - Output tensor, shape (batch, n_out). |
125 | | - |
| 95 | + jnp.ndarray: |
| 96 | + Output array of shape (batch, n_out). |
| 97 | +
|
126 | 98 | Example: |
127 | | - >>> layer = Dense(n_in=64, n_out=32, seed=42) |
128 | | - >>> x = jax.random.uniform(jax.random.key(0), (100, 64)) |
129 | | - >>> y = layer(x) # shape: (100, 32) |
| 99 | + >>> layer = DenseLayer(n_in=4, n_out=2) |
| 100 | + >>> x = jnp.ones((3, 4)) |
| 101 | + >>> y = layer(x) # shape: (3, 2) |
130 | 102 | """ |
131 | | - # Weight normalization: V = W / ||W||_2 (column-wise) |
132 | | - W_norm = jnp.linalg.norm(self.W, axis=0, keepdims=True) |
133 | | - V = self.W / (W_norm + 1e-8) |
134 | | - |
135 | | - # Compute output: y = g * (x @ V) + b |
136 | | - y = self.g * jnp.dot(x, V) |
| 103 | + # Reconstruct kernel |
| 104 | + g, v = self.g[...], self.v[...] |
| 105 | + kernel = g * v |
| 106 | + |
| 107 | + # Apply kernel and bias |
| 108 | + y = jnp.dot(x, kernel) |
137 | 109 |
|
138 | 110 | if self.b is not None: |
139 | | - y = y + self.b |
| 111 | + y = y + self.b[...] |
| 112 | + |
| 113 | + if self.activation is not None: |
| 114 | + y = self.activation(y) |
140 | 115 |
|
141 | 116 | return y |
0 commit comments