66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
from torch.utils.data import Dataset, DataLoader
|
|
import os
|
|
from PIL import Image
|
|
from torchvision import transforms as T
|
|
|
|
class UTKFaceDataset(Dataset):
|
|
def __init__(self, root_dir, transform = None):
|
|
super().__init__()
|
|
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
self.filename_list = os.listdir(root_dir)
|
|
|
|
def __len__(self):
|
|
return len(self.filename_list)
|
|
|
|
def age_to_class(self, age):
|
|
range_ages = [0, 4, 9, 15, 25, 35, 45, 60, 75]
|
|
if age > max(range_ages):
|
|
return len(range_ages) - 1
|
|
for i in range(len(range_ages) - 1):
|
|
if range_ages[i] <= age <= range_ages[i + 1]:
|
|
return i
|
|
|
|
def __getitem__(self, idx):
|
|
filename = self.filename_list[idx]
|
|
|
|
info = filename.split("_")
|
|
age = int(info[0])
|
|
age_label = self.age_to_class(age)
|
|
gender = int(info[1])
|
|
|
|
filename = os.path.join(self.root_dir, filename)
|
|
|
|
image = Image.open(filename)
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
return image, gender, age_label, age
|
|
|
|
if __name__ == '__main__':
|
|
|
|
image_size = (64,64)
|
|
root_dir = '/kaggle/input/utkface-new/UTKFace'
|
|
batch_size = 512
|
|
num_workers = os.cpu_count()
|
|
|
|
train_transform = T.Compose(
|
|
[
|
|
T.Resize(image_size),
|
|
T.RandomHorizontalFlip(0.5),
|
|
T.RandomRotation(10),
|
|
T.ToTensor(),
|
|
T.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
|
|
]
|
|
)
|
|
|
|
train_dataset = UTKFaceDataset(root_dir, transform = train_transform)
|
|
trainloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True)
|
|
|
|
for images, genders, age_labels, ages in trainloader:
|
|
print(images.shape, genders.shape, ages.shape)
|
|
break
|
|
|