-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathUtils.py
More file actions
29 lines (24 loc) · 971 Bytes
/
Utils.py
File metadata and controls
29 lines (24 loc) · 971 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn.functional as F
# Accuracy function
def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim=1)
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
# Evaluate function
@torch.no_grad()
def evaluate(model, val_loader, device):
model.eval()
outputs = []
for batch in val_loader:
images, labels = batch
images = images.to(device)
labels = labels.to(device)
out = model(images)
loss = F.cross_entropy(out, labels)
acc = accuracy(out, labels)
outputs.append({"val_loss": loss.detach(), "val_acc": acc})
batch_losses = [x["val_loss"] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine Losses
batch_accs = [x["val_acc"] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean() # Combine Accuracies
return {"val_loss": epoch_loss.item(), "val_acc": epoch_acc.item()}