smart-interactive-display/Assets/StreamingAssets/MergeFace/AgeNet/train.py

188 lines
7.6 KiB
Python

from tqdm import tqdm
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from .utils import *
from .eval import *
from .models import *
import os
import torch
import numpy as np
def train(train_data_dir, weights, device = 'cpu', image_size = 64, batch_size = 128, num_epochs = 100, steps_per_epoch = None,
val_data_dir = None, validation_split = None, save_history = True):
# đặt val_data_dir and validation_split không đồng thời khác None
assert not(val_data_dir is not None and validation_split is not None)
if isinstance(device, str):
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
device = torch.device(device)
else:
device = torch.device('cpu')
# get train_loader
train_data = get_dataloader(train_data_dir, image_size = image_size, batch_size = batch_size)
# chia dữ liệu thành 2 tập train và val
if val_data_dir is not None:
val_data = get_dataloader(val_data_dir, image_size = image_size, batch_size = batch_size)
elif validation_split is not None:
train_data, val_data = split_dataloader(train_data, validation_split)
else:
val_data = None
if steps_per_epoch is None:
steps_per_epoch = len(train_data)
# steps per epoch
num_steps = len(train_data)
iterator = iter(train_data)
count_steps = 1
# add model to device
gender_model = GenderClassificationModel().to(device)
age_range_model = AgeRangeModel().to(device)
age_estimation_model = AgeEstimationModel().to(device)
# Loss function
AgeRangeLoss = nn.CrossEntropyLoss()
GenderLoss = nn.BCELoss()
AgeEstimationLoss = nn.L1Loss()
# Optimizer
gen_optimizer = torch.optim.Adam(gender_model.parameters(), lr = 1e-4)
age_range_optimizer = torch.optim.Adam(age_range_model.parameters(), lr = 5e-3)
age_estimation_optimizer = torch.optim.Adam(age_estimation_model.parameters(), lr = 1e-3)
#schedular
age_range_scheduler = ReduceLROnPlateau(age_range_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
gender_scheduler = ReduceLROnPlateau(gen_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
age_estimation_scheduler = ReduceLROnPlateau(age_estimation_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
# history
history = {
'train_gender_loss': [],
'train_age_loss': [],
'train_gender_acc': [],
'train_age_acc': [],
'val_gender_loss': [],
'val_age_loss': [],
'val_gender_acc': [],
'val_age_acc': [],
}
for epoch in range(1, num_epochs + 1):
total_gender_loss = 0
total_age_loss = 0
age_accuracy = 0
gender_accuracy = 0
total_age_estimation_loss = 0
gender_model.train()
age_range_model.train()
age_estimation_model.train()
for step in tqdm(range(steps_per_epoch), desc = f'Epoch {epoch}/{num_epochs}: ', ncols = 100):
images, genders, age_labels, ages = next(iterator)
batch_size = images.shape[0]
images, genders, age_labels, ages = images.to(device), genders.to(device), age_labels.to(device), ages.to(device)
pred_genders = gender_model(images).view(-1)
pred_age_labels = age_range_model(images)
age_loss = AgeRangeLoss(pred_age_labels, age_labels.long())
gender_loss = GenderLoss(pred_genders, genders.float())
total_age_loss += age_loss.item()
total_gender_loss += gender_loss.item()
gender_acc = torch.sum(torch.round(pred_genders) == genders)/batch_size
age_acc = torch.sum(torch.argmax(pred_age_labels, dim = 1) == age_labels)/batch_size
age_accuracy += age_acc
gender_accuracy += gender_acc
age_range_optimizer.zero_grad()
age_loss.backward()
age_range_optimizer.step()
gen_optimizer.zero_grad()
gender_loss.backward()
gen_optimizer.step()
# age estimation loss
estimated_ages = age_estimation_model(images, age_labels).view(-1)
age_estimation_loss = AgeEstimationLoss(ages, estimated_ages)
age_estimation_optimizer.zero_grad()
age_estimation_loss.backward()
age_estimation_optimizer.step()
total_age_estimation_loss += age_estimation_loss.item()
# nếu nó duyệt hết qua tập dữ liệu thì cho nó lặp lại 1 lần nữa
if count_steps == num_steps:
iterator = iter(train_data)
count_steps = 0
count_steps += 1
train_age_loss = total_age_loss / steps_per_epoch
train_gender_loss = total_gender_loss / steps_per_epoch
train_age_accuracy = age_accuracy / steps_per_epoch
train_gender_accuracy = gender_accuracy / steps_per_epoch
train_age_estimation_loss = total_age_estimation_loss/steps_per_epoch
history['train_age_loss'].append(float(train_age_loss))
history['train_gender_loss'].append(float(train_gender_loss))
history['train_age_acc'].append(float(train_age_accuracy))
history['train_gender_acc'].append(float(train_gender_accuracy))
print(f'train_age_loss: {train_age_loss: .2f}, train_gender_loss: {train_gender_loss: .3f}, train_age_accuracy: {train_age_accuracy: .2f}, train_gender_accuracy: {train_gender_accuracy: .2f}, train_age_estimation_loss: {train_age_estimation_loss: .3f}')
if val_data:
val_age_loss, val_gender_loss, val_age_accuracy, val_gender_accuracy, val_age_estimation_loss = evaluate(gender_model, age_range_model, age_estimation_model, val_data, device = device)
history['val_age_loss'].append(float(val_age_loss))
history['val_gender_loss'].append(float(val_gender_loss))
history['val_age_acc'].append(float(val_age_accuracy))
history['val_gender_acc'].append(float(val_gender_accuracy))
age_range_scheduler.step(np.round(val_age_loss, 3))
gender_scheduler.step(np.round(val_gender_loss, 3))
age_estimation_scheduler.step(np.round(val_age_estimation_loss, 3))
print(f'val_age_loss: {val_age_loss: .2f}, val_gender_loss: {val_gender_loss: .3f}, val_age_accuracy: {val_age_accuracy: .2f}, val_gender_accuracy: {val_gender_accuracy: .2f}, val_age_estimation_loss: {val_age_estimation_loss : .3f}')
if weights:
class dummy_model(nn.Module):
def __init__(self):
super().__init__()
self.gender_model = gender_model
self.age_range_model = age_range_model
self.age_estimation_model = age_estimation_model
def forward(self, x):
return
model = dummy_model()
torch.save(model.state_dict(), weights)
print(f'Saved successfully last weights to:', weights)
if save_history:
visualize_history(history, save_history)
if __name__ == '__main__':
batch_size = 128
image_size = 64
train_data_dir = '/kaggle/input/utkface-new/UTKFace'
device = 'cuda'
weights = 'weights\AgeGenderWeights.pt'
epochs = 100
train(train_data_dir, weights, device = device, steps_per_epoch = None,
validation_split = 0.2, image_size = image_size, batch_size = batch_size, save_history = True)