-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
23 lines (17 loc) · 824 Bytes
/
utils.py
File metadata and controls
23 lines (17 loc) · 824 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
def getUncertainty(V, classNum):
# For classification, input V is the possibility vector for each category
# Get uncertainty from OUI (task model)
# Only compute uncertainty of unlabeled data
# V: dataNum * classNum
maxV = torch.amax(V, axis=1) # max of each category
varV = torch.var(V, axis=1, unbiased=False) # variance of each data point
C = classNum
C_inv = 1 / C
temp1 = torch.pow(C_inv - maxV,2)
temp2 = (C - 1) * torch.pow(C_inv - (1 - maxV) / (C - 1), 2) ##
minVarV = C_inv * (temp1 + temp2)
uncertainty_temp1 = torch.divide(minVarV, varV)
uncertainty_temp2 = torch.multiply(uncertainty_temp1, maxV)
uncertainty = 1 - uncertainty_temp2
return uncertainty # return the overall uncertainty of every unlabeled data point