Skip to content

Commit 0a86c91

Browse files
Merge pull request #13 from renan-siqueira/feature/UpgradeArchitecture
Upgrade architecture
2 parents d482ecf + 51b36e2 commit 0a86c91

File tree

3 files changed

+22
-24
lines changed

3 files changed

+22
-24
lines changed

README.md

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
A simple implementation of an autoencoder using PyTorch.
44

5-
This project aims to provide a foundational structure to understand, train, and evaluate autoencoders on 64x64 images.
5+
This project aims to provide a basic framework for understanding, training and evaluating autoencoders on any image size.
66

77
## Features
88

@@ -25,25 +25,15 @@ This project aims to provide a foundational structure to understand, train, and
2525

2626
### Installation
2727

28-
1. Clone the repository:
29-
30-
```bash
31-
git clone https://github.com/renan-siqueira/autoencoder-project.git
32-
```
33-
2. Navigate to the project directory and install the required libraries:
34-
35-
```bash
36-
cd autoencoder-project
37-
pip install -r requirements.txt
38-
```
28+
1. Clone the repository.
29+
2. Navigate to the project directory and install the required libraries.
3930

4031
## Usage
4132

42-
1. Modify settings/settings.py to point to your training and validation dataset.
43-
2. To train the autoencoder, simply run:
33+
1. Modify the `settings/settings.py` file to point to your training and validation dataset.
34+
2. Modify the `json/params.json` file to reflext your training preferences.
35+
3. To train the autoencoder, simply run:
4436

4537
```bash
4638
python run.py
4739
```
48-
49-
By default, this will train a new model. If you wish to use a pre-trained model, modify the `main` method in `run.py`.

json/params.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
"encoding_dim": 16,
55
"num_epochs": 1000,
66
"learning_rate": 0.001,
7-
"ae_type": "ae",
7+
"ae_type": "conv",
88
"save_checkpoint": null
99
}

models/convolutional_autoencoder.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,30 @@ def __init__(self):
88
super(ConvolutionalAutoencoder, self).__init__()
99

1010
# Encoder
11-
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
12-
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
13-
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
11+
self.enc0 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
12+
self.enc1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
13+
self.enc2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
14+
self.enc3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
15+
self.enc4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
1416
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
1517

1618
# Decoder
17-
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
18-
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
19-
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
19+
self.dec0 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
20+
self.dec1 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
21+
self.dec2 = nn.ConvTranspose2d(64, 128, kernel_size=2, stride=2)
22+
self.dec3 = nn.ConvTranspose2d(128, 256, kernel_size=2, stride=2)
23+
self.dec4 = nn.ConvTranspose2d(256, 3, kernel_size=2, stride=2)
2024

2125
def forward(self, x):
26+
x, _ = self.pool(F.relu(self.enc0(x)))
2227
x, _ = self.pool(F.relu(self.enc1(x)))
2328
x, _ = self.pool(F.relu(self.enc2(x)))
2429
x, _ = self.pool(F.relu(self.enc3(x)))
30+
x, _ = self.pool(F.relu(self.enc4(x)))
2531

32+
x = F.relu(self.dec0(x))
2633
x = F.relu(self.dec1(x))
2734
x = F.relu(self.dec2(x))
28-
x = torch.sigmoid(self.dec3(x))
35+
x = F.relu(self.dec3(x))
36+
x = torch.sigmoid(self.dec4(x))
2937
return x

0 commit comments

Comments
 (0)