Skip to content

Commit 3f11bc1

Browse files
committed
tuning the results to show more of a diagonal on confusion matrix.. Changed epochs, training rate, more of split to training data
1 parent d7cfb5d commit 3f11bc1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

intermediate_source/char_rnn_classification_tutorial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __getitem__(self, idx):
200200
#split but the torch.utils.data has more useful utilities. Here we specify a generator since we need to use the
201201
#same device as torch defaults to above.
202202

203-
train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2], generator=torch.Generator(device=device).manual_seed(1))
203+
train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024))
204204

205205
print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}")
206206

@@ -336,7 +336,7 @@ def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50
336336
# We can now train a dataset with mini batches for a specified number of epochs
337337

338338
start = time.time()
339-
all_losses = train(rnn, train_set, n_epoch=13, learning_rate=0.2, report_every=1)
339+
all_losses = train(rnn, train_set, n_epoch=55, learning_rate=0.15, report_every=5)
340340
end = time.time()
341341
print(f"training took {end-start}s")
342342

0 commit comments

Comments
 (0)