Confusion Matrix And Test Accuracy For PyTorch Transfer Learning Tutorial


Answer :

Answer given by ptrblck of PyTorch community. Thanks a lot!



nb_classes = 9

confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
for t, p in zip(classes.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1

print(confusion_matrix)


To get the per-class accuracy:



print(confusion_matrix.diag()/confusion_matrix.sum(1))


Here is a slightly modified(direct) approach using sklearn's confusion_matrix:-



from sklearn.metrics import confusion_matrix

nb_classes = 9

# Initialize the prediction and label lists(tensors)
predlist=torch.zeros(0,dtype=torch.long, device='cpu')
lbllist=torch.zeros(0,dtype=torch.long, device='cpu')

with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)

# Append batch prediction results
predlist=torch.cat([predlist,preds.view(-1).cpu()])
lbllist=torch.cat([lbllist,classes.view(-1).cpu()])

# Confusion matrix
conf_mat=confusion_matrix(lbllist.numpy(), predlist.numpy())
print(conf_mat)

# Per-class accuracy
class_accuracy=100*conf_mat.diagonal()/conf_mat.sum(1)
print(class_accuracy)


Another simple way to get accuracy is to use sklearns "accuracy_score".
Heres an example:



from sklearn.metrics import accuracy_score
y_pred = y_pred.data.numpy()
accuracy = accuracy_score(labels, np.argmax(y_pred, axis=1))


First you need to get the data from the variable.
"y_pred" is the predictions from your model, and labels are of course your labels.



np.argmax returns the index of the largest value inside the array. We want the largest value as it corresponds to the highest probability class when using softmax for multi-class classification. Accuracy score will return a percentage of matches between the labels and y_pred.



Comments

Popular posts from this blog

Converting A String To Int In Groovy

"Cannot Create Cache Directory /home//.composer/cache/repo/https---packagist.org/, Or Directory Is Not Writable. Proceeding Without Cache"

Android SDK Location Should Not Contain Whitespace, As This Cause Problems With NDK Tools