update wake up python
|
@ -9,6 +9,6 @@ MonoBehaviour:
|
|||
m_GameObject: {fileID: 0}
|
||||
m_Enabled: 1
|
||||
m_EditorHideFlags: 0
|
||||
m_Script: {fileID: 11500000, guid: 2c1390943a3e4db69adfdfa5c2bfb8ed, type: 3}
|
||||
m_Script: {fileID: 11500000, guid: 835f0e9a0e0b4ceeac7446d5f7384fb3, type: 3}
|
||||
m_Name: GenerateImageSuccess
|
||||
m_EditorClassIdentifier:
|
||||
|
|
|
@ -38,7 +38,7 @@ RenderSettings:
|
|||
m_ReflectionIntensity: 1
|
||||
m_CustomReflection: {fileID: 0}
|
||||
m_Sun: {fileID: 0}
|
||||
m_IndirectSpecularColor: {r: 0.44657826, g: 0.49641263, b: 0.57481676, a: 1}
|
||||
m_IndirectSpecularColor: {r: 0.44657898, g: 0.4964133, b: 0.5748178, a: 1}
|
||||
m_UseRadianceAmbientProbe: 0
|
||||
--- !u!157 &3
|
||||
LightmapSettings:
|
||||
|
@ -708,7 +708,7 @@ MonoBehaviour:
|
|||
m_Name:
|
||||
m_EditorClassIdentifier:
|
||||
m_Material: {fileID: 0}
|
||||
m_Color: {r: 1, g: 1, b: 1, a: 0.39215687}
|
||||
m_Color: {r: 1, g: 1, b: 1, a: 0.078431375}
|
||||
m_RaycastTarget: 1
|
||||
m_RaycastPadding: {x: 0, y: 0, z: 0, w: 0}
|
||||
m_Maskable: 1
|
||||
|
|
|
@ -12,7 +12,7 @@ namespace GadGame.Scripts.Coffee
|
|||
public class CoffeeController : MonoBehaviour
|
||||
{
|
||||
[SerializeField] private VoidEvent _engageReadyEvent;
|
||||
[SerializeField] private VoidEvent _generateImageSuccessEvent;
|
||||
[SerializeField] private StringEvent _generateImageSuccessEvent;
|
||||
[SerializeField] private BoolEvent _playPassByAnimEvent;
|
||||
[SerializeField] private BoolEvent _playVideoEvent;
|
||||
[SerializeField] private FloatEvent _readyCountDownEvent;
|
||||
|
@ -81,15 +81,17 @@ namespace GadGame.Scripts.Coffee
|
|||
_hintText.text = _texts[0];
|
||||
}
|
||||
|
||||
private void OnGenerateImageSuccess()
|
||||
|
||||
private void OnGenerateImageSuccess(string desc)
|
||||
{
|
||||
_isLoading = false;
|
||||
_loading.DOFade(0.255f, 0.5f);
|
||||
_hintText.text = "Mô tả";
|
||||
_hintText.text = desc;
|
||||
}
|
||||
|
||||
private void OnGetEncodeImage(string encode)
|
||||
private void OnGetEncodeImage(string filePath)
|
||||
{
|
||||
_userImager.LoadImage(encode);
|
||||
_userImager.LoadImage(filePath);
|
||||
}
|
||||
|
||||
private void OnEngageReady()
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
using System;
|
||||
using UnityEngine;
|
||||
using UnityEngine.UI;
|
||||
using System.IO;
|
||||
using GadGame.Network;
|
||||
|
||||
namespace GadGame.Scripts
|
||||
{
|
||||
[RequireComponent(typeof(Image))]
|
||||
[RequireComponent(typeof(RawImage))]
|
||||
public class LoadImageEncoded : MonoBehaviour
|
||||
{
|
||||
[SerializeField] private bool _preserveAspect;
|
||||
|
@ -19,13 +21,15 @@ namespace GadGame.Scripts
|
|||
|
||||
}
|
||||
|
||||
|
||||
public void LoadImage(string encodeString)
|
||||
public void LoadImage(string streamingData)
|
||||
{
|
||||
//Decode the Base64 string to a byte array
|
||||
byte[] imageBytes = Convert.FromBase64String(encodeString);
|
||||
byte[] imageBytes = UdpSocket.Instance.DataReceived.GenerateImageSuccess ? File.ReadAllBytes(streamingData) : Convert.FromBase64String(streamingData);
|
||||
|
||||
// byte[] imageBytes = Convert.FromBase64String(encodeString);
|
||||
|
||||
_texture.LoadImage(imageBytes); // Automatically resizes the texture dimensions
|
||||
// _texture.Apply();
|
||||
var sprite = Sprite.Create(_texture, new Rect(0, 0, _texture.width, _texture.height), new Vector2(0.5f, 0.5f), 100);
|
||||
_image.sprite = sprite;
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace GadGame
|
|||
public BoolEvent PlayPassByAnim;
|
||||
public BoolEvent PlayVideo;
|
||||
public FloatEvent ReadyCountDown;
|
||||
public VoidEvent GenerateImageSuccess;
|
||||
public StringEvent GenerateImageSuccess;
|
||||
public StringEvent EncodeImage;
|
||||
|
||||
protected override async void Awake()
|
||||
|
|
|
@ -22,5 +22,6 @@ namespace GadGame.Network
|
|||
public Vector2[] PosPoints;
|
||||
public string StreamingData;
|
||||
[FormerlySerializedAs("Success")] public bool GenerateImageSuccess;
|
||||
public string Description;
|
||||
}
|
||||
}
|
|
@ -6,6 +6,9 @@ using System.Threading;
|
|||
using GadGame.Singleton;
|
||||
using Newtonsoft.Json;
|
||||
using UnityEngine;
|
||||
using System.Diagnostics;
|
||||
using Debug = UnityEngine.Debug;
|
||||
using Unity.VisualScripting;
|
||||
|
||||
namespace GadGame.Network
|
||||
{
|
||||
|
@ -23,9 +26,28 @@ namespace GadGame.Network
|
|||
private UdpClient _client;
|
||||
private IPEndPoint _remoteEndPoint;
|
||||
|
||||
|
||||
private Process _process;
|
||||
private void Start()
|
||||
{
|
||||
_process = new Process();
|
||||
_process.StartInfo.FileName = "/bin/sh";
|
||||
_process.StartInfo.Arguments = $"{Application.streamingAssetsPath}/MergeFace/run.sh";
|
||||
_process.StartInfo.WorkingDirectory = $"{Application.streamingAssetsPath}/MergeFace";
|
||||
|
||||
_process.StartInfo.RedirectStandardOutput = true;
|
||||
_process.StartInfo.RedirectStandardError = true;
|
||||
|
||||
_process.StartInfo.CreateNoWindow = false;
|
||||
_process.StartInfo.UseShellExecute = false;
|
||||
_process.OutputDataReceived += (sender, a) => {
|
||||
Debug.Log(a.Data);
|
||||
};
|
||||
_process.ErrorDataReceived += (sender, a) => {
|
||||
Debug.LogError(a.Data);
|
||||
};
|
||||
_process.Start();
|
||||
|
||||
|
||||
// Create remote endpoint
|
||||
_remoteEndPoint = new IPEndPoint(IPAddress.Parse(_ip), _sendPort);
|
||||
|
||||
|
@ -87,6 +109,10 @@ namespace GadGame.Network
|
|||
_receiveThread.Abort();
|
||||
|
||||
_client.Close();
|
||||
|
||||
_process.Close();
|
||||
_process.CloseMainWindow();
|
||||
_process.WaitForExit();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ namespace GadGame.State.MainFlowState
|
|||
|
||||
public override void Update(float time)
|
||||
{
|
||||
Runner.EncodeImage.Raise(UdpSocket.Instance.DataReceived.StreamingData);
|
||||
if(!UdpSocket.Instance.DataReceived.PassBy) {
|
||||
Runner.SetState<IdleState>();
|
||||
return;
|
||||
|
|
|
@ -6,7 +6,7 @@ namespace GadGame.State.MainFlowState
|
|||
{
|
||||
public override void Enter()
|
||||
{
|
||||
Runner.GenerateImageSuccess.Raise();
|
||||
Runner.GenerateImageSuccess.Raise(UdpSocket.Instance.DataReceived.Description);
|
||||
}
|
||||
|
||||
public override void Update(float time)
|
||||
|
|
After Width: | Height: | Size: 394 KiB |
|
@ -0,0 +1,140 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 72c164f587bf6158ebb542742fefd4d9
|
||||
TextureImporter:
|
||||
internalIDToNameTable: []
|
||||
externalObjects: {}
|
||||
serializedVersion: 12
|
||||
mipmaps:
|
||||
mipMapMode: 0
|
||||
enableMipMap: 0
|
||||
sRGBTexture: 1
|
||||
linearTexture: 0
|
||||
fadeOut: 0
|
||||
borderMipMap: 0
|
||||
mipMapsPreserveCoverage: 0
|
||||
alphaTestReferenceValue: 0.5
|
||||
mipMapFadeDistanceStart: 1
|
||||
mipMapFadeDistanceEnd: 3
|
||||
bumpmap:
|
||||
convertToNormalMap: 0
|
||||
externalNormalMap: 0
|
||||
heightScale: 0.25
|
||||
normalMapFilter: 0
|
||||
flipGreenChannel: 0
|
||||
isReadable: 0
|
||||
streamingMipmaps: 0
|
||||
streamingMipmapsPriority: 0
|
||||
vTOnly: 0
|
||||
ignoreMipmapLimit: 0
|
||||
grayScaleToAlpha: 0
|
||||
generateCubemap: 6
|
||||
cubemapConvolution: 0
|
||||
seamlessCubemap: 0
|
||||
textureFormat: 1
|
||||
maxTextureSize: 2048
|
||||
textureSettings:
|
||||
serializedVersion: 2
|
||||
filterMode: 1
|
||||
aniso: 1
|
||||
mipBias: 0
|
||||
wrapU: 1
|
||||
wrapV: 1
|
||||
wrapW: 0
|
||||
nPOTScale: 0
|
||||
lightmap: 0
|
||||
compressionQuality: 50
|
||||
spriteMode: 1
|
||||
spriteExtrude: 1
|
||||
spriteMeshType: 1
|
||||
alignment: 0
|
||||
spritePivot: {x: 0.5, y: 0.5}
|
||||
spritePixelsToUnits: 100
|
||||
spriteBorder: {x: 0, y: 0, z: 0, w: 0}
|
||||
spriteGenerateFallbackPhysicsShape: 1
|
||||
alphaUsage: 1
|
||||
alphaIsTransparency: 1
|
||||
spriteTessellationDetail: -1
|
||||
textureType: 8
|
||||
textureShape: 1
|
||||
singleChannelComponent: 0
|
||||
flipbookRows: 1
|
||||
flipbookColumns: 1
|
||||
maxTextureSizeSet: 0
|
||||
compressionQualitySet: 0
|
||||
textureFormatSet: 0
|
||||
ignorePngGamma: 0
|
||||
applyGammaDecoding: 0
|
||||
swizzle: 50462976
|
||||
cookieLightType: 0
|
||||
platformSettings:
|
||||
- serializedVersion: 3
|
||||
buildTarget: DefaultTexturePlatform
|
||||
maxTextureSize: 2048
|
||||
resizeAlgorithm: 0
|
||||
textureFormat: -1
|
||||
textureCompression: 1
|
||||
compressionQuality: 50
|
||||
crunchedCompression: 0
|
||||
allowsAlphaSplitting: 0
|
||||
overridden: 0
|
||||
ignorePlatformSupport: 0
|
||||
androidETC2FallbackOverride: 0
|
||||
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||
- serializedVersion: 3
|
||||
buildTarget: Standalone
|
||||
maxTextureSize: 2048
|
||||
resizeAlgorithm: 0
|
||||
textureFormat: -1
|
||||
textureCompression: 1
|
||||
compressionQuality: 50
|
||||
crunchedCompression: 0
|
||||
allowsAlphaSplitting: 0
|
||||
overridden: 0
|
||||
ignorePlatformSupport: 0
|
||||
androidETC2FallbackOverride: 0
|
||||
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||
- serializedVersion: 3
|
||||
buildTarget: Android
|
||||
maxTextureSize: 2048
|
||||
resizeAlgorithm: 0
|
||||
textureFormat: -1
|
||||
textureCompression: 1
|
||||
compressionQuality: 50
|
||||
crunchedCompression: 0
|
||||
allowsAlphaSplitting: 0
|
||||
overridden: 0
|
||||
ignorePlatformSupport: 0
|
||||
androidETC2FallbackOverride: 0
|
||||
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||
- serializedVersion: 3
|
||||
buildTarget: Server
|
||||
maxTextureSize: 2048
|
||||
resizeAlgorithm: 0
|
||||
textureFormat: -1
|
||||
textureCompression: 1
|
||||
compressionQuality: 50
|
||||
crunchedCompression: 0
|
||||
allowsAlphaSplitting: 0
|
||||
overridden: 0
|
||||
ignorePlatformSupport: 0
|
||||
androidETC2FallbackOverride: 0
|
||||
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||
spriteSheet:
|
||||
serializedVersion: 2
|
||||
sprites: []
|
||||
outline: []
|
||||
physicsShape: []
|
||||
bones: []
|
||||
spriteID: 5e97eb03825dee720800000000000000
|
||||
internalID: 0
|
||||
vertices: []
|
||||
indices:
|
||||
edges: []
|
||||
weights: []
|
||||
secondaryTextures: []
|
||||
nameFileIdTable: {}
|
||||
mipmapLimitGroupName:
|
||||
pSDRemoveMatte: 0
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 560a88da2bbc70140bed167f0ba7fe37
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: fb01be13d6e88ca488dda82150319bfc
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 117dcc671050f5247bd8743b91ecaab7
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 9db2f9566d5996ffbac9835756d8f951
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 88c2d151c368129198308dbd30d9d750
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,126 @@
|
|||
from Facenet.models.mtcnn import MTCNN
|
||||
from AgeNet.models import Model
|
||||
import torch
|
||||
from torchvision import transforms as T
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
age_range_list = ["1-10", "11-20", "21-30", "31-40", "41-50", "51-60", "61-70", "71-80", "81-90", "91-100"]
|
||||
|
||||
|
||||
class AgeEstimator():
|
||||
def __init__(self, face_size=224, weights=None, device='cpu', tpx=500):
|
||||
self.thickness_per_pixels = tpx
|
||||
|
||||
if isinstance(face_size, int):
|
||||
self.face_size = (face_size, face_size)
|
||||
else:
|
||||
self.face_size = face_size
|
||||
|
||||
# Set device
|
||||
self.device = device
|
||||
if isinstance(device, str):
|
||||
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
|
||||
self.device = torch.device(device)
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
self.facenet_model = MTCNN(device=self.device)
|
||||
|
||||
self.model = Model().to(self.device)
|
||||
self.model.eval()
|
||||
if weights:
|
||||
self.model.load_state_dict(torch.load(weights, map_location=torch.device('cpu')))
|
||||
# print('Weights loaded successfully from path:', weights)
|
||||
# print('====================================================')
|
||||
|
||||
def transform(self, image):
|
||||
return T.Compose(
|
||||
[
|
||||
T.Resize(self.face_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
]
|
||||
)(image)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_image(image, face_size=(224, 224)):
|
||||
|
||||
# Resize the image to the target size
|
||||
resized_transform = T.Resize(face_size)
|
||||
|
||||
resized_image = resized_transform(image)
|
||||
|
||||
return resized_image
|
||||
|
||||
@staticmethod
|
||||
def padding_face(box, padding=20):
|
||||
return [
|
||||
box[0] - padding,
|
||||
box[1] - padding,
|
||||
box[2] + padding,
|
||||
box[3] + padding
|
||||
]
|
||||
|
||||
def predict_from_frame(self, frame, min_prob=0.9):
|
||||
image = Image.fromarray(frame)
|
||||
|
||||
ndarray_image = np.array(image)
|
||||
image_shape = ndarray_image.shape
|
||||
|
||||
try:
|
||||
|
||||
bboxes, prob = self.facenet_model.detect(image)
|
||||
bboxes = bboxes[prob > min_prob]
|
||||
|
||||
face_images = []
|
||||
|
||||
for box in bboxes:
|
||||
box = np.clip(box, 0, np.inf).astype(np.uint32)
|
||||
|
||||
padding = max(image_shape) * 5 / self.thickness_per_pixels
|
||||
padding = int(max(padding, 10))
|
||||
box = self.padding_face(box, padding)
|
||||
|
||||
face = image.crop(box)
|
||||
transformed_face = self.transform(face)
|
||||
face_images.append(transformed_face)
|
||||
|
||||
face_images = torch.stack(face_images, dim=0).to(self.device)
|
||||
|
||||
genders, ages = self.model(face_images)
|
||||
ages = torch.round(ages).long()
|
||||
|
||||
for i, box in enumerate(bboxes):
|
||||
box = np.clip(box, 0, np.inf).astype(np.uint32)
|
||||
thickness = max(image_shape) / 400
|
||||
thickness = int(max(np.ceil(thickness), 1))
|
||||
age_range = age_range_list[int(ages[i].item() / 10)]
|
||||
gender = round(genders[i].item(), 2)
|
||||
|
||||
return [gender, age_range]
|
||||
|
||||
except:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def Prediction(frame, weights="weights.pt", face_size=64, device='cuda'):
|
||||
# Initialize the AgeEstimator
|
||||
model = AgeEstimator(weights=weights, face_size=face_size, device=device)
|
||||
|
||||
# Open a connection to the camera (use 0 for default camera)
|
||||
|
||||
# Convert the frame to RGB format (OpenCV uses BGR by default)
|
||||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Predict age and gender from the frame
|
||||
predicted_image = model.predict_from_frame(rgb_frame)
|
||||
|
||||
return predicted_image
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: dee94761336911fe091215ccd6f1527c
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 29274157d02d7d616ac2fc3558f8efcb
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: a6af8aa688a0742cab6a5c8e00e9f00a
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 1cb595d88a310b7729f9182bcd8e2ae6
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,65 @@
|
|||
from torch.utils.data import Dataset, DataLoader
|
||||
import os
|
||||
from PIL import Image
|
||||
from torchvision import transforms as T
|
||||
|
||||
class UTKFaceDataset(Dataset):
|
||||
def __init__(self, root_dir, transform = None):
|
||||
super().__init__()
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.filename_list = os.listdir(root_dir)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filename_list)
|
||||
|
||||
def age_to_class(self, age):
|
||||
range_ages = [0, 4, 9, 15, 25, 35, 45, 60, 75]
|
||||
if age > max(range_ages):
|
||||
return len(range_ages) - 1
|
||||
for i in range(len(range_ages) - 1):
|
||||
if range_ages[i] <= age <= range_ages[i + 1]:
|
||||
return i
|
||||
|
||||
def __getitem__(self, idx):
|
||||
filename = self.filename_list[idx]
|
||||
|
||||
info = filename.split("_")
|
||||
age = int(info[0])
|
||||
age_label = self.age_to_class(age)
|
||||
gender = int(info[1])
|
||||
|
||||
filename = os.path.join(self.root_dir, filename)
|
||||
|
||||
image = Image.open(filename)
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image, gender, age_label, age
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
image_size = (64,64)
|
||||
root_dir = '/kaggle/input/utkface-new/UTKFace'
|
||||
batch_size = 512
|
||||
num_workers = os.cpu_count()
|
||||
|
||||
train_transform = T.Compose(
|
||||
[
|
||||
T.Resize(image_size),
|
||||
T.RandomHorizontalFlip(0.5),
|
||||
T.RandomRotation(10),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
|
||||
]
|
||||
)
|
||||
|
||||
train_dataset = UTKFaceDataset(root_dir, transform = train_transform)
|
||||
trainloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True)
|
||||
|
||||
for images, genders, age_labels, ages in trainloader:
|
||||
print(images.shape, genders.shape, ages.shape)
|
||||
break
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: a9132692d4a48bed7872f706c57c2e91
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,66 @@
|
|||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
||||
def evaluate(gender_model, age_range_model, age_estimation_model, val_dataloader, device = 'cpu', verbose = 0):
|
||||
# set device
|
||||
if isinstance(device, str):
|
||||
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
|
||||
device = torch.device(device)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Loss function
|
||||
AgeRangeLoss = nn.CrossEntropyLoss()
|
||||
GenderLoss = nn.BCELoss()
|
||||
AgeEstimationLoss = nn.L1Loss()
|
||||
|
||||
gender_model = gender_model.to(device)
|
||||
age_range_model = age_range_model.to(device)
|
||||
age_estimation_model = age_estimation_model.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
age_range_model.eval()
|
||||
gender_model.eval()
|
||||
age_estimation_model.eval()
|
||||
age_accuracy = 0
|
||||
gender_accuracy = 0
|
||||
total_age_loss = 0
|
||||
total_gender_loss = 0
|
||||
total_age_estimation_loss = 0
|
||||
if verbose == 1:
|
||||
val_dataloader = tqdm(val_dataloader, desc = 'Evaluate: ', ncols = 100)
|
||||
|
||||
for images, genders, age_labels, ages in val_dataloader:
|
||||
batch_size = images.shape[0]
|
||||
|
||||
images, genders, age_labels, ages = images.to(device), genders.to(device), age_labels.to(device), ages.to(device)
|
||||
|
||||
pred_genders = gender_model(images).view(-1)
|
||||
pred_age_labels = age_range_model(images)
|
||||
|
||||
age_loss = AgeRangeLoss(pred_age_labels, age_labels.long())
|
||||
gender_loss = GenderLoss(pred_genders, genders.float())
|
||||
|
||||
total_age_loss += age_loss.item()
|
||||
total_gender_loss += gender_loss.item()
|
||||
|
||||
gender_acc = torch.sum(torch.round(pred_genders) == genders)/batch_size
|
||||
age_acc = torch.sum(torch.argmax(pred_age_labels, dim = 1) == age_labels)/batch_size
|
||||
|
||||
age_accuracy += age_acc
|
||||
gender_accuracy += gender_acc
|
||||
|
||||
estimated_ages = age_estimation_model(images, age_labels).view(-1)
|
||||
age_estimation_loss = AgeEstimationLoss(ages, estimated_ages)
|
||||
|
||||
total_age_estimation_loss += age_estimation_loss.item()
|
||||
|
||||
val_age_loss = total_age_loss / len(val_dataloader)
|
||||
val_gender_loss = total_gender_loss / len(val_dataloader)
|
||||
val_age_accuracy = age_accuracy / len(val_dataloader)
|
||||
val_gender_accuracy = gender_accuracy / len(val_dataloader)
|
||||
val_age_estimation_loss = total_age_estimation_loss / len(val_dataloader)
|
||||
|
||||
return val_age_loss, val_gender_loss, val_age_accuracy, val_gender_accuracy, val_age_estimation_loss
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: b1bc431225702cd32bc5814aa87bfcba
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,196 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
class GenderClassificationModel(nn.Module):
|
||||
"VGG-Face"
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
|
||||
self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||
self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||
self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||
self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
|
||||
self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
|
||||
self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
|
||||
self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
|
||||
self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||
self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||
self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||
self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||
self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||
self.fc6 = nn.Linear(2048, 4096)
|
||||
self.fc7 = nn.Linear(4096, 4096)
|
||||
self.fc8 = nn.Linear(4096, 1)
|
||||
|
||||
def forward(self, x):
|
||||
""" Pytorch forward
|
||||
|
||||
Args:
|
||||
x: input image (224x224)
|
||||
|
||||
Returns: class logits
|
||||
|
||||
"""
|
||||
x = F.relu(self.conv_1_1(x))
|
||||
x = F.relu(self.conv_1_2(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.conv_2_1(x))
|
||||
x = F.relu(self.conv_2_2(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.conv_3_1(x))
|
||||
x = F.relu(self.conv_3_2(x))
|
||||
x = F.relu(self.conv_3_3(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.conv_4_1(x))
|
||||
x = F.relu(self.conv_4_2(x))
|
||||
x = F.relu(self.conv_4_3(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.conv_5_1(x))
|
||||
x = F.relu(self.conv_5_2(x))
|
||||
x = F.relu(self.conv_5_3(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.relu(self.fc6(x))
|
||||
x = F.dropout(x, 0.5, self.training)
|
||||
x = F.relu(self.fc7(x))
|
||||
x = F.dropout(x, 0.5, self.training)
|
||||
return F.sigmoid(self.fc8(x))
|
||||
|
||||
class AgeRangeModel(nn.Module):
|
||||
def __init__(self, in_channels = 3, backbone = 'resnet50', pretrained = False, num_classes = 9):
|
||||
super().__init__()
|
||||
|
||||
self.Conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 64, 3, 1, padding = 1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.Conv2 = nn.Sequential(
|
||||
nn.Conv2d(64, 256, 3, 1, padding = 1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.Conv3 = nn.Sequential(
|
||||
nn.Conv2d(256, 512, 3, 1, padding = 1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.adap = nn.AdaptiveAvgPool2d((2,2))
|
||||
|
||||
self.out_age = nn.Sequential(
|
||||
nn.Linear(2048, num_classes)
|
||||
# nn.Softmax(dim = 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
x = self.Conv1(x)
|
||||
x = self.Conv2(x)
|
||||
x = self.Conv3(x)
|
||||
|
||||
x = self.adap(x)
|
||||
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
x = self.out_age(x)
|
||||
|
||||
return x
|
||||
|
||||
class AgeEstimationModel(nn.Module):
|
||||
"VGG-Face"
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embedding_layer = nn.Embedding(9, 64)
|
||||
|
||||
self.Conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64, 3, 1, padding = 1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.Conv2 = nn.Sequential(
|
||||
nn.Conv2d(64, 256, 3, 1, padding = 1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.Conv3 = nn.Sequential(
|
||||
nn.Conv2d(256, 512, 3, 1, padding = 1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
|
||||
self.adap = nn.AdaptiveAvgPool2d((2,2))
|
||||
|
||||
self.out_age = nn.Sequential(
|
||||
nn.Linear(2048 + 64, 1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
batch_size = x.shape[0]
|
||||
x = self.Conv1(x)
|
||||
x = self.Conv2(x)
|
||||
x = self.Conv3(x)
|
||||
|
||||
x = self.adap(x)
|
||||
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
y = self.embedding_layer(y)
|
||||
|
||||
x = torch.cat([x,y], dim = 1)
|
||||
|
||||
x = self.out_age(x)
|
||||
|
||||
return x
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.gender_model = GenderClassificationModel()
|
||||
|
||||
self.age_range_model = AgeRangeModel()
|
||||
|
||||
self.age_estimation_model = AgeEstimationModel()
|
||||
|
||||
def forward(self,x):
|
||||
"""x: batch, 3, 64, 64"""
|
||||
if len(x.shape) == 3:
|
||||
x = x[None, ...]
|
||||
|
||||
predicted_genders = self.gender_model(x)
|
||||
|
||||
age_ranges = self.age_range_model(x)
|
||||
|
||||
y = torch.argmax(age_ranges, dim = 1).view(-1,)
|
||||
|
||||
estimated_ages = self.age_estimation_model(x, y)
|
||||
|
||||
return predicted_genders, estimated_ages
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Model()
|
||||
x = torch.rand(2,3,64,64)
|
||||
genders, ages = model(x) #
|
||||
|
||||
print(genders.shape, ages.shape)
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: bfdd42c2ab7b1f10ab08c676ab819c1b
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,188 @@
|
|||
from tqdm import tqdm
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
from .utils import *
|
||||
from .eval import *
|
||||
from .models import *
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def train(train_data_dir, weights, device = 'cpu', image_size = 64, batch_size = 128, num_epochs = 100, steps_per_epoch = None,
|
||||
val_data_dir = None, validation_split = None, save_history = True):
|
||||
|
||||
# đặt val_data_dir and validation_split không đồng thời khác None
|
||||
assert not(val_data_dir is not None and validation_split is not None)
|
||||
|
||||
if isinstance(device, str):
|
||||
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
|
||||
device = torch.device(device)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
# get train_loader
|
||||
train_data = get_dataloader(train_data_dir, image_size = image_size, batch_size = batch_size)
|
||||
|
||||
# chia dữ liệu thành 2 tập train và val
|
||||
if val_data_dir is not None:
|
||||
val_data = get_dataloader(val_data_dir, image_size = image_size, batch_size = batch_size)
|
||||
elif validation_split is not None:
|
||||
train_data, val_data = split_dataloader(train_data, validation_split)
|
||||
else:
|
||||
val_data = None
|
||||
|
||||
if steps_per_epoch is None:
|
||||
steps_per_epoch = len(train_data)
|
||||
|
||||
# steps per epoch
|
||||
num_steps = len(train_data)
|
||||
iterator = iter(train_data)
|
||||
count_steps = 1
|
||||
|
||||
# add model to device
|
||||
gender_model = GenderClassificationModel().to(device)
|
||||
age_range_model = AgeRangeModel().to(device)
|
||||
age_estimation_model = AgeEstimationModel().to(device)
|
||||
|
||||
# Loss function
|
||||
AgeRangeLoss = nn.CrossEntropyLoss()
|
||||
GenderLoss = nn.BCELoss()
|
||||
AgeEstimationLoss = nn.L1Loss()
|
||||
|
||||
# Optimizer
|
||||
gen_optimizer = torch.optim.Adam(gender_model.parameters(), lr = 1e-4)
|
||||
age_range_optimizer = torch.optim.Adam(age_range_model.parameters(), lr = 5e-3)
|
||||
age_estimation_optimizer = torch.optim.Adam(age_estimation_model.parameters(), lr = 1e-3)
|
||||
|
||||
#schedular
|
||||
age_range_scheduler = ReduceLROnPlateau(age_range_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
||||
gender_scheduler = ReduceLROnPlateau(gen_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
||||
age_estimation_scheduler = ReduceLROnPlateau(age_estimation_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
||||
|
||||
# history
|
||||
history = {
|
||||
'train_gender_loss': [],
|
||||
'train_age_loss': [],
|
||||
'train_gender_acc': [],
|
||||
'train_age_acc': [],
|
||||
'val_gender_loss': [],
|
||||
'val_age_loss': [],
|
||||
'val_gender_acc': [],
|
||||
'val_age_acc': [],
|
||||
}
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
total_gender_loss = 0
|
||||
total_age_loss = 0
|
||||
age_accuracy = 0
|
||||
gender_accuracy = 0
|
||||
|
||||
total_age_estimation_loss = 0
|
||||
|
||||
gender_model.train()
|
||||
age_range_model.train()
|
||||
age_estimation_model.train()
|
||||
|
||||
for step in tqdm(range(steps_per_epoch), desc = f'Epoch {epoch}/{num_epochs}: ', ncols = 100):
|
||||
images, genders, age_labels, ages = next(iterator)
|
||||
batch_size = images.shape[0]
|
||||
|
||||
images, genders, age_labels, ages = images.to(device), genders.to(device), age_labels.to(device), ages.to(device)
|
||||
|
||||
pred_genders = gender_model(images).view(-1)
|
||||
pred_age_labels = age_range_model(images)
|
||||
|
||||
age_loss = AgeRangeLoss(pred_age_labels, age_labels.long())
|
||||
|
||||
gender_loss = GenderLoss(pred_genders, genders.float())
|
||||
|
||||
total_age_loss += age_loss.item()
|
||||
total_gender_loss += gender_loss.item()
|
||||
|
||||
gender_acc = torch.sum(torch.round(pred_genders) == genders)/batch_size
|
||||
age_acc = torch.sum(torch.argmax(pred_age_labels, dim = 1) == age_labels)/batch_size
|
||||
|
||||
age_accuracy += age_acc
|
||||
gender_accuracy += gender_acc
|
||||
|
||||
age_range_optimizer.zero_grad()
|
||||
age_loss.backward()
|
||||
age_range_optimizer.step()
|
||||
|
||||
gen_optimizer.zero_grad()
|
||||
gender_loss.backward()
|
||||
gen_optimizer.step()
|
||||
|
||||
# age estimation loss
|
||||
estimated_ages = age_estimation_model(images, age_labels).view(-1)
|
||||
age_estimation_loss = AgeEstimationLoss(ages, estimated_ages)
|
||||
|
||||
age_estimation_optimizer.zero_grad()
|
||||
age_estimation_loss.backward()
|
||||
age_estimation_optimizer.step()
|
||||
|
||||
total_age_estimation_loss += age_estimation_loss.item()
|
||||
|
||||
# nếu nó duyệt hết qua tập dữ liệu thì cho nó lặp lại 1 lần nữa
|
||||
if count_steps == num_steps:
|
||||
iterator = iter(train_data)
|
||||
count_steps = 0
|
||||
count_steps += 1
|
||||
|
||||
train_age_loss = total_age_loss / steps_per_epoch
|
||||
train_gender_loss = total_gender_loss / steps_per_epoch
|
||||
train_age_accuracy = age_accuracy / steps_per_epoch
|
||||
train_gender_accuracy = gender_accuracy / steps_per_epoch
|
||||
|
||||
train_age_estimation_loss = total_age_estimation_loss/steps_per_epoch
|
||||
|
||||
history['train_age_loss'].append(float(train_age_loss))
|
||||
history['train_gender_loss'].append(float(train_gender_loss))
|
||||
history['train_age_acc'].append(float(train_age_accuracy))
|
||||
history['train_gender_acc'].append(float(train_gender_accuracy))
|
||||
|
||||
print(f'train_age_loss: {train_age_loss: .2f}, train_gender_loss: {train_gender_loss: .3f}, train_age_accuracy: {train_age_accuracy: .2f}, train_gender_accuracy: {train_gender_accuracy: .2f}, train_age_estimation_loss: {train_age_estimation_loss: .3f}')
|
||||
if val_data:
|
||||
val_age_loss, val_gender_loss, val_age_accuracy, val_gender_accuracy, val_age_estimation_loss = evaluate(gender_model, age_range_model, age_estimation_model, val_data, device = device)
|
||||
history['val_age_loss'].append(float(val_age_loss))
|
||||
history['val_gender_loss'].append(float(val_gender_loss))
|
||||
history['val_age_acc'].append(float(val_age_accuracy))
|
||||
history['val_gender_acc'].append(float(val_gender_accuracy))
|
||||
|
||||
age_range_scheduler.step(np.round(val_age_loss, 3))
|
||||
gender_scheduler.step(np.round(val_gender_loss, 3))
|
||||
age_estimation_scheduler.step(np.round(val_age_estimation_loss, 3))
|
||||
print(f'val_age_loss: {val_age_loss: .2f}, val_gender_loss: {val_gender_loss: .3f}, val_age_accuracy: {val_age_accuracy: .2f}, val_gender_accuracy: {val_gender_accuracy: .2f}, val_age_estimation_loss: {val_age_estimation_loss : .3f}')
|
||||
|
||||
if weights:
|
||||
class dummy_model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.gender_model = gender_model
|
||||
self.age_range_model = age_range_model
|
||||
self.age_estimation_model = age_estimation_model
|
||||
|
||||
def forward(self, x):
|
||||
return
|
||||
|
||||
model = dummy_model()
|
||||
torch.save(model.state_dict(), weights)
|
||||
print(f'Saved successfully last weights to:', weights)
|
||||
|
||||
if save_history:
|
||||
visualize_history(history, save_history)
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_size = 128
|
||||
image_size = 64
|
||||
train_data_dir = '/kaggle/input/utkface-new/UTKFace'
|
||||
device = 'cuda'
|
||||
weights = 'weights\AgeGenderWeights.pt'
|
||||
epochs = 100
|
||||
train(train_data_dir, weights, device = device, steps_per_epoch = None,
|
||||
validation_split = 0.2, image_size = image_size, batch_size = batch_size, save_history = True)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: d70c80bb1ccf8c73297c6a434b5035dc
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,93 @@
|
|||
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
|
||||
from .dataset import UTKFaceDataset
|
||||
from torchvision import transforms as T
|
||||
from matplotlib import pyplot as plt
|
||||
import os
|
||||
|
||||
def get_dataloader(root_dir, image_size = 64, batch_size = 128, shuffle = True, num_workers = 1):
|
||||
if isinstance(image_size, int):
|
||||
image_size = (image_size, image_size)
|
||||
|
||||
train_transform = T.Compose(
|
||||
[
|
||||
T.Resize(image_size),
|
||||
T.RandomHorizontalFlip(0.2),
|
||||
T.RandomRotation(10),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
|
||||
]
|
||||
)
|
||||
|
||||
train_dataset = UTKFaceDataset(root_dir, transform = train_transform)
|
||||
trainloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers, drop_last = True)
|
||||
return trainloader
|
||||
|
||||
def split_dataloader(train_data, validation_split = 0.2):
|
||||
# Chia DataLoader thành phần train và test
|
||||
train_ratio = 1 - validation_split # Tỷ lệ phần train (80%)
|
||||
train_size = int(train_ratio * len(train_data.dataset)) # Số lượng mẫu dùng cho train
|
||||
|
||||
indices = list(range(len(train_data.dataset))) # Danh sách các chỉ số của dataset
|
||||
train_indices = indices[:train_size] # Chỉ số của mẫu dùng cho train
|
||||
val_indices = indices[train_size:] # Chỉ số của mẫu dùng cho test
|
||||
|
||||
# lấy dữ liệu từ dataloader
|
||||
dataset = train_data.dataset
|
||||
batch_size = train_data.batch_size
|
||||
num_workers = train_data.num_workers
|
||||
|
||||
# Tạo ra các SubsetRandomSampler để chọn một phần dữ liệu cho train và test
|
||||
train_sampler = SubsetRandomSampler(train_indices)
|
||||
val_sampler = SubsetRandomSampler(val_indices)
|
||||
|
||||
# Tạo DataLoader mới từ SubsetRandomSampler
|
||||
train_data = DataLoader(dataset, batch_size = batch_size, sampler = train_sampler, num_workers = num_workers, drop_last = True)
|
||||
val_data = DataLoader(dataset, batch_size = batch_size, sampler = val_sampler, num_workers = num_workers, drop_last = True)
|
||||
|
||||
return train_data, val_data
|
||||
|
||||
def visualize_history(history, save_history=True):
|
||||
plt.figure(figsize = (20,8))
|
||||
plt.subplot(131)
|
||||
plt.plot(range(1, len(history['train_age_acc']) + 1), history['train_age_acc'], label = 'train_age_acc', c = 'r')
|
||||
plt.plot(range(1, len(history['val_age_acc']) + 1), history['val_age_acc'], label = 'val_age_acc', c = 'g')
|
||||
plt.plot(range(1, len(history['train_gender_acc']) + 1), history['train_gender_acc'], label = 'train_gender_acc', c = 'b')
|
||||
plt.plot(range(1, len(history['val_gender_acc']) + 1), history['val_gender_acc'], label = 'val_gender_acc', c = 'y')
|
||||
plt.title('Accuracy')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Gender And Age Prediction Accuracy')
|
||||
plt.legend()
|
||||
|
||||
plt.subplot(132)
|
||||
plt.plot(range(1, len(history['train_age_loss']) + 1), history['train_age_loss'], label = 'train_age_loss', c = 'r')
|
||||
plt.plot(range(1, len(history['val_age_loss']) + 1), history['val_age_loss'], label = 'val_age_loss', c = 'g')
|
||||
plt.title('Age Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Age Loss')
|
||||
plt.legend()
|
||||
|
||||
|
||||
plt.subplot(133)
|
||||
plt.plot(range(1, len(history['train_gender_loss']) + 1), history['train_gender_loss'], label = 'train_gender_loss', c = 'b')
|
||||
plt.plot(range(1, len(history['val_gender_loss']) + 1), history['val_gender_loss'], label = 'val_gender_loss', c = 'y')
|
||||
plt.title('Gender Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Gender Loss')
|
||||
plt.legend()
|
||||
|
||||
if save_history:
|
||||
if not os.path.exists("runs"):
|
||||
os.mkdir("runs")
|
||||
|
||||
if not os.path.exists(os.path.join('runs', "train")):
|
||||
os.mkdir(os.path.join('runs', "train"))
|
||||
|
||||
exp = os.listdir(os.path.join("runs", 'train'))
|
||||
if len(exp) == 0:
|
||||
last_exp = os.path.join("runs", 'train', 'exp1')
|
||||
os.mkdir(last_exp)
|
||||
else:
|
||||
exp_list = [int(i[3:]) for i in exp]
|
||||
last_exp = os.path.join("runs", 'train', 'exp' + str(int(exp_list[-1]) + 1))
|
||||
os.mkdir(last_exp)
|
||||
plt.savefig(os.path.join(last_exp, "results.png"))
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: b4f35a03a63ebf040b025624d23fe471
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 98e25faaf07e489a624e965684ee8de0756f3ebc
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: fc1c2c2dc78674ff5aa4d67db1ed5b52
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 8688372304eed0b45bb5efd33ebc3392
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2019 Timothy Esler
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: f630914c0f0b597b08fd828a65985720
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,11 @@
|
|||
from .models.inception_resnet_v1 import InceptionResnetV1
|
||||
from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten, fixed_image_standardization
|
||||
from .models.utils.detect_face import extract_face
|
||||
from .models.utils import training
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings(
|
||||
action="ignore",
|
||||
message="This overload of nonzero is deprecated:\n\tnonzero()",
|
||||
category=UserWarning
|
||||
)
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 5663da78e37b6f3a6857c1c6577c7cbd
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 67cd73539a2de632781235ca7d7aa494
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 4e61512c59f6bb79cbd37d2aea24fbeb
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: ba1c774d77bdcf8998b29d38e63a737c
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,3 @@
|
|||
2018*
|
||||
*.json
|
||||
profile.txt
|
After Width: | Height: | Size: 195 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: ab8572daa090e73df85f1b9940917e3e
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 294 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: a4f62947b8e078759b49c81457a887a9
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 1.2 MiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: f4e477d7c34ebe2148591e73fea94996
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 4a4a65ea65f71d3ed9d91a8d4c72f251
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 142d74e0017340668b10f92b8e102eed
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: fa5170ee8c30cf9ae834223e383f10d2
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 1e13d7bf50f86d2b2a29f13492ce0a16
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 5cb20437f0c28e5729e29dac4c25e7ca
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 760 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: d8a521ed41d9eaaeda79409eb7012b7b
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 232f83bfc7690e5629ed698a372308f8
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 4.4 MiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 3d8e513bbbbd722d3afe95cc13f44c24
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: ac46b8e53a69d92b4a55cb1b5a93161d
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 767 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 77ab9c4d3c60fba85857bf9de03df810
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: a4bcc08969f47fae09852a8e9a1fac79
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 921 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 18e226964c759b93bbae123cf4c06a30
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: c9c38cb76534c388890108443f8a7777
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 2.6 MiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: d56516d445a8d2729bc5b6367905199a
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 648250ee8f06135a3ba628c9c1484bfb
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 4bcc4387ae55c265ea87679d87399f2c
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 75 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: ff6b67c1e9331c820a98c1248c0b07c4
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 7f63815e904b464cca4739266ce7caba
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 76 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: aca472f8bd3620e63ac4dc792da4f5f5
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 9ea2b9aead09af7929e7faa6f80a30fc
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 75 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 4e71cfcf1e433cfb5a49b8ea3ae5e5c7
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: bf52100662173c6c7affb1431041e3ab
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 76 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 040d973d841ceee05b7e5e5e499713fa
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 2e747d2b1743b4ac28ebd50040a94842
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 78 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 121246c20607f1fbfa66d525a9289025
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 3dcc96f5e2b21bb19a85662ada3056cb
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 6af4b7dac79e10f35a9563fa58f5adb3
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,279 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Face detection and recognition training pipeline\n",
|
||||
"\n",
|
||||
"The following example illustrates how to fine-tune an InceptionResnetV1 model on your own dataset. This will mostly follow standard pytorch training patterns."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader, SubsetRandomSampler\n",
|
||||
"from torch import optim\n",
|
||||
"from torch.optim.lr_scheduler import MultiStepLR\n",
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"from torchvision import datasets, transforms\n",
|
||||
"import numpy as np\n",
|
||||
"import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define run parameters\n",
|
||||
"\n",
|
||||
"The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify `data_dir` to the location of the dataset on wish to finetune on."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_dir = '../data/test_images'\n",
|
||||
"\n",
|
||||
"batch_size = 32\n",
|
||||
"epochs = 8\n",
|
||||
"workers = 0 if os.name == 'nt' else 8"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Determine if an nvidia GPU is available"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"print('Running on device: {}'.format(device))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define MTCNN module\n",
|
||||
"\n",
|
||||
"See `help(MTCNN)` for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mtcnn = MTCNN(\n",
|
||||
" image_size=160, margin=0, min_face_size=20,\n",
|
||||
" thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n",
|
||||
" device=device\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Perfom MTCNN facial detection\n",
|
||||
"\n",
|
||||
"Iterate through the DataLoader object and obtain cropped faces."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))\n",
|
||||
"dataset.samples = [\n",
|
||||
" (p, p.replace(data_dir, data_dir + '_cropped'))\n",
|
||||
" for p, _ in dataset.samples\n",
|
||||
"]\n",
|
||||
" \n",
|
||||
"loader = DataLoader(\n",
|
||||
" dataset,\n",
|
||||
" num_workers=workers,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" collate_fn=training.collate_pil\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for i, (x, y) in enumerate(loader):\n",
|
||||
" mtcnn(x, save_path=y)\n",
|
||||
" print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')\n",
|
||||
" \n",
|
||||
"# Remove mtcnn to reduce GPU memory usage\n",
|
||||
"del mtcnn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define Inception Resnet V1 module\n",
|
||||
"\n",
|
||||
"See `help(InceptionResnetV1)` for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"resnet = InceptionResnetV1(\n",
|
||||
" classify=True,\n",
|
||||
" pretrained='vggface2',\n",
|
||||
" num_classes=len(dataset.class_to_idx)\n",
|
||||
").to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define optimizer, scheduler, dataset, and dataloader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimizer = optim.Adam(resnet.parameters(), lr=0.001)\n",
|
||||
"scheduler = MultiStepLR(optimizer, [5, 10])\n",
|
||||
"\n",
|
||||
"trans = transforms.Compose([\n",
|
||||
" np.float32,\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" fixed_image_standardization\n",
|
||||
"])\n",
|
||||
"dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n",
|
||||
"img_inds = np.arange(len(dataset))\n",
|
||||
"np.random.shuffle(img_inds)\n",
|
||||
"train_inds = img_inds[:int(0.8 * len(img_inds))]\n",
|
||||
"val_inds = img_inds[int(0.8 * len(img_inds)):]\n",
|
||||
"\n",
|
||||
"train_loader = DataLoader(\n",
|
||||
" dataset,\n",
|
||||
" num_workers=workers,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" sampler=SubsetRandomSampler(train_inds)\n",
|
||||
")\n",
|
||||
"val_loader = DataLoader(\n",
|
||||
" dataset,\n",
|
||||
" num_workers=workers,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" sampler=SubsetRandomSampler(val_inds)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define loss and evaluation functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loss_fn = torch.nn.CrossEntropyLoss()\n",
|
||||
"metrics = {\n",
|
||||
" 'fps': training.BatchTimer(),\n",
|
||||
" 'acc': training.accuracy\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Train model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"writer = SummaryWriter()\n",
|
||||
"writer.iteration, writer.interval = 0, 10\n",
|
||||
"\n",
|
||||
"print('\\n\\nInitial')\n",
|
||||
"print('-' * 10)\n",
|
||||
"resnet.eval()\n",
|
||||
"training.pass_epoch(\n",
|
||||
" resnet, loss_fn, val_loader,\n",
|
||||
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||||
" writer=writer\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for epoch in range(epochs):\n",
|
||||
" print('\\nEpoch {}/{}'.format(epoch + 1, epochs))\n",
|
||||
" print('-' * 10)\n",
|
||||
"\n",
|
||||
" resnet.train()\n",
|
||||
" training.pass_epoch(\n",
|
||||
" resnet, loss_fn, train_loader, optimizer, scheduler,\n",
|
||||
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||||
" writer=writer\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" resnet.eval()\n",
|
||||
" training.pass_epoch(\n",
|
||||
" resnet, loss_fn, val_loader,\n",
|
||||
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||||
" writer=writer\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"writer.close()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 7212bac47cb61dc798ee7316d53b8f88
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,245 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Face detection and recognition inference pipeline\n",
|
||||
"\n",
|
||||
"The following example illustrates how to use the `facenet_pytorch` python package to perform face detection and recogition on an image dataset using an Inception Resnet V1 pretrained on the VGGFace2 dataset.\n",
|
||||
"\n",
|
||||
"The following Pytorch methods are included:\n",
|
||||
"* Datasets\n",
|
||||
"* Dataloaders\n",
|
||||
"* GPU/CPU processing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from facenet_pytorch import MTCNN, InceptionResnetV1\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from torchvision import datasets\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"workers = 0 if os.name == 'nt' else 4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Determine if an nvidia GPU is available"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running on device: cuda:0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"print('Running on device: {}'.format(device))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define MTCNN module\n",
|
||||
"\n",
|
||||
"Default params shown for illustration, but not needed. Note that, since MTCNN is a collection of neural nets and other code, the device must be passed in the following way to enable copying of objects when needed internally.\n",
|
||||
"\n",
|
||||
"See `help(MTCNN)` for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mtcnn = MTCNN(\n",
|
||||
" image_size=160, margin=0, min_face_size=20,\n",
|
||||
" thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n",
|
||||
" device=device\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define Inception Resnet V1 module\n",
|
||||
"\n",
|
||||
"Set classify=True for pretrained classifier. For this example, we will use the model to output embeddings/CNN features. Note that for inference, it is important to set the model to `eval` mode.\n",
|
||||
"\n",
|
||||
"See `help(InceptionResnetV1)` for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Define a dataset and data loader\n",
|
||||
"\n",
|
||||
"We add the `idx_to_class` attribute to the dataset to enable easy recoding of label indices to identity names later one."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def collate_fn(x):\n",
|
||||
" return x[0]\n",
|
||||
"\n",
|
||||
"dataset = datasets.ImageFolder('../data/test_images')\n",
|
||||
"dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}\n",
|
||||
"loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Perfom MTCNN facial detection\n",
|
||||
"\n",
|
||||
"Iterate through the DataLoader object and detect faces and associated detection probabilities for each. The `MTCNN` forward method returns images cropped to the detected face, if a face was detected. By default only a single detected face is returned - to have `MTCNN` return all detected faces, set `keep_all=True` when creating the MTCNN object above.\n",
|
||||
"\n",
|
||||
"To obtain bounding boxes rather than cropped face images, you can instead call the lower-level `mtcnn.detect()` function. See `help(mtcnn.detect)` for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Face detected with probability: 0.999957\n",
|
||||
"Face detected with probability: 0.999927\n",
|
||||
"Face detected with probability: 0.999662\n",
|
||||
"Face detected with probability: 0.999873\n",
|
||||
"Face detected with probability: 0.999991\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"aligned = []\n",
|
||||
"names = []\n",
|
||||
"for x, y in loader:\n",
|
||||
" x_aligned, prob = mtcnn(x, return_prob=True)\n",
|
||||
" if x_aligned is not None:\n",
|
||||
" print('Face detected with probability: {:8f}'.format(prob))\n",
|
||||
" aligned.append(x_aligned)\n",
|
||||
" names.append(dataset.idx_to_class[y])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Calculate image embeddings\n",
|
||||
"\n",
|
||||
"MTCNN will return images of faces all the same size, enabling easy batch processing with the Resnet recognition module. Here, since we only have a few images, we build a single batch and perform inference on it. \n",
|
||||
"\n",
|
||||
"For real datasets, code should be modified to control batch sizes being passed to the Resnet, particularly if being processed on a GPU. For repeated testing, it is best to separate face detection (using MTCNN) from embedding or classification (using InceptionResnetV1), as calculation of cropped faces or bounding boxes can then be performed a single time and detected faces saved for future use."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"aligned = torch.stack(aligned).to(device)\n",
|
||||
"embeddings = resnet(aligned).detach().cpu()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Print distance matrix for classes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" angelina_jolie bradley_cooper kate_siegel paul_rudd \\\n",
|
||||
"angelina_jolie 0.000000 1.344806 0.781201 1.425579 \n",
|
||||
"bradley_cooper 1.344806 0.000000 1.256238 0.922126 \n",
|
||||
"kate_siegel 0.781201 1.256238 0.000000 1.366423 \n",
|
||||
"paul_rudd 1.425579 0.922126 1.366423 0.000000 \n",
|
||||
"shea_whigham 1.448495 0.891145 1.416447 0.985438 \n",
|
||||
"\n",
|
||||
" shea_whigham \n",
|
||||
"angelina_jolie 1.448495 \n",
|
||||
"bradley_cooper 0.891145 \n",
|
||||
"kate_siegel 1.416447 \n",
|
||||
"paul_rudd 0.985438 \n",
|
||||
"shea_whigham 0.000000 \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings]\n",
|
||||
"print(pd.DataFrame(dists, columns=names, index=names))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: c791f4d9146d73a68ae3aa4bbb11589f
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
|
@ -0,0 +1,522 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### facenet-pytorch LFW evaluation\n",
|
||||
"This notebook demonstrates how to evaluate performance against the LFW dataset."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training, extract_face\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler\n",
|
||||
"from torchvision import datasets, transforms\n",
|
||||
"import numpy as np\n",
|
||||
"import os"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_dir = 'data/lfw/lfw'\n",
|
||||
"pairs_path = 'data/lfw/pairs.txt'\n",
|
||||
"\n",
|
||||
"batch_size = 16\n",
|
||||
"epochs = 15\n",
|
||||
"workers = 0 if os.name == 'nt' else 8"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Running on device: cuda:0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"print('Running on device: {}'.format(device))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mtcnn = MTCNN(\n",
|
||||
" image_size=160,\n",
|
||||
" margin=14,\n",
|
||||
" device=device,\n",
|
||||
" selection_method='center_weighted_size'\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the data loader for the input set of images\n",
|
||||
"orig_img_ds = datasets.ImageFolder(data_dir, transform=None)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# overwrites class labels in dataset with path so path can be used for saving output in mtcnn batches\n",
|
||||
"orig_img_ds.samples = [\n",
|
||||
" (p, p)\n",
|
||||
" for p, _ in orig_img_ds.samples\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"loader = DataLoader(\n",
|
||||
" orig_img_ds,\n",
|
||||
" num_workers=workers,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" collate_fn=training.collate_pil\n",
|
||||
")\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"crop_paths = []\n",
|
||||
"box_probs = []\n",
|
||||
"\n",
|
||||
"for i, (x, b_paths) in enumerate(loader):\n",
|
||||
" crops = [p.replace(data_dir, data_dir + '_cropped') for p in b_paths]\n",
|
||||
" mtcnn(x, save_path=crops)\n",
|
||||
" crop_paths.extend(crops)\n",
|
||||
" print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Remove mtcnn to reduce GPU memory usage\n",
|
||||
"del mtcnn\n",
|
||||
"torch.cuda.empty_cache()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# create dataset and data loaders from cropped images output from MTCNN\n",
|
||||
"\n",
|
||||
"trans = transforms.Compose([\n",
|
||||
" np.float32,\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" fixed_image_standardization\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n",
|
||||
"\n",
|
||||
"embed_loader = DataLoader(\n",
|
||||
" dataset,\n",
|
||||
" num_workers=workers,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" sampler=SequentialSampler(dataset)\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load pretrained resnet model\n",
|
||||
"resnet = InceptionResnetV1(\n",
|
||||
" classify=False,\n",
|
||||
" pretrained='vggface2'\n",
|
||||
").to(device)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classes = []\n",
|
||||
"embeddings = []\n",
|
||||
"resnet.eval()\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for xb, yb in embed_loader:\n",
|
||||
" xb = xb.to(device)\n",
|
||||
" b_embeddings = resnet(xb)\n",
|
||||
" b_embeddings = b_embeddings.to('cpu').numpy()\n",
|
||||
" classes.extend(yb.numpy())\n",
|
||||
" embeddings.extend(b_embeddings)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings_dict = dict(zip(crop_paths,embeddings))\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#### Evaluate embeddings by using distance metrics to perform verification on the official LFW test set.\n",
|
||||
"\n",
|
||||
"The functions in the next block are copy pasted from `facenet.src.lfw`. Unfortunately that module has an absolute import from `facenet`, so can't be imported from the submodule\n",
|
||||
"\n",
|
||||
"added functionality to return false positive and false negatives"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.model_selection import KFold\n",
|
||||
"from scipy import interpolate\n",
|
||||
"\n",
|
||||
"# LFW functions taken from David Sandberg's FaceNet implementation\n",
|
||||
"def distance(embeddings1, embeddings2, distance_metric=0):\n",
|
||||
" if distance_metric==0:\n",
|
||||
" # Euclidian distance\n",
|
||||
" diff = np.subtract(embeddings1, embeddings2)\n",
|
||||
" dist = np.sum(np.square(diff),1)\n",
|
||||
" elif distance_metric==1:\n",
|
||||
" # Distance based on cosine similarity\n",
|
||||
" dot = np.sum(np.multiply(embeddings1, embeddings2), axis=1)\n",
|
||||
" norm = np.linalg.norm(embeddings1, axis=1) * np.linalg.norm(embeddings2, axis=1)\n",
|
||||
" similarity = dot / norm\n",
|
||||
" dist = np.arccos(similarity) / math.pi\n",
|
||||
" else:\n",
|
||||
" raise 'Undefined distance metric %d' % distance_metric\n",
|
||||
"\n",
|
||||
" return dist\n",
|
||||
"\n",
|
||||
"def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
|
||||
" assert(embeddings1.shape[0] == embeddings2.shape[0])\n",
|
||||
" assert(embeddings1.shape[1] == embeddings2.shape[1])\n",
|
||||
" nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n",
|
||||
" nrof_thresholds = len(thresholds)\n",
|
||||
" k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n",
|
||||
"\n",
|
||||
" tprs = np.zeros((nrof_folds,nrof_thresholds))\n",
|
||||
" fprs = np.zeros((nrof_folds,nrof_thresholds))\n",
|
||||
" accuracy = np.zeros((nrof_folds))\n",
|
||||
"\n",
|
||||
" is_false_positive = []\n",
|
||||
" is_false_negative = []\n",
|
||||
"\n",
|
||||
" indices = np.arange(nrof_pairs)\n",
|
||||
"\n",
|
||||
" for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n",
|
||||
" if subtract_mean:\n",
|
||||
" mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n",
|
||||
" else:\n",
|
||||
" mean = 0.0\n",
|
||||
" dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n",
|
||||
"\n",
|
||||
" # Find the best threshold for the fold\n",
|
||||
" acc_train = np.zeros((nrof_thresholds))\n",
|
||||
" for threshold_idx, threshold in enumerate(thresholds):\n",
|
||||
" _, _, acc_train[threshold_idx], _ ,_ = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])\n",
|
||||
" best_threshold_index = np.argmax(acc_train)\n",
|
||||
" for threshold_idx, threshold in enumerate(thresholds):\n",
|
||||
" tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _, _, _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])\n",
|
||||
" _, _, accuracy[fold_idx], is_fp, is_fn = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])\n",
|
||||
"\n",
|
||||
" tpr = np.mean(tprs,0)\n",
|
||||
" fpr = np.mean(fprs,0)\n",
|
||||
" is_false_positive.extend(is_fp)\n",
|
||||
" is_false_negative.extend(is_fn)\n",
|
||||
"\n",
|
||||
" return tpr, fpr, accuracy, is_false_positive, is_false_negative\n",
|
||||
"\n",
|
||||
"def calculate_accuracy(threshold, dist, actual_issame):\n",
|
||||
" predict_issame = np.less(dist, threshold)\n",
|
||||
" tp = np.sum(np.logical_and(predict_issame, actual_issame))\n",
|
||||
" fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n",
|
||||
" tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))\n",
|
||||
" fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))\n",
|
||||
"\n",
|
||||
" is_fp = np.logical_and(predict_issame, np.logical_not(actual_issame))\n",
|
||||
" is_fn = np.logical_and(np.logical_not(predict_issame), actual_issame)\n",
|
||||
"\n",
|
||||
" tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)\n",
|
||||
" fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)\n",
|
||||
" acc = float(tp+tn)/dist.size\n",
|
||||
" return tpr, fpr, acc, is_fp, is_fn\n",
|
||||
"\n",
|
||||
"def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
|
||||
" assert(embeddings1.shape[0] == embeddings2.shape[0])\n",
|
||||
" assert(embeddings1.shape[1] == embeddings2.shape[1])\n",
|
||||
" nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n",
|
||||
" nrof_thresholds = len(thresholds)\n",
|
||||
" k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n",
|
||||
"\n",
|
||||
" val = np.zeros(nrof_folds)\n",
|
||||
" far = np.zeros(nrof_folds)\n",
|
||||
"\n",
|
||||
" indices = np.arange(nrof_pairs)\n",
|
||||
"\n",
|
||||
" for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n",
|
||||
" if subtract_mean:\n",
|
||||
" mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n",
|
||||
" else:\n",
|
||||
" mean = 0.0\n",
|
||||
" dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n",
|
||||
"\n",
|
||||
" # Find the threshold that gives FAR = far_target\n",
|
||||
" far_train = np.zeros(nrof_thresholds)\n",
|
||||
" for threshold_idx, threshold in enumerate(thresholds):\n",
|
||||
" _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])\n",
|
||||
" if np.max(far_train)>=far_target:\n",
|
||||
" f = interpolate.interp1d(far_train, thresholds, kind='slinear')\n",
|
||||
" threshold = f(far_target)\n",
|
||||
" else:\n",
|
||||
" threshold = 0.0\n",
|
||||
"\n",
|
||||
" val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])\n",
|
||||
"\n",
|
||||
" val_mean = np.mean(val)\n",
|
||||
" far_mean = np.mean(far)\n",
|
||||
" val_std = np.std(val)\n",
|
||||
" return val_mean, val_std, far_mean\n",
|
||||
"\n",
|
||||
"def calculate_val_far(threshold, dist, actual_issame):\n",
|
||||
" predict_issame = np.less(dist, threshold)\n",
|
||||
" true_accept = np.sum(np.logical_and(predict_issame, actual_issame))\n",
|
||||
" false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n",
|
||||
" n_same = np.sum(actual_issame)\n",
|
||||
" n_diff = np.sum(np.logical_not(actual_issame))\n",
|
||||
" val = float(true_accept) / float(n_same)\n",
|
||||
" far = float(false_accept) / float(n_diff)\n",
|
||||
" return val, far\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def evaluate(embeddings, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
|
||||
" # Calculate evaluation metrics\n",
|
||||
" thresholds = np.arange(0, 4, 0.01)\n",
|
||||
" embeddings1 = embeddings[0::2]\n",
|
||||
" embeddings2 = embeddings[1::2]\n",
|
||||
" tpr, fpr, accuracy, fp, fn = calculate_roc(thresholds, embeddings1, embeddings2,\n",
|
||||
" np.asarray(actual_issame), nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n",
|
||||
" thresholds = np.arange(0, 4, 0.001)\n",
|
||||
" val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,\n",
|
||||
" np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n",
|
||||
" return tpr, fpr, accuracy, val, val_std, far, fp, fn\n",
|
||||
"\n",
|
||||
"def add_extension(path):\n",
|
||||
" if os.path.exists(path+'.jpg'):\n",
|
||||
" return path+'.jpg'\n",
|
||||
" elif os.path.exists(path+'.png'):\n",
|
||||
" return path+'.png'\n",
|
||||
" else:\n",
|
||||
" raise RuntimeError('No file \"%s\" with extension png or jpg.' % path)\n",
|
||||
"\n",
|
||||
"def get_paths(lfw_dir, pairs):\n",
|
||||
" nrof_skipped_pairs = 0\n",
|
||||
" path_list = []\n",
|
||||
" issame_list = []\n",
|
||||
" for pair in pairs:\n",
|
||||
" if len(pair) == 3:\n",
|
||||
" path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n",
|
||||
" path1 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])))\n",
|
||||
" issame = True\n",
|
||||
" elif len(pair) == 4:\n",
|
||||
" path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n",
|
||||
" path1 = add_extension(os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])))\n",
|
||||
" issame = False\n",
|
||||
" if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist\n",
|
||||
" path_list += (path0,path1)\n",
|
||||
" issame_list.append(issame)\n",
|
||||
" else:\n",
|
||||
" nrof_skipped_pairs += 1\n",
|
||||
" if nrof_skipped_pairs>0:\n",
|
||||
" print('Skipped %d image pairs' % nrof_skipped_pairs)\n",
|
||||
"\n",
|
||||
" return path_list, issame_list\n",
|
||||
"\n",
|
||||
"def read_pairs(pairs_filename):\n",
|
||||
" pairs = []\n",
|
||||
" with open(pairs_filename, 'r') as f:\n",
|
||||
" for line in f.readlines()[1:]:\n",
|
||||
" pair = line.strip().split()\n",
|
||||
" pairs.append(pair)\n",
|
||||
" return np.array(pairs, dtype=object)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pairs = read_pairs(pairs_path)\n",
|
||||
"path_list, issame_list = get_paths(data_dir+'_cropped', pairs)\n",
|
||||
"embeddings = np.array([embeddings_dict[path] for path in path_list])\n",
|
||||
"\n",
|
||||
"tpr, fpr, accuracy, val, val_std, far, fp, fn = evaluate(embeddings, issame_list)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.995 0.995 0.99166667 0.99 0.99 0.99666667\n",
|
||||
" 0.99 0.995 0.99666667 0.99666667]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "0.9936666666666666"
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(accuracy)\n",
|
||||
"np.mean(accuracy)\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: e0374ec1d0b25b8cd871b0d2e93d2dcf
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
After Width: | Height: | Size: 12 KiB |
|
@ -0,0 +1,7 @@
|
|||
fileFormatVersion: 2
|
||||
guid: 903b193c5350d12a48fd9d7d208adfdb
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|