94 lines
4.1 KiB
Python
94 lines
4.1 KiB
Python
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
|
|
from .dataset import UTKFaceDataset
|
|
from torchvision import transforms as T
|
|
from matplotlib import pyplot as plt
|
|
import os
|
|
|
|
def get_dataloader(root_dir, image_size = 64, batch_size = 128, shuffle = True, num_workers = 1):
|
|
if isinstance(image_size, int):
|
|
image_size = (image_size, image_size)
|
|
|
|
train_transform = T.Compose(
|
|
[
|
|
T.Resize(image_size),
|
|
T.RandomHorizontalFlip(0.2),
|
|
T.RandomRotation(10),
|
|
T.ToTensor(),
|
|
T.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
|
|
]
|
|
)
|
|
|
|
train_dataset = UTKFaceDataset(root_dir, transform = train_transform)
|
|
trainloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers, drop_last = True)
|
|
return trainloader
|
|
|
|
def split_dataloader(train_data, validation_split = 0.2):
|
|
# Chia DataLoader thành phần train và test
|
|
train_ratio = 1 - validation_split # Tỷ lệ phần train (80%)
|
|
train_size = int(train_ratio * len(train_data.dataset)) # Số lượng mẫu dùng cho train
|
|
|
|
indices = list(range(len(train_data.dataset))) # Danh sách các chỉ số của dataset
|
|
train_indices = indices[:train_size] # Chỉ số của mẫu dùng cho train
|
|
val_indices = indices[train_size:] # Chỉ số của mẫu dùng cho test
|
|
|
|
# lấy dữ liệu từ dataloader
|
|
dataset = train_data.dataset
|
|
batch_size = train_data.batch_size
|
|
num_workers = train_data.num_workers
|
|
|
|
# Tạo ra các SubsetRandomSampler để chọn một phần dữ liệu cho train và test
|
|
train_sampler = SubsetRandomSampler(train_indices)
|
|
val_sampler = SubsetRandomSampler(val_indices)
|
|
|
|
# Tạo DataLoader mới từ SubsetRandomSampler
|
|
train_data = DataLoader(dataset, batch_size = batch_size, sampler = train_sampler, num_workers = num_workers, drop_last = True)
|
|
val_data = DataLoader(dataset, batch_size = batch_size, sampler = val_sampler, num_workers = num_workers, drop_last = True)
|
|
|
|
return train_data, val_data
|
|
|
|
def visualize_history(history, save_history=True):
|
|
plt.figure(figsize = (20,8))
|
|
plt.subplot(131)
|
|
plt.plot(range(1, len(history['train_age_acc']) + 1), history['train_age_acc'], label = 'train_age_acc', c = 'r')
|
|
plt.plot(range(1, len(history['val_age_acc']) + 1), history['val_age_acc'], label = 'val_age_acc', c = 'g')
|
|
plt.plot(range(1, len(history['train_gender_acc']) + 1), history['train_gender_acc'], label = 'train_gender_acc', c = 'b')
|
|
plt.plot(range(1, len(history['val_gender_acc']) + 1), history['val_gender_acc'], label = 'val_gender_acc', c = 'y')
|
|
plt.title('Accuracy')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Gender And Age Prediction Accuracy')
|
|
plt.legend()
|
|
|
|
plt.subplot(132)
|
|
plt.plot(range(1, len(history['train_age_loss']) + 1), history['train_age_loss'], label = 'train_age_loss', c = 'r')
|
|
plt.plot(range(1, len(history['val_age_loss']) + 1), history['val_age_loss'], label = 'val_age_loss', c = 'g')
|
|
plt.title('Age Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Age Loss')
|
|
plt.legend()
|
|
|
|
|
|
plt.subplot(133)
|
|
plt.plot(range(1, len(history['train_gender_loss']) + 1), history['train_gender_loss'], label = 'train_gender_loss', c = 'b')
|
|
plt.plot(range(1, len(history['val_gender_loss']) + 1), history['val_gender_loss'], label = 'val_gender_loss', c = 'y')
|
|
plt.title('Gender Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Gender Loss')
|
|
plt.legend()
|
|
|
|
if save_history:
|
|
if not os.path.exists("runs"):
|
|
os.mkdir("runs")
|
|
|
|
if not os.path.exists(os.path.join('runs', "train")):
|
|
os.mkdir(os.path.join('runs', "train"))
|
|
|
|
exp = os.listdir(os.path.join("runs", 'train'))
|
|
if len(exp) == 0:
|
|
last_exp = os.path.join("runs", 'train', 'exp1')
|
|
os.mkdir(last_exp)
|
|
else:
|
|
exp_list = [int(i[3:]) for i in exp]
|
|
last_exp = os.path.join("runs", 'train', 'exp' + str(int(exp_list[-1]) + 1))
|
|
os.mkdir(last_exp)
|
|
plt.savefig(os.path.join(last_exp, "results.png"))
|