Confusion Matrix And Test Accuracy For PyTorch Transfer Learning Tutorial
Answer :
Answer given by ptrblck
of PyTorch community. Thanks a lot!
nb_classes = 9
confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
for t, p in zip(classes.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
print(confusion_matrix)
To get the per-class accuracy:
print(confusion_matrix.diag()/confusion_matrix.sum(1))
Here is a slightly modified(direct) approach using sklearn's confusion_matrix:-
from sklearn.metrics import confusion_matrix
nb_classes = 9
# Initialize the prediction and label lists(tensors)
predlist=torch.zeros(0,dtype=torch.long, device='cpu')
lbllist=torch.zeros(0,dtype=torch.long, device='cpu')
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = model_ft(inputs)
_, preds = torch.max(outputs, 1)
# Append batch prediction results
predlist=torch.cat([predlist,preds.view(-1).cpu()])
lbllist=torch.cat([lbllist,classes.view(-1).cpu()])
# Confusion matrix
conf_mat=confusion_matrix(lbllist.numpy(), predlist.numpy())
print(conf_mat)
# Per-class accuracy
class_accuracy=100*conf_mat.diagonal()/conf_mat.sum(1)
print(class_accuracy)
Another simple way to get accuracy is to use sklearns "accuracy_score".
Heres an example:
from sklearn.metrics import accuracy_score
y_pred = y_pred.data.numpy()
accuracy = accuracy_score(labels, np.argmax(y_pred, axis=1))
First you need to get the data from the variable.
"y_pred" is the predictions from your model, and labels are of course your labels.
np.argmax returns the index of the largest value inside the array. We want the largest value as it corresponds to the highest probability class when using softmax for multi-class classification. Accuracy score will return a percentage of matches between the labels and y_pred.
Comments
Post a Comment