smart-interactive-display/Assets/StreamingAssets/MergeFace/AgeNet/eval.py

67 lines
2.7 KiB
Python

from torch import nn
from tqdm import tqdm
import torch
def evaluate(gender_model, age_range_model, age_estimation_model, val_dataloader, device = 'cpu', verbose = 0):
# set device
if isinstance(device, str):
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
device = torch.device(device)
else:
device = torch.device('cpu')
# Loss function
AgeRangeLoss = nn.CrossEntropyLoss()
GenderLoss = nn.BCELoss()
AgeEstimationLoss = nn.L1Loss()
gender_model = gender_model.to(device)
age_range_model = age_range_model.to(device)
age_estimation_model = age_estimation_model.to(device)
with torch.no_grad():
age_range_model.eval()
gender_model.eval()
age_estimation_model.eval()
age_accuracy = 0
gender_accuracy = 0
total_age_loss = 0
total_gender_loss = 0
total_age_estimation_loss = 0
if verbose == 1:
val_dataloader = tqdm(val_dataloader, desc = 'Evaluate: ', ncols = 100)
for images, genders, age_labels, ages in val_dataloader:
batch_size = images.shape[0]
images, genders, age_labels, ages = images.to(device), genders.to(device), age_labels.to(device), ages.to(device)
pred_genders = gender_model(images).view(-1)
pred_age_labels = age_range_model(images)
age_loss = AgeRangeLoss(pred_age_labels, age_labels.long())
gender_loss = GenderLoss(pred_genders, genders.float())
total_age_loss += age_loss.item()
total_gender_loss += gender_loss.item()
gender_acc = torch.sum(torch.round(pred_genders) == genders)/batch_size
age_acc = torch.sum(torch.argmax(pred_age_labels, dim = 1) == age_labels)/batch_size
age_accuracy += age_acc
gender_accuracy += gender_acc
estimated_ages = age_estimation_model(images, age_labels).view(-1)
age_estimation_loss = AgeEstimationLoss(ages, estimated_ages)
total_age_estimation_loss += age_estimation_loss.item()
val_age_loss = total_age_loss / len(val_dataloader)
val_gender_loss = total_gender_loss / len(val_dataloader)
val_age_accuracy = age_accuracy / len(val_dataloader)
val_gender_accuracy = gender_accuracy / len(val_dataloader)
val_age_estimation_loss = total_age_estimation_loss / len(val_dataloader)
return val_age_loss, val_gender_loss, val_age_accuracy, val_gender_accuracy, val_age_estimation_loss