Skip to content

Commit 3fc85c9

Browse files
Add notebook on ConvMixer implementation (#28)
* Add notebook on ConvMixer implementation * change the name of the notebook and the csv name * Change the Csv name to lowercase * Cambio temporal para forzar renombrado * LoweCase file * only lower case * Csv conflicts resolved
1 parent 933a743 commit 3fc85c9

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

automation/notebooks-table-data.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ Multi-Head Attention, modules/multihead-self-attention.ipynb,,https://arxiv.org/
1616
ResNet,architectures/resnet.ipynb,,https://arxiv.org/abs/1512.03385
1717
Convolutional Block Attention, modules/convolutional-block-attention.ipynb,,https://arxiv.org/abs/1807.06521
1818
Transformer, architectures/transformer.ipynb,,https://arxiv.org/abs/1706.03762
19+
ConvMixer,architectures/convmixer.ipynb,,https://arxiv.org/pdf/2201.09792
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "IJveajFZvdXK"
7+
},
8+
"source": [
9+
"## ConvMixer"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {
16+
"id": "q5AuzFB2tA12"
17+
},
18+
"outputs": [],
19+
"source": [
20+
"#@title **Install required packages**\n",
21+
"\n",
22+
"%%capture\n",
23+
"! pip install torchinfo"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {
30+
"id": "118P89a1osHb"
31+
},
32+
"outputs": [],
33+
"source": [
34+
"#@title **Importing libraries**\n",
35+
"import torch # 2.5.1+cu121\n",
36+
"import torch.nn as nn\n",
37+
"import torchinfo #1.8.0"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"metadata": {
44+
"colab": {
45+
"base_uri": "https://localhost:8080/"
46+
},
47+
"id": "QchOSIbro1zf",
48+
"outputId": "94003f5a-622f-4f54-e916-26fbbff86c17"
49+
},
50+
"outputs": [
51+
{
52+
"name": "stdout",
53+
"output_type": "stream",
54+
"text": [
55+
"torch version: 2.5.1+cu121\n",
56+
"torchinfo version: 1.8.0\n"
57+
]
58+
}
59+
],
60+
"source": [
61+
"# Note: Not all dependencies have the __version__ method.\n",
62+
"print(f\"torch version: {torch.__version__}\")\n",
63+
"print(f\"torchinfo version: {torchinfo.__version__}\")"
64+
]
65+
},
66+
{
67+
"cell_type": "markdown",
68+
"metadata": {
69+
"id": "hrG5QYB4pRMu"
70+
},
71+
"source": [
72+
"**ConvMixer architecture code**\n"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"metadata": {
79+
"id": "r2d2e2P0pdan"
80+
},
81+
"outputs": [],
82+
"source": [
83+
"class Residual(nn.Module):\n",
84+
" def __init__(self, fn):\n",
85+
" super().__init__()\n",
86+
" self.fn = fn\n",
87+
"\n",
88+
" def forward(self, x):\n",
89+
" return self.fn(x) + x\n",
90+
"\n",
91+
"def ConvMixer(dim, depth, kernel_size = 9, patch_size = 7, n_classes = 1000):\n",
92+
" return nn.Sequential(\n",
93+
" nn.Conv2d(3, dim, kernel_size = patch_size, stride = patch_size),\n",
94+
" nn.GELU(),\n",
95+
" nn.BatchNorm2d(dim),\n",
96+
" *[nn.Sequential(\n",
97+
" Residual(nn.Sequential(\n",
98+
" nn.Conv2d(dim, dim, kernel_size, groups=dim, padding=\"same\"),\n",
99+
" nn.GELU(),\n",
100+
" nn.BatchNorm2d(dim)\n",
101+
" )),\n",
102+
" nn.Conv2d(dim,dim, kernel_size = 1),\n",
103+
" nn.GELU(),\n",
104+
" nn.BatchNorm2d(dim)\n",
105+
" )for i in range(depth)],\n",
106+
" nn.AdaptiveAvgPool2d((1,1)),\n",
107+
" nn.Flatten(),\n",
108+
" nn.Linear(dim, n_classes)\n",
109+
" )"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": null,
115+
"metadata": {
116+
"colab": {
117+
"base_uri": "https://localhost:8080/"
118+
},
119+
"id": "9YM5zfNoslTx",
120+
"outputId": "1b5a1295-d3ac-49c9-d61d-ae986021ad00"
121+
},
122+
"outputs": [
123+
{
124+
"data": {
125+
"text/plain": [
126+
"=================================================================\n",
127+
"Layer (type:depth-idx) Param #\n",
128+
"=================================================================\n",
129+
"Sequential --\n",
130+
"├─Conv2d: 1-1 8,192\n",
131+
"├─GELU: 1-2 --\n",
132+
"├─BatchNorm2d: 1-3 4,096\n",
133+
"├─Sequential: 1-4 --\n",
134+
"│ └─Residual: 2-1 --\n",
135+
"│ │ └─Sequential: 3-1 172,032\n",
136+
"│ └─Conv2d: 2-2 4,196,352\n",
137+
"│ └─GELU: 2-3 --\n",
138+
"│ └─BatchNorm2d: 2-4 4,096\n",
139+
"├─Sequential: 1-5 --\n",
140+
"│ └─Residual: 2-5 --\n",
141+
"│ │ └─Sequential: 3-2 172,032\n",
142+
"│ └─Conv2d: 2-6 4,196,352\n",
143+
"│ └─GELU: 2-7 --\n",
144+
"│ └─BatchNorm2d: 2-8 4,096\n",
145+
"├─Sequential: 1-6 --\n",
146+
"│ └─Residual: 2-9 --\n",
147+
"│ │ └─Sequential: 3-3 172,032\n",
148+
"│ └─Conv2d: 2-10 4,196,352\n",
149+
"│ └─GELU: 2-11 --\n",
150+
"│ └─BatchNorm2d: 2-12 4,096\n",
151+
"├─Sequential: 1-7 --\n",
152+
"│ └─Residual: 2-13 --\n",
153+
"│ │ └─Sequential: 3-4 172,032\n",
154+
"│ └─Conv2d: 2-14 4,196,352\n",
155+
"│ └─GELU: 2-15 --\n",
156+
"│ └─BatchNorm2d: 2-16 4,096\n",
157+
"├─Sequential: 1-8 --\n",
158+
"│ └─Residual: 2-17 --\n",
159+
"│ │ └─Sequential: 3-5 172,032\n",
160+
"│ └─Conv2d: 2-18 4,196,352\n",
161+
"│ └─GELU: 2-19 --\n",
162+
"│ └─BatchNorm2d: 2-20 4,096\n",
163+
"├─Sequential: 1-9 --\n",
164+
"│ └─Residual: 2-21 --\n",
165+
"│ │ └─Sequential: 3-6 172,032\n",
166+
"│ └─Conv2d: 2-22 4,196,352\n",
167+
"│ └─GELU: 2-23 --\n",
168+
"│ └─BatchNorm2d: 2-24 4,096\n",
169+
"├─Sequential: 1-10 --\n",
170+
"│ └─Residual: 2-25 --\n",
171+
"│ │ └─Sequential: 3-7 172,032\n",
172+
"│ └─Conv2d: 2-26 4,196,352\n",
173+
"│ └─GELU: 2-27 --\n",
174+
"│ └─BatchNorm2d: 2-28 4,096\n",
175+
"├─Sequential: 1-11 --\n",
176+
"│ └─Residual: 2-29 --\n",
177+
"│ │ └─Sequential: 3-8 172,032\n",
178+
"│ └─Conv2d: 2-30 4,196,352\n",
179+
"│ └─GELU: 2-31 --\n",
180+
"│ └─BatchNorm2d: 2-32 4,096\n",
181+
"├─AdaptiveAvgPool2d: 1-12 --\n",
182+
"├─Flatten: 1-13 --\n",
183+
"├─Linear: 1-14 2,049,000\n",
184+
"=================================================================\n",
185+
"Total params: 37,041,128\n",
186+
"Trainable params: 37,041,128\n",
187+
"Non-trainable params: 0\n",
188+
"================================================================="
189+
]
190+
},
191+
"execution_count": 41,
192+
"metadata": {},
193+
"output_type": "execute_result"
194+
}
195+
],
196+
"source": [
197+
"model = ConvMixer(2048, 8, kernel_size=9, patch_size=1, n_classes=1000)\n",
198+
"torchinfo.summary(model)"
199+
]
200+
}
201+
],
202+
"metadata": {
203+
"colab": {
204+
"provenance": []
205+
},
206+
"kernelspec": {
207+
"display_name": "Python 3",
208+
"name": "python3"
209+
},
210+
"language_info": {
211+
"name": "python"
212+
}
213+
},
214+
"nbformat": 4,
215+
"nbformat_minor": 0
216+
}

0 commit comments

Comments
 (0)