Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion automation/notebooks-table-data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ MLP Mixer,architectures/mlp-mixer.ipynb,,https://arxiv.org/abs/2105.01601
GloVe Word Embeddings, data_exploration/glove-word-embeddings.ipynb,https://github.com/stanfordnlp/GloVe,https://nlp.stanford.edu/pubs/glove.pdf
Vision Transformer (ViT),architectures/vit.ipynb,,https://arxiv.org/pdf/2010.11929
Multi-Head Attention, modules/multihead-self-attention.ipynb,,https://arxiv.org/abs/1706.03762
ResNet,architectures/resnet.ipynb,,https://arxiv.org/abs/1512.03385
ResNet,architectures/resnet.ipynb,,https://arxiv.org/abs/1512.03385
Convolutional Block Attention, modules/convolutional-block-attention.ipynb,,https://arxiv.org/abs/1807.06521
4 changes: 2 additions & 2 deletions notebooks/architectures/mlp-mixer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "venv",
"language": "python",
"name": "python3"
},
Expand All @@ -243,7 +243,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
219 changes: 219 additions & 0 deletions notebooks/modules/convolutional-block-attention.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![deep-learning-notes](https://github.com/semilleroCV/deep-learning-notes/raw/main/assets/banner-notebook.png)](https://github.com/semilleroCV/deep-learning-notes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## **CBAM: Convolutional Block Attention Module**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"#@title **Install required packages**\n",
"\n",
"! pip install torchinfo"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#@title **Importing libraries**\n",
"\n",
"import torch # 2.3.0+cu121\n",
"import torch.nn as nn\n",
"\n",
"import torchinfo #1.8.0"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 2.3.1+cu121\n"
]
}
],
"source": [
"# Note: Not all dependencies have the __version__ method.\n",
"\n",
"print(f\"torch version: {torch.__version__}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## **CBAM code**"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"class ChannelAttention(nn.Module):\n",
" def __init__(self, in_channels: int, ratio: int = 16):\n",
" super().__init__()\n",
"\n",
" self.avg_pooling = nn.AdaptiveAvgPool2d(1)\n",
" self.max_pooling = nn.AdaptiveAvgPool2d(1)\n",
"\n",
" self.net = nn.Sequential(nn.Conv2d(in_channels, in_channels//ratio, 1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels//ratio, in_channels, 1))\n",
" \n",
" self.act = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" avg_pool = self.net(self.avg_pooling(x))\n",
" max_pool = self.net(self.max_pooling(x))\n",
"\n",
" out = self.act(avg_pool + max_pool)\n",
"\n",
" return out * x\n",
"\n",
"class SpatialAttention(nn.Module):\n",
" def __init__(self, kernel_size: int = 7):\n",
" super().__init__()\n",
"\n",
" self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2)\n",
" self.act = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
"\n",
" avg_x = torch.mean(x, dim=1, keepdim=True)\n",
" max_x = torch.amax(x, dim=1, keepdim=True)\n",
"\n",
" out = torch.cat([avg_x, max_x], dim=1)\n",
" out = self.act(self.conv(out))\n",
"\n",
" return out * x\n",
" \n",
"\n",
"class CBAM(nn.Module):\n",
" def __init__(self, in_channels: int, ratio: int = 16):\n",
" super().__init__()\n",
"\n",
" self.ca = ChannelAttention(in_channels, ratio)\n",
" self.sa = SpatialAttention()\n",
"\n",
" def forward(self, x):\n",
"\n",
" ca_out = self.ca(x) \n",
"\n",
" print(ca_out.shape)\n",
" sa_out = self.sa(ca_out) \n",
"\n",
" x = sa_out + x\n",
"\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 256, 32, 32])\n"
]
},
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"CBAM [1, 256, 32, 32] --\n",
"├─ChannelAttention: 1-1 [1, 256, 32, 32] --\n",
"│ └─AdaptiveAvgPool2d: 2-1 [1, 256, 1, 1] --\n",
"│ └─Sequential: 2-2 [1, 256, 1, 1] --\n",
"│ │ └─Conv2d: 3-1 [1, 16, 1, 1] 4,112\n",
"│ │ └─ReLU: 3-2 [1, 16, 1, 1] --\n",
"│ │ └─Conv2d: 3-3 [1, 256, 1, 1] 4,352\n",
"│ └─AdaptiveAvgPool2d: 2-3 [1, 256, 1, 1] --\n",
"│ └─Sequential: 2-4 [1, 256, 1, 1] (recursive)\n",
"│ │ └─Conv2d: 3-4 [1, 16, 1, 1] (recursive)\n",
"│ │ └─ReLU: 3-5 [1, 16, 1, 1] --\n",
"│ │ └─Conv2d: 3-6 [1, 256, 1, 1] (recursive)\n",
"│ └─Sigmoid: 2-5 [1, 256, 1, 1] --\n",
"├─SpatialAttention: 1-2 [1, 256, 32, 32] --\n",
"│ └─Conv2d: 2-6 [1, 1, 32, 32] 99\n",
"│ └─Sigmoid: 2-7 [1, 1, 32, 32] --\n",
"==========================================================================================\n",
"Total params: 8,563\n",
"Trainable params: 8,563\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 0.12\n",
"==========================================================================================\n",
"Input size (MB): 1.05\n",
"Forward/backward pass size (MB): 0.01\n",
"Params size (MB): 0.03\n",
"Estimated Total Size (MB): 1.10\n",
"=========================================================================================="
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cbam_module = CBAM(in_channels=256)\n",
"torchinfo.summary(cbam_module, (256, 32, 32), batch_dim = 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading