diff --git a/synaptogen_ml/memristor_modules/conv.py b/synaptogen_ml/memristor_modules/conv.py index e62f81c..ae2abff 100644 --- a/synaptogen_ml/memristor_modules/conv.py +++ b/synaptogen_ml/memristor_modules/conv.py @@ -86,7 +86,9 @@ def init_from_conv_quant( conv_quant: Conv1DQuant, num_cycles_init: int, ): - quant_weights = conv_quant.weight_quantizer(conv_quant.weight).detach() + quant_weights = conv_quant.weight_quantizer( + conv_quant.weight + ).detach() # [out, in, kernel] self.bias = conv_quant.bias # handle weight sign separately because integer division with negative numbers does not work as expected @@ -276,7 +278,10 @@ def init_from_conv_quant( conv_quant: Conv2dQuant, num_cycles_init: int, ): - quant_weights = conv_quant.weight_quantizer(conv_quant.weight).detach() + quant_weights = conv_quant.weight_quantizer( + conv_quant.weight + ).detach() # out channels, in channels, kernel[0], kernel[1] + self.bias = conv_quant.bias # handle weight sign separately because integer division with negative numbers does not work as expected @@ -294,9 +299,9 @@ def init_from_conv_quant( quant_weights_scaled_abs = quant_weights_scaled_abs % (2**bit) # re-apply sign and transpose - quant_weights_scaled_transposed = torch.transpose( - quant_weights_scaled_bit * weights_sign, 0, 1 - ) # [out, in] -> [in, out] + quant_weights_scaled_transposed = ( + quant_weights_scaled_bit * weights_sign + ) # [in, out, k[0], k[1]] # Arrays need flat input flat = torch.flatten(quant_weights_scaled_transposed).cpu() @@ -409,6 +414,205 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: result = mem_out.reshape( batch_size, in1.shape[2], in1.shape[3], self.out_channels ) # [Batch, out_channels, T//S[0], F//S[1]] + if self.bias is not None: + result = result + self.bias + return result.permute(0, 3, 2, 1) # [..., O, T'] + + +class SingleKernelMemristorConv2d(nn.Module): + """ + Memristive 2d-convolution + Currently, supports groups==1 + + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int], Literal["same", "valid"]] = 0, + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + groups: int, + weight_precision: int, + converter_hardware_settings: DacAdcHardwareSettings, + ): + super().__init__() + + assert in_channels > 0 + self.in_channels = in_channels + assert out_channels > 0 + self.out_channels = out_channels + assert groups == 1 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + assert all(x > 0 for x in kernel_size) + self.kernel_size = kernel_size + if isinstance(padding, int): + padding = (padding, padding) + + if isinstance(padding, tuple): + assert all(x >= 0 for x in padding) + else: + assert padding in ["same", "valid"] + self.padding = padding + assert padding_mode in ["zeros", "reflect", "replicate", "circular"] + self.padding_mode = padding_mode + + if isinstance(stride, int): + stride = (stride, stride) + assert all(x > 0 for x in stride) + self.stride = stride + + self.memristors = torch.nn.ModuleList( + [ + PairedMemristorArrayV2( + in_features=in_channels * kernel_size[0] * kernel_size[1], + out_features=out_channels, + ) + for _ in range(weight_precision - 1) + ] + ) + + self.converter = DacAdcPair(hardware_settings=converter_hardware_settings) + assert weight_precision > 0 + self.weight_precision = weight_precision + + self.input_factor = None + self.output_factor = None + + self.initialized = False + self.bias = None + + def init_from_conv_quant( + self, + activation_quant: ActivationQuantizer, + conv_quant: Conv2dQuant, + num_cycles_init: int, + ): + quant_weights = conv_quant.weight_quantizer( + conv_quant.weight + ).detach() # out channels, in channels, kernel[0], kernel[1] + # shape into memristor form with extra axis over in channels + quant_weights = torch.reshape( + quant_weights, + ( + self.out_channels, + self.in_channels, + self.kernel_size[0] * self.kernel_size[1], + ), + ) + # [out, in , k[0]*k[1]] + quant_weights = torch.permute(quant_weights, (1, 0, 2)) # [in, out, k[0]*k[1]] + self.bias = conv_quant.bias + + # handle weight sign separately because integer division with negative numbers does not work as expected + # for this case here, e.g. -5 // 2 = -3 instead of -2 + weights_sign = torch.sign(quant_weights) + quant_weights_scaled_abs = torch.round( + torch.absolute(quant_weights / conv_quant.weight_quantizer.scale) + ).to(dtype=torch.int32) + + for i, bit in enumerate(reversed(range(0, self.weight_precision - 1))): + # the weights we want to apply + quant_weights_scaled_bit = quant_weights_scaled_abs // (2**bit) + + # the residual weights for the next step + quant_weights_scaled_abs = quant_weights_scaled_abs % (2**bit) + + # re-apply sign and transpose + quant_weights_scaled_transposed = torch.transpose( + quant_weights_scaled_bit * weights_sign, 1, 2 + ) # [in, out, k[0]*k[1]] -> [in, k[0]*k[1], out] + + # Arrays need flat input + flat = torch.flatten(quant_weights_scaled_transposed).cpu() + + # positive numbers go in the positive line array, negative numbers as positive weight in the negative array + positive_weights = torch.clamp(flat, 0, 1).numpy() + negative_weights = torch.abs(torch.clamp(flat, -1, 0)).numpy() + + # apply negative voltage where a weight is set + size = flat.shape[0] + positive_cells = CellArrayCPU(size) + negative_cells = CellArrayCPU(size) + for _ in range(num_cycles_init * 15): + positive_cells.applyVoltage(np.random.uniform(-2.0, 2.0)) + negative_cells.applyVoltage(np.random.uniform(-2.0, 2.0)) + + positive_cells.applyVoltage(2.0) + negative_cells.applyVoltage(2.0) + positive_cells.applyVoltage(positive_weights * -2.0) + negative_cells.applyVoltage(negative_weights * -2.0) + + self.memristors[i].init_from_paired_cell_array_input_major( + positive_cells, negative_cells + ) + + self.input_factor = 1.0 / (activation_quant.scale * activation_quant.quant_max) + self.output_factor = ( + conv_quant.weight_quantizer.scale + * activation_quant.scale + * activation_quant.quant_max + ) + self.initialized = True + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Applies 2d-convolution. + + :param inputs: [..., C, T1, T2] + :return: [..., C', T1', T2'] + """ + assert self.initialized + + inputs = self.converter.dac(inputs * self.input_factor) + batch_size, in_channels, time_dim1, time_dim2 = inputs.shape + + if isinstance(self.padding, tuple): + padding_amount = self.padding + elif self.padding == "same": + padding_amount = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) + elif self.padding == "valid": + padding_amount = (0, 0) + else: + raise ValueError(f"Unknown padding mode: {self.padding}") + if any(x > 0 for x in padding_amount): + mode = "constant" if self.padding_mode == "zeros" else self.padding_mode + inputs = F.pad( + inputs, + ( + padding_amount[1], + padding_amount[1], + padding_amount[0], + padding_amount[0], + ), + mode=mode, + ) + + in0 = inputs + in1 = in0.unfold(-2, self.kernel_size[0], self.stride[0]).unfold( + -2, self.kernel_size[1], self.stride[1] + ) # [B, in_c, T//S[0], F//S[1], K_[0], K[1]] + in2 = in1.reshape( + batch_size, in_channels, -1, self.kernel_size[0], self.kernel_size[1] + ) # [Batch, in_channels, T//S[0] * F//S[1], K_[0], K_[1]] + in3 = in2.permute( + 0, 2, 1, 3, 4 + ) # [Batch, T//S[0] * F//S[1], in_channels, K_[0], K_[1]] + in4 = in3.reshape( + batch_size, -1, in_channels * self.kernel_size[0] * self.kernel_size[1] + ) + out = self.memristors[-1].forward( + in4 + ) # [Batch, T//S[0] * F//S[1], out_channels] + mem_out = self.converter.adc(out) # [Batch, T//S[0] * F//S[1], out_channels] + for i, bit in enumerate(reversed(range(0, self.weight_precision - 1))): + out = self.memristors[i].forward(in4) + mem_out += self.converter.adc(out) * (2 ** (bit)) + result = mem_out * self.output_factor if self.bias is not None: result = result + self.bias return result.permute(0, 3, 1, 2) # [..., O, T']