|
| 1 | +""" |
| 2 | +/* Copyright (c) 2023 Amazon |
| 3 | + Written by Jan Buethe */ |
| 4 | +/* |
| 5 | + Redistribution and use in source and binary forms, with or without |
| 6 | + modification, are permitted provided that the following conditions |
| 7 | + are met: |
| 8 | +
|
| 9 | + - Redistributions of source code must retain the above copyright |
| 10 | + notice, this list of conditions and the following disclaimer. |
| 11 | +
|
| 12 | + - Redistributions in binary form must reproduce the above copyright |
| 13 | + notice, this list of conditions and the following disclaimer in the |
| 14 | + documentation and/or other materials provided with the distribution. |
| 15 | +
|
| 16 | + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| 17 | + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| 18 | + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| 19 | + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER |
| 20 | + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, |
| 21 | + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, |
| 22 | + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR |
| 23 | + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF |
| 24 | + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING |
| 25 | + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
| 26 | + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 27 | +*/ |
| 28 | +""" |
| 29 | + |
| 30 | +import torch |
| 31 | + |
| 32 | +from .common import sparsify_matrix |
| 33 | + |
| 34 | + |
| 35 | +class GRUSparsifier: |
| 36 | + def __init__(self, task_list, start, stop, interval, exponent=3): |
| 37 | + """ Sparsifier for torch.nn.GRUs |
| 38 | +
|
| 39 | + Parameters: |
| 40 | + ----------- |
| 41 | + task_list : list |
| 42 | + task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance |
| 43 | + of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in', |
| 44 | + 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset, |
| 45 | + update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal), |
| 46 | + where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which |
| 47 | + sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal |
| 48 | + should be kept. |
| 49 | +
|
| 50 | + start : int |
| 51 | + training step after which sparsification will be started. |
| 52 | +
|
| 53 | + stop : int |
| 54 | + training step after which sparsification will be completed. |
| 55 | +
|
| 56 | + interval : int |
| 57 | + sparsification interval for steps between start and stop. After stop sparsification will be |
| 58 | + carried out after every call to GRUSparsifier.step() |
| 59 | +
|
| 60 | + exponent : float |
| 61 | + Interpolation exponent for sparsification interval. In step i sparsification will be carried out |
| 62 | + with density (alpha + target_density * (1 * alpha)), where |
| 63 | + alpha = ((stop - i) / (start - stop)) ** exponent |
| 64 | +
|
| 65 | + Example: |
| 66 | + -------- |
| 67 | + >>> import torch |
| 68 | + >>> gru = torch.nn.GRU(10, 20) |
| 69 | + >>> sparsify_dict = { |
| 70 | + ... 'W_ir' : (0.5, [2, 2], False), |
| 71 | + ... 'W_iz' : (0.6, [2, 2], False), |
| 72 | + ... 'W_in' : (0.7, [2, 2], False), |
| 73 | + ... 'W_hr' : (0.1, [4, 4], True), |
| 74 | + ... 'W_hz' : (0.2, [4, 4], True), |
| 75 | + ... 'W_hn' : (0.3, [4, 4], True), |
| 76 | + ... } |
| 77 | + >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50) |
| 78 | + >>> for i in range(100): |
| 79 | + ... sparsifier.step() |
| 80 | + """ |
| 81 | + # just copying parameters... |
| 82 | + self.start = start |
| 83 | + self.stop = stop |
| 84 | + self.interval = interval |
| 85 | + self.exponent = exponent |
| 86 | + self.task_list = task_list |
| 87 | + |
| 88 | + # ... and setting counter to 0 |
| 89 | + self.step_counter = 0 |
| 90 | + |
| 91 | + self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']} |
| 92 | + |
| 93 | + def step(self, verbose=False): |
| 94 | + """ carries out sparsification step |
| 95 | +
|
| 96 | + Call this function after optimizer.step in your |
| 97 | + training loop. |
| 98 | +
|
| 99 | + Parameters: |
| 100 | + ---------- |
| 101 | + verbose : bool |
| 102 | + if true, densities are printed out |
| 103 | +
|
| 104 | + Returns: |
| 105 | + -------- |
| 106 | + None |
| 107 | +
|
| 108 | + """ |
| 109 | + # compute current interpolation factor |
| 110 | + self.step_counter += 1 |
| 111 | + |
| 112 | + if self.step_counter < self.start: |
| 113 | + return |
| 114 | + elif self.step_counter < self.stop: |
| 115 | + # update only every self.interval-th interval |
| 116 | + if self.step_counter % self.interval: |
| 117 | + return |
| 118 | + |
| 119 | + alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent |
| 120 | + else: |
| 121 | + alpha = 0 |
| 122 | + |
| 123 | + |
| 124 | + with torch.no_grad(): |
| 125 | + for gru, params in self.task_list: |
| 126 | + hidden_size = gru.hidden_size |
| 127 | + |
| 128 | + # input weights |
| 129 | + for i, key in enumerate(['W_ir', 'W_iz', 'W_in']): |
| 130 | + if key in params: |
| 131 | + density = alpha + (1 - alpha) * params[key][0] |
| 132 | + if verbose: |
| 133 | + print(f"[{self.step_counter}]: {key} density: {density}") |
| 134 | + |
| 135 | + gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( |
| 136 | + gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ], |
| 137 | + density, # density |
| 138 | + params[key][1], # block_size |
| 139 | + params[key][2], # keep_diagonal (might want to set this to False) |
| 140 | + return_mask=True |
| 141 | + ) |
| 142 | + |
| 143 | + if type(self.last_masks[key]) != type(None): |
| 144 | + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: |
| 145 | + print(f"sparsification mask {key} changed for gru {gru}") |
| 146 | + |
| 147 | + self.last_masks[key] = new_mask |
| 148 | + |
| 149 | + # recurrent weights |
| 150 | + for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']): |
| 151 | + if key in params: |
| 152 | + density = alpha + (1 - alpha) * params[key][0] |
| 153 | + if verbose: |
| 154 | + print(f"[{self.step_counter}]: {key} density: {density}") |
| 155 | + gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( |
| 156 | + gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ], |
| 157 | + density, |
| 158 | + params[key][1], # block_size |
| 159 | + params[key][2], # keep_diagonal (might want to set this to False) |
| 160 | + return_mask=True |
| 161 | + ) |
| 162 | + |
| 163 | + if type(self.last_masks[key]) != type(None): |
| 164 | + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: |
| 165 | + print(f"sparsification mask {key} changed for gru {gru}") |
| 166 | + |
| 167 | + self.last_masks[key] = new_mask |
| 168 | + |
| 169 | + |
| 170 | + |
| 171 | +if __name__ == "__main__": |
| 172 | + print("Testing sparsifier") |
| 173 | + |
| 174 | + gru = torch.nn.GRU(10, 20) |
| 175 | + sparsify_dict = { |
| 176 | + 'W_ir' : (0.5, [2, 2], False), |
| 177 | + 'W_iz' : (0.6, [2, 2], False), |
| 178 | + 'W_in' : (0.7, [2, 2], False), |
| 179 | + 'W_hr' : (0.1, [4, 4], True), |
| 180 | + 'W_hz' : (0.2, [4, 4], True), |
| 181 | + 'W_hn' : (0.3, [4, 4], True), |
| 182 | + } |
| 183 | + |
| 184 | + sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10) |
| 185 | + |
| 186 | + for i in range(100): |
| 187 | + sparsifier.step(verbose=True) |
0 commit comments