diff --git a/connectomics/config/defaults.py b/connectomics/config/defaults.py index 436ba09d..eb0a5fc3 100755 --- a/connectomics/config/defaults.py +++ b/connectomics/config/defaults.py @@ -10,7 +10,7 @@ # ----------------------------------------------------------------------------- _C.SYSTEM = CN() -_C.SYSTEM.NUM_GPUS = 4 +_C.SYSTEM.NUM_GPUS = 1 _C.SYSTEM.NUM_CPUS = 4 # Run distributed training using DistributedDataparallel model _C.SYSTEM.DISTRIBUTED = False @@ -109,6 +109,27 @@ # Predict an auxiliary output (only works with 2D DeeplabV3) _C.MODEL.AUX_OUT = False +## EXCLUSIVE TO SWINTRANSFORMERS + +_C.MODEL.PATCH_SIZE = (4,4,4) +_C.MODEL.DEPTHS = [2,2,2,2,2] +_C.MODEL.NUM_HEADS = [3,6,12,24,24] +_C.MODEL.WINDOW_SIIE = (2,7,7) +_C.MODEL.MLP_RATIO = 4. +_C.MODEL.QKV_BIAS = True +_C.MODEL.QK_SCALE = None +_C.MODEL.DROP_RATE = 0. +_C.MODEL.ATTN_DROP_RATE = 0. +_C.MODEL.DROP_PATH_RATE = 0.2 +_C.MODEL.USE_CONV = False +_C.MODEL.PATCH_NORM = False +_C.MODEL.FROZEN_STAGES = -1 +_C.MODEL.USE_CHECKPOINT = False +_C.MODEL.EMBED_DIM = 96 +_C.MODEL.DOWNSAMPLE_BEFORE = [True, True, True, True] +_C.MODEL.SWIN_ISOTROPY = [True, True, True, True] + + # ----------------------------------------------------------------------------- # Dataset # ----------------------------------------------------------------------------- diff --git a/connectomics/model/arch/fpn.py b/connectomics/model/arch/fpn.py index f660f6fa..00c12e4b 100755 --- a/connectomics/model/arch/fpn.py +++ b/connectomics/model/arch/fpn.py @@ -54,7 +54,7 @@ def __init__(self, self.filters = filters self.depth = len(filters) - assert len(isotropy) == self.depth + # assert len(isotropy) == self.depth if is_isotropic: isotropy = [True] * self.depth self.isotropy = isotropy @@ -77,6 +77,11 @@ def __init__(self, 'attention': attn, } backbone_kwargs.update(self.shared_kwargs) + self.is_swin = False + if backbone_type == 'swintransformer3d': + backbone_kwargs.update(kwargs) + self.shared_kwargs['norm_mode'] = 'layer' + self.is_swin = True self.backbone = build_backbone( backbone_type, feature_keys, **backbone_kwargs) @@ -84,14 +89,14 @@ def __init__(self, self.latplanes = filters[0] self.latlayers = nn.ModuleList([ - conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0, + conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0, is_swin=self.is_swin, **self.shared_kwargs) for x in filters]) self.smooth = nn.ModuleList() for i in range(self.depth): kernel_size, padding = self._get_kernel_size(isotropy[i]) self.smooth.append(conv3d_norm_act( - self.latplanes, self.latplanes, kernel_size=kernel_size, + self.latplanes, self.latplanes, kernel_size=kernel_size, is_swin=self.is_swin, padding=padding, **self.shared_kwargs)) self.conv_out = self._get_io_conv(out_channel, isotropy[0]) @@ -100,6 +105,7 @@ def __init__(self, model_init(self, init_mode) def forward(self, x): + self.x_size = x.size() z = self.backbone(x) return self._forward_main(z) @@ -113,6 +119,11 @@ def _forward_main(self, z): out = self._up_smooth_add(out, features[i-1], self.smooth[i]) out = self.smooth[0](out) out = self.conv_out(out) + if self.is_swin: + b,c,d,h,w = self.x_size + _b,_c,_d,_h,_w = out.size() + if _d != d or _h != h or _w != w: + out = F.interpolate(out,size=(d,h,w),mode='trilinear') return out def _up_smooth_add(self, x, y, smooth): @@ -138,4 +149,4 @@ def _get_io_conv(self, out_channel, is_isotropic): return conv3d_norm_act( self.filters[0], out_channel, kernel_size_io, padding=padding_io, pad_mode=self.shared_kwargs['pad_mode'], bias=True, - act_mode='none', norm_mode='none') + act_mode='none', norm_mode='none',is_swin=self.is_swin,) diff --git a/connectomics/model/backbone/__init__.py b/connectomics/model/backbone/__init__.py index f2a34b32..097697d9 100755 --- a/connectomics/model/backbone/__init__.py +++ b/connectomics/model/backbone/__init__.py @@ -2,3 +2,4 @@ from .resnet import ResNet3D from .repvgg import RepVGG3D, RepVGGBlock3D from .botnet import BotNet3D +from .swintr import SwinTransformer2D,SwinTransformer3D \ No newline at end of file diff --git a/connectomics/model/backbone/build.py b/connectomics/model/backbone/build.py index 6ec06cc6..6ee4099e 100755 --- a/connectomics/model/backbone/build.py +++ b/connectomics/model/backbone/build.py @@ -7,6 +7,7 @@ from .repvgg import RepVGG3D from .botnet import BotNet3D from .efficientnet import EfficientNet3D +from .swintr import SwinTransformer2D,SwinTransformer3D from ..utils.misc import IntermediateLayerGetter backbone_dict = { @@ -14,13 +15,15 @@ 'repvgg': RepVGG3D, 'botnet': BotNet3D, 'efficientnet': EfficientNet3D, + 'swintransformer2d': SwinTransformer2D, + 'swintransformer3d': SwinTransformer3D, } def build_backbone(backbone_type: str, feat_keys: List[str], **kwargs): - assert backbone_type in ['resnet', 'repvgg', 'botnet', 'efficientnet'] + assert backbone_type in ['resnet', 'repvgg', 'botnet', 'efficientnet','swintransformer2d','swintransformer3d'] return_layers = {'layer0': feat_keys[0], 'layer1': feat_keys[1], 'layer2': feat_keys[2], @@ -28,5 +31,11 @@ def build_backbone(backbone_type: str, 'layer4': feat_keys[4]} backbone = backbone_dict[backbone_type](**kwargs) - assert len(feat_keys) == backbone.num_stages + if backbone_type[:15] =='swintransformer': + if backbone.use_conv: + assert len(feat_keys) == backbone.num_layers + 2 + else: + assert len(feat_keys) == backbone.num_layers + 1 + else: + assert len(feat_keys) == backbone.num_stages return IntermediateLayerGetter(backbone, return_layers) diff --git a/connectomics/model/backbone/swintr.py b/connectomics/model/backbone/swintr.py new file mode 100644 index 00000000..004de661 --- /dev/null +++ b/connectomics/model/backbone/swintr.py @@ -0,0 +1,1334 @@ +# Code adapted from https://github.com/microsoft/Swin-Transformer +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, trunc_normal_,to_2tuple + +# from mmcv.runner import load_checkpoint +# from mmaction.utils import get_root_logger + +from functools import reduce, lru_cache +from operator import mul +from einops import rearrange + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def window_partition_3d(x, window_size): + """ + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + return windows + +def window_reverse_3d(windows, window_size, B, D, H, W): + """ + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + return x + +def get_window_size_3d(x_size, window_size, shift_size=None): + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention3D(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wd, Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape( + N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class SwinTransformerBlock3D(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint=use_checkpoint + + assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention3D( + dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size_3d((D, H, W), self.window_size, self.shift_size) + x = self.norm1(x) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + # partition windows + x_windows = window_partition_3d(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + # merge windows + attn_windows = attn_windows.view(-1, *(window_size+(C,))) + shifted_x = window_reverse_3d(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 >0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :].contiguous() + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + + return x + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +class PatchMerging3D(nn.Module): + + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm,isotropy=False): + super().__init__() + self.dim = dim + self.isotropy = isotropy + if self.isotropy: + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer(8*dim) + else: + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, D, H, W, C). + """ + B, D, H, W, C = x.shape + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + if self.isotropy: + x0 = x[:, 0::2, 0::2, 0::2, :] # B D/2 H/2 W/2 C + x1 = x[:, 0::2, 1::2, 0::2, :] # B D/2 H/2 W/2 C + x2 = x[:, 0::2, 0::2, 1::2, :] # B D/2 H/2 W/2 C + x3 = x[:, 0::2, 1::2, 1::2, :] # B D/2 H/2 W/2 C + x4 = x[:, 1::2, 0::2, 0::2, :] # B D/2 H/2 W/2 C + x5 = x[:, 1::2, 1::2, 0::2, :] # B D/2 H/2 W/2 C + x6 = x[:, 1::2, 0::2, 1::2, :] # B D/2 H/2 W/2 C + x7 = x[:, 1::2, 1::2, 1::2, :] # B D/2 H/2 W/2 C + + x = torch.cat([x0, x1, x2, x3,x4, x5, x6, x7], -1) # B D/2 H/2 W/2 8*C + + else: + x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C + x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C + x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C + x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C + + x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + +class PatchMerging(nn.Module): + """ Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition_3d(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + return attn_mask + +class BasicLayer3D(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (1,7,7). + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(1,7,7), + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + isotropy=False, + downsample_before=False, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.downsample_before = downsample_before + self.downsample = downsample + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock3D( + dim=dim*2 if self.downsample_before and self.downsample else dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for i in range(depth)]) + + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer,isotropy=isotropy) + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for SW-MSA + if self.downsample_before: + B, C, D, H, W = x.shape + x = rearrange(x, 'b c d h w -> b d h w c') + if self.downsample is not None: + x = self.downsample(x) + B, D, H, W, C = x.shape + + window_size, shift_size = get_window_size_3d((D,H,W), self.window_size, self.shift_size) + # x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + + else: + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size_3d((D,H,W), self.window_size, self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(B, D, H, W, -1) + if self.downsample is not None: + x = self.downsample(x) + + x = rearrange(x, 'b d h w c -> b c d h w') + return x + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + +class PatchEmbed3D(nn.Module): + """ Video to Patch Embedding. + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_channel (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + def __init__(self, patch_size=(2,4,4), in_channel=3, embed_dim=96, norm_layer=None): + super().__init__() + self.patch_size = patch_size + + self.in_channel = in_channel + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_channel, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # B C D Wh Ww + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + + return x + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_channel (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_channel=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_channel = in_channel + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_channel, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + +class SwinTransformer3D(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + patch_size (int | tuple(int)): Patch size. Default: (4,4,4). + in_channel (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer: Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + """ + + def __init__(self, + pretrained=None, + pretrained2d=True, + patch_size=(4,4,4), + in_channel=3, + embed_dim=96, + depths=[2, 2, 2, 2], + num_heads=[3, 6, 12, 24], + window_size=(2,7,7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=False, + frozen_stages=-1, + use_checkpoint=False, + swin_isotropy = [False,False,False,False], + use_conv = False, + downsample_before = [True,True,True,True], + **kwargs): + super().__init__() + + self.pretrained = pretrained + self.pretrained2d = pretrained2d + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.frozen_stages = frozen_stages + self.window_size = window_size + self.patch_size = patch_size + self.isotropy = swin_isotropy + assert len(self.isotropy) == self.num_layers + assert len(num_heads) == self.num_layers + assert len(downsample_before) == self.num_layers + self.use_conv = use_conv + if self.use_conv: + assert self.num_layers == 3 + else: + assert self.num_layers == 4 + + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dims = [embed_dim] + for _ in range(self.num_layers-1): + dims.append(embed_dim * 2**_) + # build layers + layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer3D( + dim=dims[i_layer], + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample_before = downsample_before[i_layer], + downsample=PatchMerging3D if i_layer>0 else None, + isotropy=self.isotropy[i_layer], + use_checkpoint=use_checkpoint) + layers.append(layer) + + # split image into non-overlapping patches + patch_embed = PatchEmbed3D( + patch_size=patch_size, in_channel=in_channel, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + pos_drop = nn.Dropout(p=drop_rate) + + if self.use_conv: + # self.layer0 = nn.Sequential(nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # ) #ORIGINAL DIMENSIONS + # self.layer1 = nn.Sequential(patch_embed,pos_drop,) + # self.layer2 = layers[0] + # self.layer3 = layers[1] + # self.layer4 = nn.Sequential(layers[2]) + + + # self.layer0 = nn.Sequential(nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + # patch_embed,pos_drop,) + # self.layer1 = layers[0] + # self.layer2 = layers[1] + # self.layer3 = layers[2] + # self.layer4 = layers[3] + + self.layer0 = nn.Sequential(nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + nn.Conv3d(in_channels = in_channel,out_channels = in_channel,kernel_size = 3,stride = 1,padding = 1,padding_mode = 'replicate'), + ) + self.layer1 = nn.Sequential(patch_embed,pos_drop) + self.layer2 = layers[0] + self.layer3 = layers[1] + self.layer4 = layers[2] + + + else: + + self.layer0 = nn.Sequential(patch_embed,pos_drop) + self.layer1 = layers[0] + self.layer2 = layers[1] + self.layer3 = layers[2] + self.layer4 = layers[3] + + self.num_features = int(embed_dim * 2**(self.num_layers-1)) + + # add a norm layer for each output + self.norm = norm_layer(self.num_features) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1: + self.pos_drop.eval() + for i in range(0, self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def inflate_weights(self, logger): + """Inflate the swin2d parameters to swin3d. + The differences between swin3d and swin2d mainly lie in an extra + axis. To utilize the pretrained parameters in 2d model, + the weight of swin2d models should be inflated to fit in the shapes of + the 3d counterpart. + Args: + logger (logging.Logger): The logger used to print + debugging infomation. + """ + checkpoint = torch.load(self.pretrained, map_location='cpu') + state_dict = checkpoint['model'] + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] + for k in relative_position_index_keys: + del state_dict[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] + for k in attn_mask_keys: + del state_dict[k] + + state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0] + + # bicubic interpolate relative_position_bias_table if not match + relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] + for k in relative_position_bias_table_keys: + relative_position_bias_table_pretrained = state_dict[k] + relative_position_bias_table_current = self.state_dict()[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1) + wd = self.window_size[0] + if nH1 != nH2: + logger.warning(f"Error in loading {k}, passing") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1), + mode='bicubic') + relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) + state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1) + + msg = self.load_state_dict(state_dict, strict=False) + logger.info(msg) + logger.info(f"=> loaded successfully '{self.pretrained}'") + del checkpoint + torch.cuda.empty_cache() + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # if pretrained: + # self.pretrained = pretrained + # if isinstance(self.pretrained, str): + # self.apply(_init_weights) + # logger = get_root_logger() + # logger.info(f'load model from: {self.pretrained}') + + # if self.pretrained2d: + # # Inflate 2D model into 3D model. + # self.inflate_weights(logger) + # else: + # # Directly load 3D model. + # load_checkpoint(self, self.pretrained, strict=False, logger=logger) + if self.pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + if self.use_conv: + x = self.layer0(x.contiguous()) + + x = self.patch_embed(x) + + x = self.pos_drop(x) + + x = self.layer1(x.contiguous()) + + x = self.layer2(x.contiguous()) + + x = self.layer3(x.contiguous()) + + x = self.layer4(x.contiguous()) + else: + x = self.patch_embed(x) + + x = self.pos_drop(x) + + x = self.layer0(x.contiguous()) + + x = self.layer1(x.contiguous()) + + x = self.layer2(x.contiguous()) + + x = self.layer3(x.contiguous()) + + x = self.layer4(x.contiguous()) + + x = rearrange(x, 'n c d h w -> n d h w c') + + x = self.norm(x) + + x = rearrange(x, 'n d h w c -> n c d h w') + + return x + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer3D, self).train(mode) + self._freeze_stages() + +class SwinTransformer2D(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_channel (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_channel=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + **_): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_channel=in_channel, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # if isinstance(pretrained, str): + # self.apply(_init_weights) + # logger = get_root_logger() + # load_checkpoint(self, pretrained, strict=False, logger=logger) + if pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer2D, self).train(mode) + self._freeze_stages() \ No newline at end of file diff --git a/connectomics/model/block/basic.py b/connectomics/model/block/basic.py index 1753ae0a..dbe6fd7c 100755 --- a/connectomics/model/block/basic.py +++ b/connectomics/model/block/basic.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import get_norm_2d, get_norm_3d, get_activation +from ..utils import get_norm_2d, get_norm_3d, get_activation, Rearrange def conv2d_norm_act(in_planes, planes, kernel_size=(3, 3), stride=1, groups=1, @@ -26,13 +26,18 @@ def conv2d_norm_act(in_planes, planes, kernel_size=(3, 3), stride=1, groups=1, def conv3d_norm_act(in_planes, planes, kernel_size=(3, 3, 3), stride=1, groups=1, dilation=(1, 1, 1), padding=(1, 1, 1), bias=False, pad_mode='replicate', - norm_mode='bn', act_mode='relu', return_list=False): + norm_mode='bn', act_mode='relu', return_list=False, is_swin=False): layers = [] layers += [nn.Conv3d(in_planes, planes, kernel_size=kernel_size, stride=stride, groups=groups, padding=padding, padding_mode=pad_mode, dilation=dilation, bias=bias)] - layers += [get_norm_3d(norm_mode, planes)] + if is_swin: + layers += [Rearrange()] + layers += [get_norm_3d(norm_mode, planes)] + layers += [Rearrange(before_norm=False)] + else: + layers += [get_norm_3d(norm_mode, planes)] layers += [get_activation(act_mode)] if return_list: # return a list of layers diff --git a/connectomics/model/build.py b/connectomics/model/build.py index 149c32f4..b30c82f4 100755 --- a/connectomics/model/build.py +++ b/connectomics/model/build.py @@ -3,7 +3,7 @@ import torch.nn as nn from .arch import UNet3D, UNet2D, FPN3D, DeepLabV3, UNetPlus3D -from .backbone import RepVGG3D +from .backbone import RepVGG3D, SwinTransformer2D, SwinTransformer3D MODEL_MAP = { 'unet_3d': UNet3D, @@ -41,6 +41,27 @@ def build_model(cfg, device, rank=None): kwargs['deploy'] = cfg.MODEL.DEPLOY_MODE if cfg.MODEL.BACKBONE == 'botnet': kwargs['fmap_size'] = cfg.MODEL.INPUT_SIZE + if cfg.MODEL.BACKBONE == 'swintransformer3d': + swin_kwargs = { + 'patch_size': cfg.MODEL.PATCH_SIZE, + 'depths': cfg.MODEL.DEPTHS, + 'num_heads': cfg.MODEL.NUM_HEADS, + 'window_size': cfg.MODEL.WINDOW_SIIE, + 'mlp_ratio': cfg.MODEL.MLP_RATIO, + 'qkv_bias': cfg.MODEL.QKV_BIAS, + 'qk_scale': cfg.MODEL.QK_SCALE, + 'drop_rate': cfg.MODEL.DROP_RATE, + 'attn_drop_rate': cfg.MODEL.ATTN_DROP_RATE, + 'drop_path_rate': cfg.MODEL.DROP_PATH_RATE, + 'embed_dim': cfg.MODEL.EMBED_DIM, + 'patch_norm': cfg.MODEL.PATCH_NORM, + 'frozen_stages': cfg.MODEL.FROZEN_STAGES, + 'use_checkpoint': cfg.MODEL.USE_CHECKPOINT, + 'swin_isotropy': cfg.MODEL.SWIN_ISOTROPY, + 'use_conv': cfg.MODEL.USE_CONV, + 'downsample_before': cfg.MODEL.DOWNSAMPLE_BEFORE, + } + kwargs.update(swin_kwargs) if model_arch[:7] == 'deeplab': kwargs['name'] = model_arch diff --git a/connectomics/model/utils/misc.py b/connectomics/model/utils/misc.py index 183241ff..24b5221c 100755 --- a/connectomics/model/utils/misc.py +++ b/connectomics/model/utils/misc.py @@ -6,7 +6,7 @@ from torch import nn import torch.nn.functional as F from torch.jit.annotations import Dict - +from einops import rearrange class IntermediateLayerGetter(nn.ModuleDict): """ @@ -256,7 +256,7 @@ def get_norm_3d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo Returns: nn.Module: the normalization layer """ - assert norm in ["bn", "sync_bn", "gn", "in", "none"], \ + assert norm in ["bn", "sync_bn", "gn", "in", "none","layer"], \ "Get unknown normalization layer key {}".format(norm) if norm == "gn": assert out_channels%8 == 0, "GN requires channels to separable into 8 groups" norm = { @@ -265,6 +265,7 @@ def get_norm_3d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo "in": nn.InstanceNorm3d, "gn": lambda channels: nn.GroupNorm(8, channels), "none": nn.Identity, + "layer": nn.LayerNorm, }[norm] if norm in ["bn", "sync_bn", "in"]: return norm(out_channels, momentum=bn_momentum) @@ -282,7 +283,7 @@ def get_norm_2d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo Returns: nn.Module: the normalization layer """ - assert norm in ["bn", "sync_bn", "gn", "in", "none"], \ + assert norm in ["bn", "sync_bn", "gn", "in", "none","layer"], \ "Get unknown normalization layer key {}".format(norm) norm = { "bn": nn.BatchNorm2d, @@ -290,6 +291,7 @@ def get_norm_2d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo "in": nn.InstanceNorm2d, "gn": lambda channels: nn.GroupNorm(16, channels), "none": nn.Identity, + "layer": nn.LayerNorm, }[norm] if norm in ["bn", "sync_bn", "in"]: return norm(out_channels, momentum=bn_momentum) @@ -325,3 +327,18 @@ def get_norm_1d(norm: str, out_channels: int, bn_momentum: float = 0.1) -> nn.Mo def get_num_params(model): num_param = sum([param.nelement() for param in model.parameters()]) return num_param + +# ---------------------- +# Miscellanous Modules +# ---------------------- + +class Rearrange(nn.Module): + def __init__(self,before_norm=True): + super(Rearrange, self).__init__() + self.before_norm = before_norm + + def forward(self, x): + if self.before_norm: + return rearrange(x, 'n c d h w -> n d h w c') + else: + return rearrange(x, 'n d h w c -> n c d h w') \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py index 10114253..c99d274e 100755 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -155,6 +155,48 @@ def test_build_fpn_with_botnet(self): y1 = model(x) self.assertTupleEqual(tuple(y1.shape), (2, 1, d, h, w)) + def test_build_fpn_with_default_swintransformer(self): + r"""Test building a 3D FPN model with BotNet3D backbone from configs. + """ + cfg = get_cfg_defaults() + cfg.MODEL.ARCHITECTURE = 'fpn_3d' + cfg.MODEL.BACKBONE = 'swintransformer3d' + cfg.MODEL.FILTERS = [96, 96, 192, 384, 768] + cfg.MODEL.SWIN_ISOTROPY = [False,False,False,False] + cfg.MODEL.DEPTHS = [2,2,2,2] + cfg.MODEL.NUM_HEADS = [3,6,12,24] + cfg.MODEL.USE_CONV = False + cfg.MODEL.DOWNSAMPLE_BEFORE = [True,True,True,True] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = build_model(cfg, device).eval() + + d, h, w = cfg.MODEL.INPUT_SIZE + c = cfg.MODEL.IN_PLANES + x = torch.rand(2, c, d, h, w) + y1 = model(x) + self.assertTupleEqual(tuple(y1.shape), (2, c, d, h, w)) + + def test_build_fpn_with_conv_swintransformer(self): + r"""Test building a 3D FPN model with BotNet3D backbone from configs. + """ + cfg = get_cfg_defaults() + cfg.MODEL.ARCHITECTURE = 'fpn_3d' + cfg.MODEL.BACKBONE = 'swintransformer3d' + cfg.MODEL.FILTERS = [1, 96, 96, 192, 384] + cfg.MODEL.SWIN_ISOTROPY = [False,False,False] + cfg.MODEL.DEPTHS = [2,2,2] + cfg.MODEL.NUM_HEADS = [3,6,12] + cfg.MODEL.USE_CONV = True + cfg.MODEL.DOWNSAMPLE_BEFORE = [True,True,True] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = build_model(cfg, device).eval() + + d, h, w = cfg.MODEL.INPUT_SIZE + c = cfg.MODEL.IN_PLANES + x = torch.rand(2, c, d, h, w) + y1 = model(x) + self.assertTupleEqual(tuple(y1.shape), (2, c, d, h, w)) + def test_build_fpn_with_efficientnet(self): r"""Test building a 3D FPN model with EfficientNet3D backbone from configs. """