Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion automation/notebooks-table-data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ MLP Mixer,architectures/mlp-mixer.ipynb,,https://arxiv.org/abs/2105.01601
GloVe Word Embeddings, data_exploration/glove-word-embeddings.ipynb,https://github.com/stanfordnlp/GloVe,https://nlp.stanford.edu/pubs/glove.pdf
Vision Transformer (ViT),architectures/vit.ipynb,,https://arxiv.org/pdf/2010.11929
Multi-Head Attention, modules/multihead-self-attention.ipynb,,https://arxiv.org/abs/1706.03762
ResNet,architectures/resnet.ipynb,,https://arxiv.org/abs/1512.03385
ResNet,architectures/resnet.ipynb,,https://arxiv.org/abs/1512.03385
ConvMixer,architectures/ConvMixer.ipynb,,https://arxiv.org/pdf/2201.09792
216 changes: 216 additions & 0 deletions notebooks/architectures/ConvMixer.ipynb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said in the other comment, the name of the .ipynb must be completely lowercase

Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"## ConvMixer"
],
"metadata": {
"id": "IJveajFZvdXK"
}
},
{
"cell_type": "code",
"source": [
"#@title **Install required packages**\n",
"\n",
"%%capture\n",
"! pip install torchinfo"
],
"metadata": {
"id": "q5AuzFB2tA12"
},
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"id": "118P89a1osHb"
},
"outputs": [],
"source": [
"#@title **Importing libraries**\n",
"import torch # 2.5.1+cu121\n",
"import torch.nn as nn\n",
"import torchinfo #1.8.0"
]
},
{
"cell_type": "code",
"source": [
"# Note: Not all dependencies have the __version__ method.\n",
"print(f\"torch version: {torch.__version__}\")\n",
"print(f\"torchinfo version: {torchinfo.__version__}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QchOSIbro1zf",
"outputId": "94003f5a-622f-4f54-e916-26fbbff86c17"
},
"execution_count": 37,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch version: 2.5.1+cu121\n",
"torchinfo version: 1.8.0\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"**ConvMixer architecture code**\n"
],
"metadata": {
"id": "hrG5QYB4pRMu"
}
},
{
"cell_type": "code",
"source": [
"class Residual(nn.Module):\n",
" def __init__(self, fn):\n",
" super().__init__()\n",
" self.fn = fn\n",
"\n",
" def forward(self, x):\n",
" return self.fn(x) + x\n",
"\n",
"def ConvMixer(dim, depth, kernel_size = 9, patch_size = 7, n_classes = 1000):\n",
" return nn.Sequential(\n",
" nn.Conv2d(3, dim, kernel_size = patch_size, stride = patch_size),\n",
" nn.GELU(),\n",
" nn.BatchNorm2d(dim),\n",
" *[nn.Sequential(\n",
" Residual(nn.Sequential(\n",
" nn.Conv2d(dim, dim, kernel_size, groups=dim, padding=\"same\"),\n",
" nn.GELU(),\n",
" nn.BatchNorm2d(dim)\n",
" )),\n",
" nn.Conv2d(dim,dim, kernel_size = 1),\n",
" nn.GELU(),\n",
" nn.BatchNorm2d(dim)\n",
" )for i in range(depth)],\n",
" nn.AdaptiveAvgPool2d((1,1)),\n",
" nn.Flatten(),\n",
" nn.Linear(dim, n_classes)\n",
" )"
],
"metadata": {
"id": "r2d2e2P0pdan"
},
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = ConvMixer(2048, 8, kernel_size=9, patch_size=1, n_classes=1000)\n",
"torchinfo.summary(model)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9YM5zfNoslTx",
"outputId": "1b5a1295-d3ac-49c9-d61d-ae986021ad00"
},
"execution_count": 41,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"=================================================================\n",
"Layer (type:depth-idx) Param #\n",
"=================================================================\n",
"Sequential --\n",
"├─Conv2d: 1-1 8,192\n",
"├─GELU: 1-2 --\n",
"├─BatchNorm2d: 1-3 4,096\n",
"├─Sequential: 1-4 --\n",
"│ └─Residual: 2-1 --\n",
"│ │ └─Sequential: 3-1 172,032\n",
"│ └─Conv2d: 2-2 4,196,352\n",
"│ └─GELU: 2-3 --\n",
"│ └─BatchNorm2d: 2-4 4,096\n",
"├─Sequential: 1-5 --\n",
"│ └─Residual: 2-5 --\n",
"│ │ └─Sequential: 3-2 172,032\n",
"│ └─Conv2d: 2-6 4,196,352\n",
"│ └─GELU: 2-7 --\n",
"│ └─BatchNorm2d: 2-8 4,096\n",
"├─Sequential: 1-6 --\n",
"│ └─Residual: 2-9 --\n",
"│ │ └─Sequential: 3-3 172,032\n",
"│ └─Conv2d: 2-10 4,196,352\n",
"│ └─GELU: 2-11 --\n",
"│ └─BatchNorm2d: 2-12 4,096\n",
"├─Sequential: 1-7 --\n",
"│ └─Residual: 2-13 --\n",
"│ │ └─Sequential: 3-4 172,032\n",
"│ └─Conv2d: 2-14 4,196,352\n",
"│ └─GELU: 2-15 --\n",
"│ └─BatchNorm2d: 2-16 4,096\n",
"├─Sequential: 1-8 --\n",
"│ └─Residual: 2-17 --\n",
"│ │ └─Sequential: 3-5 172,032\n",
"│ └─Conv2d: 2-18 4,196,352\n",
"│ └─GELU: 2-19 --\n",
"│ └─BatchNorm2d: 2-20 4,096\n",
"├─Sequential: 1-9 --\n",
"│ └─Residual: 2-21 --\n",
"│ │ └─Sequential: 3-6 172,032\n",
"│ └─Conv2d: 2-22 4,196,352\n",
"│ └─GELU: 2-23 --\n",
"│ └─BatchNorm2d: 2-24 4,096\n",
"├─Sequential: 1-10 --\n",
"│ └─Residual: 2-25 --\n",
"│ │ └─Sequential: 3-7 172,032\n",
"│ └─Conv2d: 2-26 4,196,352\n",
"│ └─GELU: 2-27 --\n",
"│ └─BatchNorm2d: 2-28 4,096\n",
"├─Sequential: 1-11 --\n",
"│ └─Residual: 2-29 --\n",
"│ │ └─Sequential: 3-8 172,032\n",
"│ └─Conv2d: 2-30 4,196,352\n",
"│ └─GELU: 2-31 --\n",
"│ └─BatchNorm2d: 2-32 4,096\n",
"├─AdaptiveAvgPool2d: 1-12 --\n",
"├─Flatten: 1-13 --\n",
"├─Linear: 1-14 2,049,000\n",
"=================================================================\n",
"Total params: 37,041,128\n",
"Trainable params: 37,041,128\n",
"Non-trainable params: 0\n",
"================================================================="
]
},
"metadata": {},
"execution_count": 41
}
]
}
]
}