smart-interactive-display/Assets/StreamingAssets/MergeFace/AgeNet/models.py

196 lines
5.6 KiB
Python

from torch import nn
import torch
from torch.nn import functional as F
class GenderClassificationModel(nn.Module):
"VGG-Face"
def __init__(self):
super().__init__()
self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
self.fc6 = nn.Linear(2048, 4096)
self.fc7 = nn.Linear(4096, 4096)
self.fc8 = nn.Linear(4096, 1)
def forward(self, x):
""" Pytorch forward
Args:
x: input image (224x224)
Returns: class logits
"""
x = F.relu(self.conv_1_1(x))
x = F.relu(self.conv_1_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_2_1(x))
x = F.relu(self.conv_2_2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_3_1(x))
x = F.relu(self.conv_3_2(x))
x = F.relu(self.conv_3_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_4_1(x))
x = F.relu(self.conv_4_2(x))
x = F.relu(self.conv_4_3(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_5_1(x))
x = F.relu(self.conv_5_2(x))
x = F.relu(self.conv_5_3(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc6(x))
x = F.dropout(x, 0.5, self.training)
x = F.relu(self.fc7(x))
x = F.dropout(x, 0.5, self.training)
return F.sigmoid(self.fc8(x))
class AgeRangeModel(nn.Module):
def __init__(self, in_channels = 3, backbone = 'resnet50', pretrained = False, num_classes = 9):
super().__init__()
self.Conv1 = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, 1, padding = 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Dropout2d(0.3),
nn.MaxPool2d(2),
)
self.Conv2 = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, padding = 1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.3),
nn.MaxPool2d(2),
)
self.Conv3 = nn.Sequential(
nn.Conv2d(256, 512, 3, 1, padding = 1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Dropout2d(0.3),
nn.MaxPool2d(2),
)
self.adap = nn.AdaptiveAvgPool2d((2,2))
self.out_age = nn.Sequential(
nn.Linear(2048, num_classes)
# nn.Softmax(dim = 1)
)
def forward(self, x):
batch_size = x.shape[0]
x = self.Conv1(x)
x = self.Conv2(x)
x = self.Conv3(x)
x = self.adap(x)
x = x.view(batch_size, -1)
x = self.out_age(x)
return x
class AgeEstimationModel(nn.Module):
"VGG-Face"
def __init__(self):
super().__init__()
self.embedding_layer = nn.Embedding(9, 64)
self.Conv1 = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, padding = 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Dropout2d(0.3),
nn.MaxPool2d(2),
)
self.Conv2 = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, padding = 1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.3),
nn.MaxPool2d(2),
)
self.Conv3 = nn.Sequential(
nn.Conv2d(256, 512, 3, 1, padding = 1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Dropout2d(0.3),
nn.MaxPool2d(2),
)
self.adap = nn.AdaptiveAvgPool2d((2,2))
self.out_age = nn.Sequential(
nn.Linear(2048 + 64, 1),
nn.ReLU()
)
def forward(self, x, y):
batch_size = x.shape[0]
x = self.Conv1(x)
x = self.Conv2(x)
x = self.Conv3(x)
x = self.adap(x)
x = x.view(batch_size, -1)
y = self.embedding_layer(y)
x = torch.cat([x,y], dim = 1)
x = self.out_age(x)
return x
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.gender_model = GenderClassificationModel()
self.age_range_model = AgeRangeModel()
self.age_estimation_model = AgeEstimationModel()
def forward(self,x):
"""x: batch, 3, 64, 64"""
if len(x.shape) == 3:
x = x[None, ...]
predicted_genders = self.gender_model(x)
age_ranges = self.age_range_model(x)
y = torch.argmax(age_ranges, dim = 1).view(-1,)
estimated_ages = self.age_estimation_model(x, y)
return predicted_genders, estimated_ages
if __name__ == '__main__':
model = Model()
x = torch.rand(2,3,64,64)
genders, ages = model(x) #
print(genders.shape, ages.shape)