Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 209 additions & 5 deletions synaptogen_ml/memristor_modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def init_from_conv_quant(
conv_quant: Conv1DQuant,
num_cycles_init: int,
):
quant_weights = conv_quant.weight_quantizer(conv_quant.weight).detach()
quant_weights = conv_quant.weight_quantizer(
conv_quant.weight
).detach() # [out, in, kernel]
self.bias = conv_quant.bias

# handle weight sign separately because integer division with negative numbers does not work as expected
Expand Down Expand Up @@ -276,7 +278,10 @@ def init_from_conv_quant(
conv_quant: Conv2dQuant,
num_cycles_init: int,
):
quant_weights = conv_quant.weight_quantizer(conv_quant.weight).detach()
quant_weights = conv_quant.weight_quantizer(
conv_quant.weight
).detach() # out channels, in channels, kernel[0], kernel[1]

self.bias = conv_quant.bias

# handle weight sign separately because integer division with negative numbers does not work as expected
Expand All @@ -294,9 +299,9 @@ def init_from_conv_quant(
quant_weights_scaled_abs = quant_weights_scaled_abs % (2**bit)

# re-apply sign and transpose
quant_weights_scaled_transposed = torch.transpose(
quant_weights_scaled_bit * weights_sign, 0, 1
) # [out, in] -> [in, out]
quant_weights_scaled_transposed = (
quant_weights_scaled_bit * weights_sign
) # [in, out, k[0], k[1]]

# Arrays need flat input
flat = torch.flatten(quant_weights_scaled_transposed).cpu()
Expand Down Expand Up @@ -409,6 +414,205 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
result = mem_out.reshape(
batch_size, in1.shape[2], in1.shape[3], self.out_channels
) # [Batch, out_channels, T//S[0], F//S[1]]
if self.bias is not None:
result = result + self.bias
return result.permute(0, 3, 2, 1) # [..., O, T']


class SingleKernelMemristorConv2d(nn.Module):
"""
Memristive 2d-convolution
Currently, supports groups==1

"""

def __init__(
self,
*,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int], Literal["same", "valid"]] = 0,
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
groups: int,
weight_precision: int,
converter_hardware_settings: DacAdcHardwareSettings,
):
super().__init__()

assert in_channels > 0
self.in_channels = in_channels
assert out_channels > 0
self.out_channels = out_channels
assert groups == 1
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
assert all(x > 0 for x in kernel_size)
self.kernel_size = kernel_size
if isinstance(padding, int):
padding = (padding, padding)

if isinstance(padding, tuple):
assert all(x >= 0 for x in padding)
else:
assert padding in ["same", "valid"]
self.padding = padding
assert padding_mode in ["zeros", "reflect", "replicate", "circular"]
self.padding_mode = padding_mode

if isinstance(stride, int):
stride = (stride, stride)
assert all(x > 0 for x in stride)
self.stride = stride

self.memristors = torch.nn.ModuleList(
[
PairedMemristorArrayV2(
in_features=in_channels * kernel_size[0] * kernel_size[1],
out_features=out_channels,
)
for _ in range(weight_precision - 1)
]
)

self.converter = DacAdcPair(hardware_settings=converter_hardware_settings)
assert weight_precision > 0
self.weight_precision = weight_precision

self.input_factor = None
self.output_factor = None

self.initialized = False
self.bias = None

def init_from_conv_quant(
self,
activation_quant: ActivationQuantizer,
conv_quant: Conv2dQuant,
num_cycles_init: int,
):
quant_weights = conv_quant.weight_quantizer(
conv_quant.weight
).detach() # out channels, in channels, kernel[0], kernel[1]
# shape into memristor form with extra axis over in channels
quant_weights = torch.reshape(
quant_weights,
(
self.out_channels,
self.in_channels,
self.kernel_size[0] * self.kernel_size[1],
),
)
# [out, in , k[0]*k[1]]
quant_weights = torch.permute(quant_weights, (1, 0, 2)) # [in, out, k[0]*k[1]]
self.bias = conv_quant.bias

# handle weight sign separately because integer division with negative numbers does not work as expected
# for this case here, e.g. -5 // 2 = -3 instead of -2
weights_sign = torch.sign(quant_weights)
quant_weights_scaled_abs = torch.round(
torch.absolute(quant_weights / conv_quant.weight_quantizer.scale)
).to(dtype=torch.int32)

for i, bit in enumerate(reversed(range(0, self.weight_precision - 1))):
# the weights we want to apply
quant_weights_scaled_bit = quant_weights_scaled_abs // (2**bit)

# the residual weights for the next step
quant_weights_scaled_abs = quant_weights_scaled_abs % (2**bit)

# re-apply sign and transpose
quant_weights_scaled_transposed = torch.transpose(
quant_weights_scaled_bit * weights_sign, 1, 2
) # [in, out, k[0]*k[1]] -> [in, k[0]*k[1], out]

# Arrays need flat input
flat = torch.flatten(quant_weights_scaled_transposed).cpu()

# positive numbers go in the positive line array, negative numbers as positive weight in the negative array
positive_weights = torch.clamp(flat, 0, 1).numpy()
negative_weights = torch.abs(torch.clamp(flat, -1, 0)).numpy()

# apply negative voltage where a weight is set
size = flat.shape[0]
positive_cells = CellArrayCPU(size)
negative_cells = CellArrayCPU(size)
for _ in range(num_cycles_init * 15):
positive_cells.applyVoltage(np.random.uniform(-2.0, 2.0))
negative_cells.applyVoltage(np.random.uniform(-2.0, 2.0))

positive_cells.applyVoltage(2.0)
negative_cells.applyVoltage(2.0)
positive_cells.applyVoltage(positive_weights * -2.0)
negative_cells.applyVoltage(negative_weights * -2.0)

self.memristors[i].init_from_paired_cell_array_input_major(
positive_cells, negative_cells
)

self.input_factor = 1.0 / (activation_quant.scale * activation_quant.quant_max)
self.output_factor = (
conv_quant.weight_quantizer.scale
* activation_quant.scale
* activation_quant.quant_max
)
self.initialized = True

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Applies 2d-convolution.

:param inputs: [..., C, T1, T2]
:return: [..., C', T1', T2']
"""
assert self.initialized

inputs = self.converter.dac(inputs * self.input_factor)
batch_size, in_channels, time_dim1, time_dim2 = inputs.shape

if isinstance(self.padding, tuple):
padding_amount = self.padding
elif self.padding == "same":
padding_amount = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
elif self.padding == "valid":
padding_amount = (0, 0)
else:
raise ValueError(f"Unknown padding mode: {self.padding}")
if any(x > 0 for x in padding_amount):
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
inputs = F.pad(
inputs,
(
padding_amount[1],
padding_amount[1],
padding_amount[0],
padding_amount[0],
),
mode=mode,
)

in0 = inputs
in1 = in0.unfold(-2, self.kernel_size[0], self.stride[0]).unfold(
-2, self.kernel_size[1], self.stride[1]
) # [B, in_c, T//S[0], F//S[1], K_[0], K[1]]
in2 = in1.reshape(
batch_size, in_channels, -1, self.kernel_size[0], self.kernel_size[1]
) # [Batch, in_channels, T//S[0] * F//S[1], K_[0], K_[1]]
in3 = in2.permute(
0, 2, 1, 3, 4
) # [Batch, T//S[0] * F//S[1], in_channels, K_[0], K_[1]]
in4 = in3.reshape(
batch_size, -1, in_channels * self.kernel_size[0] * self.kernel_size[1]
)
out = self.memristors[-1].forward(
in4
) # [Batch, T//S[0] * F//S[1], out_channels]
mem_out = self.converter.adc(out) # [Batch, T//S[0] * F//S[1], out_channels]
for i, bit in enumerate(reversed(range(0, self.weight_precision - 1))):
out = self.memristors[i].forward(in4)
mem_out += self.converter.adc(out) * (2 ** (bit))
result = mem_out * self.output_factor
if self.bias is not None:
result = result + self.bias
return result.permute(0, 3, 1, 2) # [..., O, T']