Skip to content

Commit 6b543fb

Browse files
committed
house cleaning
1 parent 57c2175 commit 6b543fb

File tree

5 files changed

+58
-47
lines changed

5 files changed

+58
-47
lines changed

datasets/compiler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,23 +167,20 @@ def run(self, max_samples:int=None):
167167

168168

169169
if __name__ == "__main__":
170+
ROOT = os.getenv("ROOT_DIRECTORY")
170171
SAMPLE_RATE = int(os.getenv("SAMPLE_RATE"))
171172
CHUNK_DURATION = int(os.getenv("CHUNK_DURATION"))
172173
HOP_LENGTH = int(os.getenv("HOP_LENGTH"))
173174
N_SAMPLES = int(SAMPLE_RATE * CHUNK_DURATION)
175+
TARGET_LUFFS = -float(os.getenv("TARGET_LUFFS", 23))
174176
N_CHANNELS = 1
175-
TARGET_LUFFS = -23
176-
ROOT = os.getenv("ROOT_DIRECTORY")
177177

178178
buffer_schemas = {
179179
"audio": (N_CHANNELS, N_SAMPLES),
180180
}
181181

182182
datasets_folders = [
183183
'gtzan/train',
184-
'mvsep_multisong_dataset/train',
185-
'mvsep_synth_dataset/train',
186-
'djmax_respectv_22050/train',
187184
]
188185

189186
datasets_config = [
@@ -198,7 +195,7 @@ def run(self, max_samples:int=None):
198195
for folder in datasets_folders
199196
]
200197

201-
compile_name = "mixture"
198+
compile_name = "gtzan"
202199
compile_dir = os.path.join(ROOT, "_compiled")
203200
os.makedirs(compile_dir, exist_ok=True)
204201

demo.ipynb

Lines changed: 45 additions & 32 deletions
Large diffs are not rendered by default.

notebooks/transfer.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

training/train.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,13 @@ def train(self):
126126

127127

128128
if __name__ =="__main__":
129-
from datasets.configs import GTZANConfig
129+
from datasets.configs import BaseDatasetConfig
130130
from datasets.dataset import ChunkDataset
131131
from sincnet.model import SincNet
132132

133-
model = SincNet()
134-
dataset_config = GTZANConfig(id="gtzan")
135-
learning_rate = 1e-3
133+
model = SincNet(scale="mel")
134+
dataset_config = BaseDatasetConfig(id="gtzan")
135+
learning_rate = 1e-4
136136
train_config = TrainConfig(**{
137137
"batch_size": 8,
138138
"n_epoch": 500,
@@ -157,4 +157,10 @@ def train(self):
157157
val_set=datasets["test"],
158158
config=train_config
159159
)
160+
161+
try:
162+
model.load_pretrained_weights(weights_folder="pretrained", freeze=False)
163+
except:
164+
pass
165+
160166
trainer.train()

0 commit comments

Comments
 (0)