Skip to content

Commit f4700a6

Browse files
committed
Initial sparse matrix support copied from Opus
Now training model with data listed in the README
1 parent 84fe83b commit f4700a6

File tree

6 files changed

+350
-4
lines changed

6 files changed

+350
-4
lines changed

model_version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
37cf35f
1+
2828242

torch/rnnoise/rnnoise.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
import torch
22
from torch import nn
33
import torch.nn.functional as F
4+
import sys
5+
import os
6+
7+
sys.path.append(os.path.join(os.path.split(__file__)[0], '..'))
8+
from sparsification import GRUSparsifier
9+
10+
sparsify_start = 2500
11+
sparsify_stop = 8000
12+
sparsify_interval = 50
13+
sparsify_exponent = 3
14+
15+
sparse_params1 = {
16+
'W_hr' : (0.3, [8, 4], True),
17+
'W_hz' : (0.2, [8, 4], True),
18+
'W_hn' : (0.5, [8, 4], True),
19+
'W_ir' : (0.3, [8, 4], False),
20+
'W_iz' : (0.2, [8, 4], False),
21+
'W_in' : (0.5, [8, 4], False)
22+
}
23+
24+
def init_weights(module):
25+
if isinstance(module, nn.GRU):
26+
for p in module.named_parameters():
27+
if p[0].startswith('weight_hh_'):
28+
nn.init.orthogonal_(p[1])
429

530
class RNNoise(nn.Module):
631
def __init__(self, input_dim=65, output_dim=32, cond_size=128, gru_size=256):
@@ -19,6 +44,16 @@ def __init__(self, input_dim=65, output_dim=32, cond_size=128, gru_size=256):
1944
self.vad_dense = nn.Linear(self.gru_size, 1)
2045
nb_params = sum(p.numel() for p in self.parameters())
2146
print(f"model: {nb_params} weights")
47+
self.apply(init_weights)
48+
self.sparsifier = []
49+
self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
50+
self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
51+
self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
52+
53+
54+
def sparsify(self):
55+
for sparsifier in self.sparsifier:
56+
sparsifier.step()
2257

2358
def forward(self, features, states=None):
2459
#print(states)

torch/rnnoise/train_rnnoise.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
model_group = parser.add_argument_group(title="model parameters")
2020
model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 128", default=128)
21-
model_group.add_argument('--gru-size', type=int, help="first conditioning size, default: 256", default=256)
21+
model_group.add_argument('--gru-size', type=int, help="first conditioning size, default: 384", default=384)
2222

2323
training_group = parser.add_argument_group(title="training parameters")
24-
training_group.add_argument('--batch-size', type=int, help="batch size, default: 192", default=192)
24+
training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
2525
training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3)
2626
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 200', default=200)
2727
training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 2000', default=2000)
28-
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 1e-3', default=1e-3)
28+
training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 5e-5', default=5e-5)
2929
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
3030
training_group.add_argument('--gamma', type=float, help='perceptual exponent (default 0.1667)', default=0.1667)
3131

@@ -127,6 +127,7 @@ def mask(g):
127127

128128
loss.backward()
129129
optimizer.step()
130+
model.sparsify()
130131

131132
scheduler.step()
132133

torch/sparsification/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .gru_sparsifier import GRUSparsifier
2+
from .common import sparsify_matrix, calculate_gru_flops_per_step

torch/sparsification/common.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
/* Copyright (c) 2023 Amazon
3+
Written by Jan Buethe */
4+
/*
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions
7+
are met:
8+
9+
- Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
- Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in the
14+
documentation and/or other materials provided with the distribution.
15+
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17+
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20+
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24+
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25+
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
*/
28+
"""
29+
30+
import torch
31+
32+
def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
33+
""" sparsifies matrix with specified block size
34+
35+
Parameters:
36+
-----------
37+
matrix : torch.tensor
38+
matrix to sparsify
39+
density : int
40+
target density
41+
block_size : [int, int]
42+
block size dimensions
43+
keep_diagonal : bool
44+
If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
45+
"""
46+
47+
m, n = matrix.shape
48+
m1, n1 = block_size
49+
50+
if m % m1 or n % n1:
51+
raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
52+
53+
# extract diagonal if keep_diagonal = True
54+
if keep_diagonal:
55+
if m != n:
56+
raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
57+
58+
to_spare = torch.diag(torch.diag(matrix))
59+
matrix = matrix - to_spare
60+
else:
61+
to_spare = torch.zeros_like(matrix)
62+
63+
# calculate energy in sub-blocks
64+
x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
65+
x = x ** 2
66+
block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
67+
68+
number_of_blocks = (m * n) // (m1 * n1)
69+
number_of_survivors = round(number_of_blocks * density)
70+
71+
# masking threshold
72+
if number_of_survivors == 0:
73+
threshold = 0
74+
else:
75+
threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
76+
77+
# create mask
78+
mask = torch.ones_like(block_energies)
79+
mask[block_energies < threshold] = 0
80+
mask = torch.repeat_interleave(mask, m1, dim=0)
81+
mask = torch.repeat_interleave(mask, n1, dim=1)
82+
83+
# perform masking
84+
masked_matrix = mask * matrix + to_spare
85+
86+
if return_mask:
87+
return masked_matrix, mask
88+
else:
89+
return masked_matrix
90+
91+
def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
92+
input_size = gru.input_size
93+
hidden_size = gru.hidden_size
94+
flops = 0
95+
96+
input_density = (
97+
sparsification_dict.get('W_ir', [1])[0]
98+
+ sparsification_dict.get('W_in', [1])[0]
99+
+ sparsification_dict.get('W_iz', [1])[0]
100+
) / 3
101+
102+
recurrent_density = (
103+
sparsification_dict.get('W_hr', [1])[0]
104+
+ sparsification_dict.get('W_hn', [1])[0]
105+
+ sparsification_dict.get('W_hz', [1])[0]
106+
) / 3
107+
108+
# input matrix vector multiplications
109+
if not drop_input:
110+
flops += 2 * 3 * input_size * hidden_size * input_density
111+
112+
# recurrent matrix vector multiplications
113+
flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
114+
115+
# biases
116+
flops += 6 * hidden_size
117+
118+
# activations estimated by 10 flops per activation
119+
flops += 30 * hidden_size
120+
121+
return flops
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""
2+
/* Copyright (c) 2023 Amazon
3+
Written by Jan Buethe */
4+
/*
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions
7+
are met:
8+
9+
- Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
- Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in the
14+
documentation and/or other materials provided with the distribution.
15+
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17+
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20+
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24+
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25+
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
*/
28+
"""
29+
30+
import torch
31+
32+
from .common import sparsify_matrix
33+
34+
35+
class GRUSparsifier:
36+
def __init__(self, task_list, start, stop, interval, exponent=3):
37+
""" Sparsifier for torch.nn.GRUs
38+
39+
Parameters:
40+
-----------
41+
task_list : list
42+
task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
43+
of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
44+
'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
45+
update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
46+
where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
47+
sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
48+
should be kept.
49+
50+
start : int
51+
training step after which sparsification will be started.
52+
53+
stop : int
54+
training step after which sparsification will be completed.
55+
56+
interval : int
57+
sparsification interval for steps between start and stop. After stop sparsification will be
58+
carried out after every call to GRUSparsifier.step()
59+
60+
exponent : float
61+
Interpolation exponent for sparsification interval. In step i sparsification will be carried out
62+
with density (alpha + target_density * (1 * alpha)), where
63+
alpha = ((stop - i) / (start - stop)) ** exponent
64+
65+
Example:
66+
--------
67+
>>> import torch
68+
>>> gru = torch.nn.GRU(10, 20)
69+
>>> sparsify_dict = {
70+
... 'W_ir' : (0.5, [2, 2], False),
71+
... 'W_iz' : (0.6, [2, 2], False),
72+
... 'W_in' : (0.7, [2, 2], False),
73+
... 'W_hr' : (0.1, [4, 4], True),
74+
... 'W_hz' : (0.2, [4, 4], True),
75+
... 'W_hn' : (0.3, [4, 4], True),
76+
... }
77+
>>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
78+
>>> for i in range(100):
79+
... sparsifier.step()
80+
"""
81+
# just copying parameters...
82+
self.start = start
83+
self.stop = stop
84+
self.interval = interval
85+
self.exponent = exponent
86+
self.task_list = task_list
87+
88+
# ... and setting counter to 0
89+
self.step_counter = 0
90+
91+
self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
92+
93+
def step(self, verbose=False):
94+
""" carries out sparsification step
95+
96+
Call this function after optimizer.step in your
97+
training loop.
98+
99+
Parameters:
100+
----------
101+
verbose : bool
102+
if true, densities are printed out
103+
104+
Returns:
105+
--------
106+
None
107+
108+
"""
109+
# compute current interpolation factor
110+
self.step_counter += 1
111+
112+
if self.step_counter < self.start:
113+
return
114+
elif self.step_counter < self.stop:
115+
# update only every self.interval-th interval
116+
if self.step_counter % self.interval:
117+
return
118+
119+
alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
120+
else:
121+
alpha = 0
122+
123+
124+
with torch.no_grad():
125+
for gru, params in self.task_list:
126+
hidden_size = gru.hidden_size
127+
128+
# input weights
129+
for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
130+
if key in params:
131+
density = alpha + (1 - alpha) * params[key][0]
132+
if verbose:
133+
print(f"[{self.step_counter}]: {key} density: {density}")
134+
135+
gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
136+
gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
137+
density, # density
138+
params[key][1], # block_size
139+
params[key][2], # keep_diagonal (might want to set this to False)
140+
return_mask=True
141+
)
142+
143+
if type(self.last_masks[key]) != type(None):
144+
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
145+
print(f"sparsification mask {key} changed for gru {gru}")
146+
147+
self.last_masks[key] = new_mask
148+
149+
# recurrent weights
150+
for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
151+
if key in params:
152+
density = alpha + (1 - alpha) * params[key][0]
153+
if verbose:
154+
print(f"[{self.step_counter}]: {key} density: {density}")
155+
gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
156+
gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
157+
density,
158+
params[key][1], # block_size
159+
params[key][2], # keep_diagonal (might want to set this to False)
160+
return_mask=True
161+
)
162+
163+
if type(self.last_masks[key]) != type(None):
164+
if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
165+
print(f"sparsification mask {key} changed for gru {gru}")
166+
167+
self.last_masks[key] = new_mask
168+
169+
170+
171+
if __name__ == "__main__":
172+
print("Testing sparsifier")
173+
174+
gru = torch.nn.GRU(10, 20)
175+
sparsify_dict = {
176+
'W_ir' : (0.5, [2, 2], False),
177+
'W_iz' : (0.6, [2, 2], False),
178+
'W_in' : (0.7, [2, 2], False),
179+
'W_hr' : (0.1, [4, 4], True),
180+
'W_hz' : (0.2, [4, 4], True),
181+
'W_hn' : (0.3, [4, 4], True),
182+
}
183+
184+
sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
185+
186+
for i in range(100):
187+
sparsifier.step(verbose=True)

0 commit comments

Comments
 (0)