188 lines
7.6 KiB
Python
188 lines
7.6 KiB
Python
from tqdm import tqdm
|
|
from torch import nn
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
|
|
from .utils import *
|
|
from .eval import *
|
|
from .models import *
|
|
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
def train(train_data_dir, weights, device = 'cpu', image_size = 64, batch_size = 128, num_epochs = 100, steps_per_epoch = None,
|
|
val_data_dir = None, validation_split = None, save_history = True):
|
|
|
|
# đặt val_data_dir and validation_split không đồng thời khác None
|
|
assert not(val_data_dir is not None and validation_split is not None)
|
|
|
|
if isinstance(device, str):
|
|
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
|
|
device = torch.device(device)
|
|
else:
|
|
device = torch.device('cpu')
|
|
|
|
# get train_loader
|
|
train_data = get_dataloader(train_data_dir, image_size = image_size, batch_size = batch_size)
|
|
|
|
# chia dữ liệu thành 2 tập train và val
|
|
if val_data_dir is not None:
|
|
val_data = get_dataloader(val_data_dir, image_size = image_size, batch_size = batch_size)
|
|
elif validation_split is not None:
|
|
train_data, val_data = split_dataloader(train_data, validation_split)
|
|
else:
|
|
val_data = None
|
|
|
|
if steps_per_epoch is None:
|
|
steps_per_epoch = len(train_data)
|
|
|
|
# steps per epoch
|
|
num_steps = len(train_data)
|
|
iterator = iter(train_data)
|
|
count_steps = 1
|
|
|
|
# add model to device
|
|
gender_model = GenderClassificationModel().to(device)
|
|
age_range_model = AgeRangeModel().to(device)
|
|
age_estimation_model = AgeEstimationModel().to(device)
|
|
|
|
# Loss function
|
|
AgeRangeLoss = nn.CrossEntropyLoss()
|
|
GenderLoss = nn.BCELoss()
|
|
AgeEstimationLoss = nn.L1Loss()
|
|
|
|
# Optimizer
|
|
gen_optimizer = torch.optim.Adam(gender_model.parameters(), lr = 1e-4)
|
|
age_range_optimizer = torch.optim.Adam(age_range_model.parameters(), lr = 5e-3)
|
|
age_estimation_optimizer = torch.optim.Adam(age_estimation_model.parameters(), lr = 1e-3)
|
|
|
|
#schedular
|
|
age_range_scheduler = ReduceLROnPlateau(age_range_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
|
gender_scheduler = ReduceLROnPlateau(gen_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
|
age_estimation_scheduler = ReduceLROnPlateau(age_estimation_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
|
|
|
# history
|
|
history = {
|
|
'train_gender_loss': [],
|
|
'train_age_loss': [],
|
|
'train_gender_acc': [],
|
|
'train_age_acc': [],
|
|
'val_gender_loss': [],
|
|
'val_age_loss': [],
|
|
'val_gender_acc': [],
|
|
'val_age_acc': [],
|
|
}
|
|
|
|
for epoch in range(1, num_epochs + 1):
|
|
total_gender_loss = 0
|
|
total_age_loss = 0
|
|
age_accuracy = 0
|
|
gender_accuracy = 0
|
|
|
|
total_age_estimation_loss = 0
|
|
|
|
gender_model.train()
|
|
age_range_model.train()
|
|
age_estimation_model.train()
|
|
|
|
for step in tqdm(range(steps_per_epoch), desc = f'Epoch {epoch}/{num_epochs}: ', ncols = 100):
|
|
images, genders, age_labels, ages = next(iterator)
|
|
batch_size = images.shape[0]
|
|
|
|
images, genders, age_labels, ages = images.to(device), genders.to(device), age_labels.to(device), ages.to(device)
|
|
|
|
pred_genders = gender_model(images).view(-1)
|
|
pred_age_labels = age_range_model(images)
|
|
|
|
age_loss = AgeRangeLoss(pred_age_labels, age_labels.long())
|
|
|
|
gender_loss = GenderLoss(pred_genders, genders.float())
|
|
|
|
total_age_loss += age_loss.item()
|
|
total_gender_loss += gender_loss.item()
|
|
|
|
gender_acc = torch.sum(torch.round(pred_genders) == genders)/batch_size
|
|
age_acc = torch.sum(torch.argmax(pred_age_labels, dim = 1) == age_labels)/batch_size
|
|
|
|
age_accuracy += age_acc
|
|
gender_accuracy += gender_acc
|
|
|
|
age_range_optimizer.zero_grad()
|
|
age_loss.backward()
|
|
age_range_optimizer.step()
|
|
|
|
gen_optimizer.zero_grad()
|
|
gender_loss.backward()
|
|
gen_optimizer.step()
|
|
|
|
# age estimation loss
|
|
estimated_ages = age_estimation_model(images, age_labels).view(-1)
|
|
age_estimation_loss = AgeEstimationLoss(ages, estimated_ages)
|
|
|
|
age_estimation_optimizer.zero_grad()
|
|
age_estimation_loss.backward()
|
|
age_estimation_optimizer.step()
|
|
|
|
total_age_estimation_loss += age_estimation_loss.item()
|
|
|
|
# nếu nó duyệt hết qua tập dữ liệu thì cho nó lặp lại 1 lần nữa
|
|
if count_steps == num_steps:
|
|
iterator = iter(train_data)
|
|
count_steps = 0
|
|
count_steps += 1
|
|
|
|
train_age_loss = total_age_loss / steps_per_epoch
|
|
train_gender_loss = total_gender_loss / steps_per_epoch
|
|
train_age_accuracy = age_accuracy / steps_per_epoch
|
|
train_gender_accuracy = gender_accuracy / steps_per_epoch
|
|
|
|
train_age_estimation_loss = total_age_estimation_loss/steps_per_epoch
|
|
|
|
history['train_age_loss'].append(float(train_age_loss))
|
|
history['train_gender_loss'].append(float(train_gender_loss))
|
|
history['train_age_acc'].append(float(train_age_accuracy))
|
|
history['train_gender_acc'].append(float(train_gender_accuracy))
|
|
|
|
print(f'train_age_loss: {train_age_loss: .2f}, train_gender_loss: {train_gender_loss: .3f}, train_age_accuracy: {train_age_accuracy: .2f}, train_gender_accuracy: {train_gender_accuracy: .2f}, train_age_estimation_loss: {train_age_estimation_loss: .3f}')
|
|
if val_data:
|
|
val_age_loss, val_gender_loss, val_age_accuracy, val_gender_accuracy, val_age_estimation_loss = evaluate(gender_model, age_range_model, age_estimation_model, val_data, device = device)
|
|
history['val_age_loss'].append(float(val_age_loss))
|
|
history['val_gender_loss'].append(float(val_gender_loss))
|
|
history['val_age_acc'].append(float(val_age_accuracy))
|
|
history['val_gender_acc'].append(float(val_gender_accuracy))
|
|
|
|
age_range_scheduler.step(np.round(val_age_loss, 3))
|
|
gender_scheduler.step(np.round(val_gender_loss, 3))
|
|
age_estimation_scheduler.step(np.round(val_age_estimation_loss, 3))
|
|
print(f'val_age_loss: {val_age_loss: .2f}, val_gender_loss: {val_gender_loss: .3f}, val_age_accuracy: {val_age_accuracy: .2f}, val_gender_accuracy: {val_gender_accuracy: .2f}, val_age_estimation_loss: {val_age_estimation_loss : .3f}')
|
|
|
|
if weights:
|
|
class dummy_model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.gender_model = gender_model
|
|
self.age_range_model = age_range_model
|
|
self.age_estimation_model = age_estimation_model
|
|
|
|
def forward(self, x):
|
|
return
|
|
|
|
model = dummy_model()
|
|
torch.save(model.state_dict(), weights)
|
|
print(f'Saved successfully last weights to:', weights)
|
|
|
|
if save_history:
|
|
visualize_history(history, save_history)
|
|
|
|
if __name__ == '__main__':
|
|
batch_size = 128
|
|
image_size = 64
|
|
train_data_dir = '/kaggle/input/utkface-new/UTKFace'
|
|
device = 'cuda'
|
|
weights = 'weights\AgeGenderWeights.pt'
|
|
epochs = 100
|
|
train(train_data_dir, weights, device = device, steps_per_epoch = None,
|
|
validation_split = 0.2, image_size = image_size, batch_size = batch_size, save_history = True)
|
|
|