update wake up python
|
@ -9,6 +9,6 @@ MonoBehaviour:
|
||||||
m_GameObject: {fileID: 0}
|
m_GameObject: {fileID: 0}
|
||||||
m_Enabled: 1
|
m_Enabled: 1
|
||||||
m_EditorHideFlags: 0
|
m_EditorHideFlags: 0
|
||||||
m_Script: {fileID: 11500000, guid: 2c1390943a3e4db69adfdfa5c2bfb8ed, type: 3}
|
m_Script: {fileID: 11500000, guid: 835f0e9a0e0b4ceeac7446d5f7384fb3, type: 3}
|
||||||
m_Name: GenerateImageSuccess
|
m_Name: GenerateImageSuccess
|
||||||
m_EditorClassIdentifier:
|
m_EditorClassIdentifier:
|
||||||
|
|
|
@ -38,7 +38,7 @@ RenderSettings:
|
||||||
m_ReflectionIntensity: 1
|
m_ReflectionIntensity: 1
|
||||||
m_CustomReflection: {fileID: 0}
|
m_CustomReflection: {fileID: 0}
|
||||||
m_Sun: {fileID: 0}
|
m_Sun: {fileID: 0}
|
||||||
m_IndirectSpecularColor: {r: 0.44657826, g: 0.49641263, b: 0.57481676, a: 1}
|
m_IndirectSpecularColor: {r: 0.44657898, g: 0.4964133, b: 0.5748178, a: 1}
|
||||||
m_UseRadianceAmbientProbe: 0
|
m_UseRadianceAmbientProbe: 0
|
||||||
--- !u!157 &3
|
--- !u!157 &3
|
||||||
LightmapSettings:
|
LightmapSettings:
|
||||||
|
@ -708,7 +708,7 @@ MonoBehaviour:
|
||||||
m_Name:
|
m_Name:
|
||||||
m_EditorClassIdentifier:
|
m_EditorClassIdentifier:
|
||||||
m_Material: {fileID: 0}
|
m_Material: {fileID: 0}
|
||||||
m_Color: {r: 1, g: 1, b: 1, a: 0.39215687}
|
m_Color: {r: 1, g: 1, b: 1, a: 0.078431375}
|
||||||
m_RaycastTarget: 1
|
m_RaycastTarget: 1
|
||||||
m_RaycastPadding: {x: 0, y: 0, z: 0, w: 0}
|
m_RaycastPadding: {x: 0, y: 0, z: 0, w: 0}
|
||||||
m_Maskable: 1
|
m_Maskable: 1
|
||||||
|
|
|
@ -12,7 +12,7 @@ namespace GadGame.Scripts.Coffee
|
||||||
public class CoffeeController : MonoBehaviour
|
public class CoffeeController : MonoBehaviour
|
||||||
{
|
{
|
||||||
[SerializeField] private VoidEvent _engageReadyEvent;
|
[SerializeField] private VoidEvent _engageReadyEvent;
|
||||||
[SerializeField] private VoidEvent _generateImageSuccessEvent;
|
[SerializeField] private StringEvent _generateImageSuccessEvent;
|
||||||
[SerializeField] private BoolEvent _playPassByAnimEvent;
|
[SerializeField] private BoolEvent _playPassByAnimEvent;
|
||||||
[SerializeField] private BoolEvent _playVideoEvent;
|
[SerializeField] private BoolEvent _playVideoEvent;
|
||||||
[SerializeField] private FloatEvent _readyCountDownEvent;
|
[SerializeField] private FloatEvent _readyCountDownEvent;
|
||||||
|
@ -81,15 +81,17 @@ namespace GadGame.Scripts.Coffee
|
||||||
_hintText.text = _texts[0];
|
_hintText.text = _texts[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
private void OnGenerateImageSuccess()
|
|
||||||
|
private void OnGenerateImageSuccess(string desc)
|
||||||
{
|
{
|
||||||
|
_isLoading = false;
|
||||||
_loading.DOFade(0.255f, 0.5f);
|
_loading.DOFade(0.255f, 0.5f);
|
||||||
_hintText.text = "Mô tả";
|
_hintText.text = desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void OnGetEncodeImage(string encode)
|
private void OnGetEncodeImage(string filePath)
|
||||||
{
|
{
|
||||||
_userImager.LoadImage(encode);
|
_userImager.LoadImage(filePath);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void OnEngageReady()
|
private void OnEngageReady()
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
using System;
|
using System;
|
||||||
using UnityEngine;
|
using UnityEngine;
|
||||||
using UnityEngine.UI;
|
using UnityEngine.UI;
|
||||||
|
using System.IO;
|
||||||
|
using GadGame.Network;
|
||||||
|
|
||||||
namespace GadGame.Scripts
|
namespace GadGame.Scripts
|
||||||
{
|
{
|
||||||
[RequireComponent(typeof(Image))]
|
[RequireComponent(typeof(RawImage))]
|
||||||
public class LoadImageEncoded : MonoBehaviour
|
public class LoadImageEncoded : MonoBehaviour
|
||||||
{
|
{
|
||||||
[SerializeField] private bool _preserveAspect;
|
[SerializeField] private bool _preserveAspect;
|
||||||
|
@ -19,13 +21,15 @@ namespace GadGame.Scripts
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void LoadImage(string streamingData)
|
||||||
public void LoadImage(string encodeString)
|
|
||||||
{
|
{
|
||||||
//Decode the Base64 string to a byte array
|
//Decode the Base64 string to a byte array
|
||||||
byte[] imageBytes = Convert.FromBase64String(encodeString);
|
byte[] imageBytes = UdpSocket.Instance.DataReceived.GenerateImageSuccess ? File.ReadAllBytes(streamingData) : Convert.FromBase64String(streamingData);
|
||||||
|
|
||||||
|
// byte[] imageBytes = Convert.FromBase64String(encodeString);
|
||||||
|
|
||||||
_texture.LoadImage(imageBytes); // Automatically resizes the texture dimensions
|
_texture.LoadImage(imageBytes); // Automatically resizes the texture dimensions
|
||||||
|
// _texture.Apply();
|
||||||
var sprite = Sprite.Create(_texture, new Rect(0, 0, _texture.width, _texture.height), new Vector2(0.5f, 0.5f), 100);
|
var sprite = Sprite.Create(_texture, new Rect(0, 0, _texture.width, _texture.height), new Vector2(0.5f, 0.5f), 100);
|
||||||
_image.sprite = sprite;
|
_image.sprite = sprite;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ namespace GadGame
|
||||||
public BoolEvent PlayPassByAnim;
|
public BoolEvent PlayPassByAnim;
|
||||||
public BoolEvent PlayVideo;
|
public BoolEvent PlayVideo;
|
||||||
public FloatEvent ReadyCountDown;
|
public FloatEvent ReadyCountDown;
|
||||||
public VoidEvent GenerateImageSuccess;
|
public StringEvent GenerateImageSuccess;
|
||||||
public StringEvent EncodeImage;
|
public StringEvent EncodeImage;
|
||||||
|
|
||||||
protected override async void Awake()
|
protected override async void Awake()
|
||||||
|
|
|
@ -22,5 +22,6 @@ namespace GadGame.Network
|
||||||
public Vector2[] PosPoints;
|
public Vector2[] PosPoints;
|
||||||
public string StreamingData;
|
public string StreamingData;
|
||||||
[FormerlySerializedAs("Success")] public bool GenerateImageSuccess;
|
[FormerlySerializedAs("Success")] public bool GenerateImageSuccess;
|
||||||
|
public string Description;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -6,6 +6,9 @@ using System.Threading;
|
||||||
using GadGame.Singleton;
|
using GadGame.Singleton;
|
||||||
using Newtonsoft.Json;
|
using Newtonsoft.Json;
|
||||||
using UnityEngine;
|
using UnityEngine;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using Debug = UnityEngine.Debug;
|
||||||
|
using Unity.VisualScripting;
|
||||||
|
|
||||||
namespace GadGame.Network
|
namespace GadGame.Network
|
||||||
{
|
{
|
||||||
|
@ -23,9 +26,28 @@ namespace GadGame.Network
|
||||||
private UdpClient _client;
|
private UdpClient _client;
|
||||||
private IPEndPoint _remoteEndPoint;
|
private IPEndPoint _remoteEndPoint;
|
||||||
|
|
||||||
|
private Process _process;
|
||||||
private void Start()
|
private void Start()
|
||||||
{
|
{
|
||||||
|
_process = new Process();
|
||||||
|
_process.StartInfo.FileName = "/bin/sh";
|
||||||
|
_process.StartInfo.Arguments = $"{Application.streamingAssetsPath}/MergeFace/run.sh";
|
||||||
|
_process.StartInfo.WorkingDirectory = $"{Application.streamingAssetsPath}/MergeFace";
|
||||||
|
|
||||||
|
_process.StartInfo.RedirectStandardOutput = true;
|
||||||
|
_process.StartInfo.RedirectStandardError = true;
|
||||||
|
|
||||||
|
_process.StartInfo.CreateNoWindow = false;
|
||||||
|
_process.StartInfo.UseShellExecute = false;
|
||||||
|
_process.OutputDataReceived += (sender, a) => {
|
||||||
|
Debug.Log(a.Data);
|
||||||
|
};
|
||||||
|
_process.ErrorDataReceived += (sender, a) => {
|
||||||
|
Debug.LogError(a.Data);
|
||||||
|
};
|
||||||
|
_process.Start();
|
||||||
|
|
||||||
|
|
||||||
// Create remote endpoint
|
// Create remote endpoint
|
||||||
_remoteEndPoint = new IPEndPoint(IPAddress.Parse(_ip), _sendPort);
|
_remoteEndPoint = new IPEndPoint(IPAddress.Parse(_ip), _sendPort);
|
||||||
|
|
||||||
|
@ -87,6 +109,10 @@ namespace GadGame.Network
|
||||||
_receiveThread.Abort();
|
_receiveThread.Abort();
|
||||||
|
|
||||||
_client.Close();
|
_client.Close();
|
||||||
|
|
||||||
|
_process.Close();
|
||||||
|
_process.CloseMainWindow();
|
||||||
|
_process.WaitForExit();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -20,6 +20,7 @@ namespace GadGame.State.MainFlowState
|
||||||
|
|
||||||
public override void Update(float time)
|
public override void Update(float time)
|
||||||
{
|
{
|
||||||
|
Runner.EncodeImage.Raise(UdpSocket.Instance.DataReceived.StreamingData);
|
||||||
if(!UdpSocket.Instance.DataReceived.PassBy) {
|
if(!UdpSocket.Instance.DataReceived.PassBy) {
|
||||||
Runner.SetState<IdleState>();
|
Runner.SetState<IdleState>();
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -6,7 +6,7 @@ namespace GadGame.State.MainFlowState
|
||||||
{
|
{
|
||||||
public override void Enter()
|
public override void Enter()
|
||||||
{
|
{
|
||||||
Runner.GenerateImageSuccess.Raise();
|
Runner.GenerateImageSuccess.Raise(UdpSocket.Instance.DataReceived.Description);
|
||||||
}
|
}
|
||||||
|
|
||||||
public override void Update(float time)
|
public override void Update(float time)
|
||||||
|
|
After Width: | Height: | Size: 394 KiB |
|
@ -0,0 +1,140 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 72c164f587bf6158ebb542742fefd4d9
|
||||||
|
TextureImporter:
|
||||||
|
internalIDToNameTable: []
|
||||||
|
externalObjects: {}
|
||||||
|
serializedVersion: 12
|
||||||
|
mipmaps:
|
||||||
|
mipMapMode: 0
|
||||||
|
enableMipMap: 0
|
||||||
|
sRGBTexture: 1
|
||||||
|
linearTexture: 0
|
||||||
|
fadeOut: 0
|
||||||
|
borderMipMap: 0
|
||||||
|
mipMapsPreserveCoverage: 0
|
||||||
|
alphaTestReferenceValue: 0.5
|
||||||
|
mipMapFadeDistanceStart: 1
|
||||||
|
mipMapFadeDistanceEnd: 3
|
||||||
|
bumpmap:
|
||||||
|
convertToNormalMap: 0
|
||||||
|
externalNormalMap: 0
|
||||||
|
heightScale: 0.25
|
||||||
|
normalMapFilter: 0
|
||||||
|
flipGreenChannel: 0
|
||||||
|
isReadable: 0
|
||||||
|
streamingMipmaps: 0
|
||||||
|
streamingMipmapsPriority: 0
|
||||||
|
vTOnly: 0
|
||||||
|
ignoreMipmapLimit: 0
|
||||||
|
grayScaleToAlpha: 0
|
||||||
|
generateCubemap: 6
|
||||||
|
cubemapConvolution: 0
|
||||||
|
seamlessCubemap: 0
|
||||||
|
textureFormat: 1
|
||||||
|
maxTextureSize: 2048
|
||||||
|
textureSettings:
|
||||||
|
serializedVersion: 2
|
||||||
|
filterMode: 1
|
||||||
|
aniso: 1
|
||||||
|
mipBias: 0
|
||||||
|
wrapU: 1
|
||||||
|
wrapV: 1
|
||||||
|
wrapW: 0
|
||||||
|
nPOTScale: 0
|
||||||
|
lightmap: 0
|
||||||
|
compressionQuality: 50
|
||||||
|
spriteMode: 1
|
||||||
|
spriteExtrude: 1
|
||||||
|
spriteMeshType: 1
|
||||||
|
alignment: 0
|
||||||
|
spritePivot: {x: 0.5, y: 0.5}
|
||||||
|
spritePixelsToUnits: 100
|
||||||
|
spriteBorder: {x: 0, y: 0, z: 0, w: 0}
|
||||||
|
spriteGenerateFallbackPhysicsShape: 1
|
||||||
|
alphaUsage: 1
|
||||||
|
alphaIsTransparency: 1
|
||||||
|
spriteTessellationDetail: -1
|
||||||
|
textureType: 8
|
||||||
|
textureShape: 1
|
||||||
|
singleChannelComponent: 0
|
||||||
|
flipbookRows: 1
|
||||||
|
flipbookColumns: 1
|
||||||
|
maxTextureSizeSet: 0
|
||||||
|
compressionQualitySet: 0
|
||||||
|
textureFormatSet: 0
|
||||||
|
ignorePngGamma: 0
|
||||||
|
applyGammaDecoding: 0
|
||||||
|
swizzle: 50462976
|
||||||
|
cookieLightType: 0
|
||||||
|
platformSettings:
|
||||||
|
- serializedVersion: 3
|
||||||
|
buildTarget: DefaultTexturePlatform
|
||||||
|
maxTextureSize: 2048
|
||||||
|
resizeAlgorithm: 0
|
||||||
|
textureFormat: -1
|
||||||
|
textureCompression: 1
|
||||||
|
compressionQuality: 50
|
||||||
|
crunchedCompression: 0
|
||||||
|
allowsAlphaSplitting: 0
|
||||||
|
overridden: 0
|
||||||
|
ignorePlatformSupport: 0
|
||||||
|
androidETC2FallbackOverride: 0
|
||||||
|
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||||
|
- serializedVersion: 3
|
||||||
|
buildTarget: Standalone
|
||||||
|
maxTextureSize: 2048
|
||||||
|
resizeAlgorithm: 0
|
||||||
|
textureFormat: -1
|
||||||
|
textureCompression: 1
|
||||||
|
compressionQuality: 50
|
||||||
|
crunchedCompression: 0
|
||||||
|
allowsAlphaSplitting: 0
|
||||||
|
overridden: 0
|
||||||
|
ignorePlatformSupport: 0
|
||||||
|
androidETC2FallbackOverride: 0
|
||||||
|
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||||
|
- serializedVersion: 3
|
||||||
|
buildTarget: Android
|
||||||
|
maxTextureSize: 2048
|
||||||
|
resizeAlgorithm: 0
|
||||||
|
textureFormat: -1
|
||||||
|
textureCompression: 1
|
||||||
|
compressionQuality: 50
|
||||||
|
crunchedCompression: 0
|
||||||
|
allowsAlphaSplitting: 0
|
||||||
|
overridden: 0
|
||||||
|
ignorePlatformSupport: 0
|
||||||
|
androidETC2FallbackOverride: 0
|
||||||
|
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||||
|
- serializedVersion: 3
|
||||||
|
buildTarget: Server
|
||||||
|
maxTextureSize: 2048
|
||||||
|
resizeAlgorithm: 0
|
||||||
|
textureFormat: -1
|
||||||
|
textureCompression: 1
|
||||||
|
compressionQuality: 50
|
||||||
|
crunchedCompression: 0
|
||||||
|
allowsAlphaSplitting: 0
|
||||||
|
overridden: 0
|
||||||
|
ignorePlatformSupport: 0
|
||||||
|
androidETC2FallbackOverride: 0
|
||||||
|
forceMaximumCompressionQuality_BC6H_BC7: 0
|
||||||
|
spriteSheet:
|
||||||
|
serializedVersion: 2
|
||||||
|
sprites: []
|
||||||
|
outline: []
|
||||||
|
physicsShape: []
|
||||||
|
bones: []
|
||||||
|
spriteID: 5e97eb03825dee720800000000000000
|
||||||
|
internalID: 0
|
||||||
|
vertices: []
|
||||||
|
indices:
|
||||||
|
edges: []
|
||||||
|
weights: []
|
||||||
|
secondaryTextures: []
|
||||||
|
nameFileIdTable: {}
|
||||||
|
mipmapLimitGroupName:
|
||||||
|
pSDRemoveMatte: 0
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 560a88da2bbc70140bed167f0ba7fe37
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: fb01be13d6e88ca488dda82150319bfc
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 117dcc671050f5247bd8743b91ecaab7
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 9db2f9566d5996ffbac9835756d8f951
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 88c2d151c368129198308dbd30d9d750
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,126 @@
|
||||||
|
from Facenet.models.mtcnn import MTCNN
|
||||||
|
from AgeNet.models import Model
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms as T
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
age_range_list = ["1-10", "11-20", "21-30", "31-40", "41-50", "51-60", "61-70", "71-80", "81-90", "91-100"]
|
||||||
|
|
||||||
|
|
||||||
|
class AgeEstimator():
|
||||||
|
def __init__(self, face_size=224, weights=None, device='cpu', tpx=500):
|
||||||
|
self.thickness_per_pixels = tpx
|
||||||
|
|
||||||
|
if isinstance(face_size, int):
|
||||||
|
self.face_size = (face_size, face_size)
|
||||||
|
else:
|
||||||
|
self.face_size = face_size
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
self.device = device
|
||||||
|
if isinstance(device, str):
|
||||||
|
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
|
||||||
|
self.device = torch.device(device)
|
||||||
|
else:
|
||||||
|
self.device = torch.device('cpu')
|
||||||
|
|
||||||
|
self.facenet_model = MTCNN(device=self.device)
|
||||||
|
|
||||||
|
self.model = Model().to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
if weights:
|
||||||
|
self.model.load_state_dict(torch.load(weights, map_location=torch.device('cpu')))
|
||||||
|
# print('Weights loaded successfully from path:', weights)
|
||||||
|
# print('====================================================')
|
||||||
|
|
||||||
|
def transform(self, image):
|
||||||
|
return T.Compose(
|
||||||
|
[
|
||||||
|
T.Resize(self.face_size),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
]
|
||||||
|
)(image)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess_image(image, face_size=(224, 224)):
|
||||||
|
|
||||||
|
# Resize the image to the target size
|
||||||
|
resized_transform = T.Resize(face_size)
|
||||||
|
|
||||||
|
resized_image = resized_transform(image)
|
||||||
|
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def padding_face(box, padding=20):
|
||||||
|
return [
|
||||||
|
box[0] - padding,
|
||||||
|
box[1] - padding,
|
||||||
|
box[2] + padding,
|
||||||
|
box[3] + padding
|
||||||
|
]
|
||||||
|
|
||||||
|
def predict_from_frame(self, frame, min_prob=0.9):
|
||||||
|
image = Image.fromarray(frame)
|
||||||
|
|
||||||
|
ndarray_image = np.array(image)
|
||||||
|
image_shape = ndarray_image.shape
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
bboxes, prob = self.facenet_model.detect(image)
|
||||||
|
bboxes = bboxes[prob > min_prob]
|
||||||
|
|
||||||
|
face_images = []
|
||||||
|
|
||||||
|
for box in bboxes:
|
||||||
|
box = np.clip(box, 0, np.inf).astype(np.uint32)
|
||||||
|
|
||||||
|
padding = max(image_shape) * 5 / self.thickness_per_pixels
|
||||||
|
padding = int(max(padding, 10))
|
||||||
|
box = self.padding_face(box, padding)
|
||||||
|
|
||||||
|
face = image.crop(box)
|
||||||
|
transformed_face = self.transform(face)
|
||||||
|
face_images.append(transformed_face)
|
||||||
|
|
||||||
|
face_images = torch.stack(face_images, dim=0).to(self.device)
|
||||||
|
|
||||||
|
genders, ages = self.model(face_images)
|
||||||
|
ages = torch.round(ages).long()
|
||||||
|
|
||||||
|
for i, box in enumerate(bboxes):
|
||||||
|
box = np.clip(box, 0, np.inf).astype(np.uint32)
|
||||||
|
thickness = max(image_shape) / 400
|
||||||
|
thickness = int(max(np.ceil(thickness), 1))
|
||||||
|
age_range = age_range_list[int(ages[i].item() / 10)]
|
||||||
|
gender = round(genders[i].item(), 2)
|
||||||
|
|
||||||
|
return [gender, age_range]
|
||||||
|
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def Prediction(frame, weights="weights.pt", face_size=64, device='cuda'):
|
||||||
|
# Initialize the AgeEstimator
|
||||||
|
model = AgeEstimator(weights=weights, face_size=face_size, device=device)
|
||||||
|
|
||||||
|
# Open a connection to the camera (use 0 for default camera)
|
||||||
|
|
||||||
|
# Convert the frame to RGB format (OpenCV uses BGR by default)
|
||||||
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Predict age and gender from the frame
|
||||||
|
predicted_image = model.predict_from_frame(rgb_frame)
|
||||||
|
|
||||||
|
return predicted_image
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: dee94761336911fe091215ccd6f1527c
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 29274157d02d7d616ac2fc3558f8efcb
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: a6af8aa688a0742cab6a5c8e00e9f00a
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 1cb595d88a310b7729f9182bcd8e2ae6
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,65 @@
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms as T
|
||||||
|
|
||||||
|
class UTKFaceDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, transform = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.filename_list = os.listdir(root_dir)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.filename_list)
|
||||||
|
|
||||||
|
def age_to_class(self, age):
|
||||||
|
range_ages = [0, 4, 9, 15, 25, 35, 45, 60, 75]
|
||||||
|
if age > max(range_ages):
|
||||||
|
return len(range_ages) - 1
|
||||||
|
for i in range(len(range_ages) - 1):
|
||||||
|
if range_ages[i] <= age <= range_ages[i + 1]:
|
||||||
|
return i
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
filename = self.filename_list[idx]
|
||||||
|
|
||||||
|
info = filename.split("_")
|
||||||
|
age = int(info[0])
|
||||||
|
age_label = self.age_to_class(age)
|
||||||
|
gender = int(info[1])
|
||||||
|
|
||||||
|
filename = os.path.join(self.root_dir, filename)
|
||||||
|
|
||||||
|
image = Image.open(filename)
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
image = self.transform(image)
|
||||||
|
|
||||||
|
return image, gender, age_label, age
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
image_size = (64,64)
|
||||||
|
root_dir = '/kaggle/input/utkface-new/UTKFace'
|
||||||
|
batch_size = 512
|
||||||
|
num_workers = os.cpu_count()
|
||||||
|
|
||||||
|
train_transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.Resize(image_size),
|
||||||
|
T.RandomHorizontalFlip(0.5),
|
||||||
|
T.RandomRotation(10),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = UTKFaceDataset(root_dir, transform = train_transform)
|
||||||
|
trainloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True)
|
||||||
|
|
||||||
|
for images, genders, age_labels, ages in trainloader:
|
||||||
|
print(images.shape, genders.shape, ages.shape)
|
||||||
|
break
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: a9132692d4a48bed7872f706c57c2e91
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,66 @@
|
||||||
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def evaluate(gender_model, age_range_model, age_estimation_model, val_dataloader, device = 'cpu', verbose = 0):
|
||||||
|
# set device
|
||||||
|
if isinstance(device, str):
|
||||||
|
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
|
||||||
|
device = torch.device(device)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
AgeRangeLoss = nn.CrossEntropyLoss()
|
||||||
|
GenderLoss = nn.BCELoss()
|
||||||
|
AgeEstimationLoss = nn.L1Loss()
|
||||||
|
|
||||||
|
gender_model = gender_model.to(device)
|
||||||
|
age_range_model = age_range_model.to(device)
|
||||||
|
age_estimation_model = age_estimation_model.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
age_range_model.eval()
|
||||||
|
gender_model.eval()
|
||||||
|
age_estimation_model.eval()
|
||||||
|
age_accuracy = 0
|
||||||
|
gender_accuracy = 0
|
||||||
|
total_age_loss = 0
|
||||||
|
total_gender_loss = 0
|
||||||
|
total_age_estimation_loss = 0
|
||||||
|
if verbose == 1:
|
||||||
|
val_dataloader = tqdm(val_dataloader, desc = 'Evaluate: ', ncols = 100)
|
||||||
|
|
||||||
|
for images, genders, age_labels, ages in val_dataloader:
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
|
||||||
|
images, genders, age_labels, ages = images.to(device), genders.to(device), age_labels.to(device), ages.to(device)
|
||||||
|
|
||||||
|
pred_genders = gender_model(images).view(-1)
|
||||||
|
pred_age_labels = age_range_model(images)
|
||||||
|
|
||||||
|
age_loss = AgeRangeLoss(pred_age_labels, age_labels.long())
|
||||||
|
gender_loss = GenderLoss(pred_genders, genders.float())
|
||||||
|
|
||||||
|
total_age_loss += age_loss.item()
|
||||||
|
total_gender_loss += gender_loss.item()
|
||||||
|
|
||||||
|
gender_acc = torch.sum(torch.round(pred_genders) == genders)/batch_size
|
||||||
|
age_acc = torch.sum(torch.argmax(pred_age_labels, dim = 1) == age_labels)/batch_size
|
||||||
|
|
||||||
|
age_accuracy += age_acc
|
||||||
|
gender_accuracy += gender_acc
|
||||||
|
|
||||||
|
estimated_ages = age_estimation_model(images, age_labels).view(-1)
|
||||||
|
age_estimation_loss = AgeEstimationLoss(ages, estimated_ages)
|
||||||
|
|
||||||
|
total_age_estimation_loss += age_estimation_loss.item()
|
||||||
|
|
||||||
|
val_age_loss = total_age_loss / len(val_dataloader)
|
||||||
|
val_gender_loss = total_gender_loss / len(val_dataloader)
|
||||||
|
val_age_accuracy = age_accuracy / len(val_dataloader)
|
||||||
|
val_gender_accuracy = gender_accuracy / len(val_dataloader)
|
||||||
|
val_age_estimation_loss = total_age_estimation_loss / len(val_dataloader)
|
||||||
|
|
||||||
|
return val_age_loss, val_gender_loss, val_age_accuracy, val_gender_accuracy, val_age_estimation_loss
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: b1bc431225702cd32bc5814aa87bfcba
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,196 @@
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
class GenderClassificationModel(nn.Module):
|
||||||
|
"VGG-Face"
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
|
||||||
|
self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||||
|
self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||||
|
self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||||
|
self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
|
||||||
|
self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
|
||||||
|
self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
|
||||||
|
self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
|
||||||
|
self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||||
|
self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||||
|
self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||||
|
self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||||
|
self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
|
||||||
|
self.fc6 = nn.Linear(2048, 4096)
|
||||||
|
self.fc7 = nn.Linear(4096, 4096)
|
||||||
|
self.fc8 = nn.Linear(4096, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
""" Pytorch forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: input image (224x224)
|
||||||
|
|
||||||
|
Returns: class logits
|
||||||
|
|
||||||
|
"""
|
||||||
|
x = F.relu(self.conv_1_1(x))
|
||||||
|
x = F.relu(self.conv_1_2(x))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = F.relu(self.conv_2_1(x))
|
||||||
|
x = F.relu(self.conv_2_2(x))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = F.relu(self.conv_3_1(x))
|
||||||
|
x = F.relu(self.conv_3_2(x))
|
||||||
|
x = F.relu(self.conv_3_3(x))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = F.relu(self.conv_4_1(x))
|
||||||
|
x = F.relu(self.conv_4_2(x))
|
||||||
|
x = F.relu(self.conv_4_3(x))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = F.relu(self.conv_5_1(x))
|
||||||
|
x = F.relu(self.conv_5_2(x))
|
||||||
|
x = F.relu(self.conv_5_3(x))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = F.relu(self.fc6(x))
|
||||||
|
x = F.dropout(x, 0.5, self.training)
|
||||||
|
x = F.relu(self.fc7(x))
|
||||||
|
x = F.dropout(x, 0.5, self.training)
|
||||||
|
return F.sigmoid(self.fc8(x))
|
||||||
|
|
||||||
|
class AgeRangeModel(nn.Module):
|
||||||
|
def __init__(self, in_channels = 3, backbone = 'resnet50', pretrained = False, num_classes = 9):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.Conv1 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, 64, 3, 1, padding = 1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(0.3),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Conv2 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 256, 3, 1, padding = 1),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(0.3),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Conv3 = nn.Sequential(
|
||||||
|
nn.Conv2d(256, 512, 3, 1, padding = 1),
|
||||||
|
nn.BatchNorm2d(512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(0.3),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adap = nn.AdaptiveAvgPool2d((2,2))
|
||||||
|
|
||||||
|
self.out_age = nn.Sequential(
|
||||||
|
nn.Linear(2048, num_classes)
|
||||||
|
# nn.Softmax(dim = 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
x = self.Conv1(x)
|
||||||
|
x = self.Conv2(x)
|
||||||
|
x = self.Conv3(x)
|
||||||
|
|
||||||
|
x = self.adap(x)
|
||||||
|
|
||||||
|
x = x.view(batch_size, -1)
|
||||||
|
|
||||||
|
x = self.out_age(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class AgeEstimationModel(nn.Module):
|
||||||
|
"VGG-Face"
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_layer = nn.Embedding(9, 64)
|
||||||
|
|
||||||
|
self.Conv1 = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 64, 3, 1, padding = 1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(0.3),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Conv2 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 256, 3, 1, padding = 1),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(0.3),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Conv3 = nn.Sequential(
|
||||||
|
nn.Conv2d(256, 512, 3, 1, padding = 1),
|
||||||
|
nn.BatchNorm2d(512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout2d(0.3),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adap = nn.AdaptiveAvgPool2d((2,2))
|
||||||
|
|
||||||
|
self.out_age = nn.Sequential(
|
||||||
|
nn.Linear(2048 + 64, 1),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
x = self.Conv1(x)
|
||||||
|
x = self.Conv2(x)
|
||||||
|
x = self.Conv3(x)
|
||||||
|
|
||||||
|
x = self.adap(x)
|
||||||
|
|
||||||
|
x = x.view(batch_size, -1)
|
||||||
|
|
||||||
|
y = self.embedding_layer(y)
|
||||||
|
|
||||||
|
x = torch.cat([x,y], dim = 1)
|
||||||
|
|
||||||
|
x = self.out_age(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
self.gender_model = GenderClassificationModel()
|
||||||
|
|
||||||
|
self.age_range_model = AgeRangeModel()
|
||||||
|
|
||||||
|
self.age_estimation_model = AgeEstimationModel()
|
||||||
|
|
||||||
|
def forward(self,x):
|
||||||
|
"""x: batch, 3, 64, 64"""
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[None, ...]
|
||||||
|
|
||||||
|
predicted_genders = self.gender_model(x)
|
||||||
|
|
||||||
|
age_ranges = self.age_range_model(x)
|
||||||
|
|
||||||
|
y = torch.argmax(age_ranges, dim = 1).view(-1,)
|
||||||
|
|
||||||
|
estimated_ages = self.age_estimation_model(x, y)
|
||||||
|
|
||||||
|
return predicted_genders, estimated_ages
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = Model()
|
||||||
|
x = torch.rand(2,3,64,64)
|
||||||
|
genders, ages = model(x) #
|
||||||
|
|
||||||
|
print(genders.shape, ages.shape)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: bfdd42c2ab7b1f10ab08c676ab819c1b
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,188 @@
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
from .utils import *
|
||||||
|
from .eval import *
|
||||||
|
from .models import *
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def train(train_data_dir, weights, device = 'cpu', image_size = 64, batch_size = 128, num_epochs = 100, steps_per_epoch = None,
|
||||||
|
val_data_dir = None, validation_split = None, save_history = True):
|
||||||
|
|
||||||
|
# đặt val_data_dir and validation_split không đồng thời khác None
|
||||||
|
assert not(val_data_dir is not None and validation_split is not None)
|
||||||
|
|
||||||
|
if isinstance(device, str):
|
||||||
|
if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
|
||||||
|
device = torch.device(device)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
# get train_loader
|
||||||
|
train_data = get_dataloader(train_data_dir, image_size = image_size, batch_size = batch_size)
|
||||||
|
|
||||||
|
# chia dữ liệu thành 2 tập train và val
|
||||||
|
if val_data_dir is not None:
|
||||||
|
val_data = get_dataloader(val_data_dir, image_size = image_size, batch_size = batch_size)
|
||||||
|
elif validation_split is not None:
|
||||||
|
train_data, val_data = split_dataloader(train_data, validation_split)
|
||||||
|
else:
|
||||||
|
val_data = None
|
||||||
|
|
||||||
|
if steps_per_epoch is None:
|
||||||
|
steps_per_epoch = len(train_data)
|
||||||
|
|
||||||
|
# steps per epoch
|
||||||
|
num_steps = len(train_data)
|
||||||
|
iterator = iter(train_data)
|
||||||
|
count_steps = 1
|
||||||
|
|
||||||
|
# add model to device
|
||||||
|
gender_model = GenderClassificationModel().to(device)
|
||||||
|
age_range_model = AgeRangeModel().to(device)
|
||||||
|
age_estimation_model = AgeEstimationModel().to(device)
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
AgeRangeLoss = nn.CrossEntropyLoss()
|
||||||
|
GenderLoss = nn.BCELoss()
|
||||||
|
AgeEstimationLoss = nn.L1Loss()
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
gen_optimizer = torch.optim.Adam(gender_model.parameters(), lr = 1e-4)
|
||||||
|
age_range_optimizer = torch.optim.Adam(age_range_model.parameters(), lr = 5e-3)
|
||||||
|
age_estimation_optimizer = torch.optim.Adam(age_estimation_model.parameters(), lr = 1e-3)
|
||||||
|
|
||||||
|
#schedular
|
||||||
|
age_range_scheduler = ReduceLROnPlateau(age_range_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
||||||
|
gender_scheduler = ReduceLROnPlateau(gen_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
||||||
|
age_estimation_scheduler = ReduceLROnPlateau(age_estimation_optimizer, mode = 'min', factor = 0.1, patience = 3, verbose = 1)
|
||||||
|
|
||||||
|
# history
|
||||||
|
history = {
|
||||||
|
'train_gender_loss': [],
|
||||||
|
'train_age_loss': [],
|
||||||
|
'train_gender_acc': [],
|
||||||
|
'train_age_acc': [],
|
||||||
|
'val_gender_loss': [],
|
||||||
|
'val_age_loss': [],
|
||||||
|
'val_gender_acc': [],
|
||||||
|
'val_age_acc': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for epoch in range(1, num_epochs + 1):
|
||||||
|
total_gender_loss = 0
|
||||||
|
total_age_loss = 0
|
||||||
|
age_accuracy = 0
|
||||||
|
gender_accuracy = 0
|
||||||
|
|
||||||
|
total_age_estimation_loss = 0
|
||||||
|
|
||||||
|
gender_model.train()
|
||||||
|
age_range_model.train()
|
||||||
|
age_estimation_model.train()
|
||||||
|
|
||||||
|
for step in tqdm(range(steps_per_epoch), desc = f'Epoch {epoch}/{num_epochs}: ', ncols = 100):
|
||||||
|
images, genders, age_labels, ages = next(iterator)
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
|
||||||
|
images, genders, age_labels, ages = images.to(device), genders.to(device), age_labels.to(device), ages.to(device)
|
||||||
|
|
||||||
|
pred_genders = gender_model(images).view(-1)
|
||||||
|
pred_age_labels = age_range_model(images)
|
||||||
|
|
||||||
|
age_loss = AgeRangeLoss(pred_age_labels, age_labels.long())
|
||||||
|
|
||||||
|
gender_loss = GenderLoss(pred_genders, genders.float())
|
||||||
|
|
||||||
|
total_age_loss += age_loss.item()
|
||||||
|
total_gender_loss += gender_loss.item()
|
||||||
|
|
||||||
|
gender_acc = torch.sum(torch.round(pred_genders) == genders)/batch_size
|
||||||
|
age_acc = torch.sum(torch.argmax(pred_age_labels, dim = 1) == age_labels)/batch_size
|
||||||
|
|
||||||
|
age_accuracy += age_acc
|
||||||
|
gender_accuracy += gender_acc
|
||||||
|
|
||||||
|
age_range_optimizer.zero_grad()
|
||||||
|
age_loss.backward()
|
||||||
|
age_range_optimizer.step()
|
||||||
|
|
||||||
|
gen_optimizer.zero_grad()
|
||||||
|
gender_loss.backward()
|
||||||
|
gen_optimizer.step()
|
||||||
|
|
||||||
|
# age estimation loss
|
||||||
|
estimated_ages = age_estimation_model(images, age_labels).view(-1)
|
||||||
|
age_estimation_loss = AgeEstimationLoss(ages, estimated_ages)
|
||||||
|
|
||||||
|
age_estimation_optimizer.zero_grad()
|
||||||
|
age_estimation_loss.backward()
|
||||||
|
age_estimation_optimizer.step()
|
||||||
|
|
||||||
|
total_age_estimation_loss += age_estimation_loss.item()
|
||||||
|
|
||||||
|
# nếu nó duyệt hết qua tập dữ liệu thì cho nó lặp lại 1 lần nữa
|
||||||
|
if count_steps == num_steps:
|
||||||
|
iterator = iter(train_data)
|
||||||
|
count_steps = 0
|
||||||
|
count_steps += 1
|
||||||
|
|
||||||
|
train_age_loss = total_age_loss / steps_per_epoch
|
||||||
|
train_gender_loss = total_gender_loss / steps_per_epoch
|
||||||
|
train_age_accuracy = age_accuracy / steps_per_epoch
|
||||||
|
train_gender_accuracy = gender_accuracy / steps_per_epoch
|
||||||
|
|
||||||
|
train_age_estimation_loss = total_age_estimation_loss/steps_per_epoch
|
||||||
|
|
||||||
|
history['train_age_loss'].append(float(train_age_loss))
|
||||||
|
history['train_gender_loss'].append(float(train_gender_loss))
|
||||||
|
history['train_age_acc'].append(float(train_age_accuracy))
|
||||||
|
history['train_gender_acc'].append(float(train_gender_accuracy))
|
||||||
|
|
||||||
|
print(f'train_age_loss: {train_age_loss: .2f}, train_gender_loss: {train_gender_loss: .3f}, train_age_accuracy: {train_age_accuracy: .2f}, train_gender_accuracy: {train_gender_accuracy: .2f}, train_age_estimation_loss: {train_age_estimation_loss: .3f}')
|
||||||
|
if val_data:
|
||||||
|
val_age_loss, val_gender_loss, val_age_accuracy, val_gender_accuracy, val_age_estimation_loss = evaluate(gender_model, age_range_model, age_estimation_model, val_data, device = device)
|
||||||
|
history['val_age_loss'].append(float(val_age_loss))
|
||||||
|
history['val_gender_loss'].append(float(val_gender_loss))
|
||||||
|
history['val_age_acc'].append(float(val_age_accuracy))
|
||||||
|
history['val_gender_acc'].append(float(val_gender_accuracy))
|
||||||
|
|
||||||
|
age_range_scheduler.step(np.round(val_age_loss, 3))
|
||||||
|
gender_scheduler.step(np.round(val_gender_loss, 3))
|
||||||
|
age_estimation_scheduler.step(np.round(val_age_estimation_loss, 3))
|
||||||
|
print(f'val_age_loss: {val_age_loss: .2f}, val_gender_loss: {val_gender_loss: .3f}, val_age_accuracy: {val_age_accuracy: .2f}, val_gender_accuracy: {val_gender_accuracy: .2f}, val_age_estimation_loss: {val_age_estimation_loss : .3f}')
|
||||||
|
|
||||||
|
if weights:
|
||||||
|
class dummy_model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.gender_model = gender_model
|
||||||
|
self.age_range_model = age_range_model
|
||||||
|
self.age_estimation_model = age_estimation_model
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return
|
||||||
|
|
||||||
|
model = dummy_model()
|
||||||
|
torch.save(model.state_dict(), weights)
|
||||||
|
print(f'Saved successfully last weights to:', weights)
|
||||||
|
|
||||||
|
if save_history:
|
||||||
|
visualize_history(history, save_history)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
batch_size = 128
|
||||||
|
image_size = 64
|
||||||
|
train_data_dir = '/kaggle/input/utkface-new/UTKFace'
|
||||||
|
device = 'cuda'
|
||||||
|
weights = 'weights\AgeGenderWeights.pt'
|
||||||
|
epochs = 100
|
||||||
|
train(train_data_dir, weights, device = device, steps_per_epoch = None,
|
||||||
|
validation_split = 0.2, image_size = image_size, batch_size = batch_size, save_history = True)
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: d70c80bb1ccf8c73297c6a434b5035dc
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,93 @@
|
||||||
|
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
|
||||||
|
from .dataset import UTKFaceDataset
|
||||||
|
from torchvision import transforms as T
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_dataloader(root_dir, image_size = 64, batch_size = 128, shuffle = True, num_workers = 1):
|
||||||
|
if isinstance(image_size, int):
|
||||||
|
image_size = (image_size, image_size)
|
||||||
|
|
||||||
|
train_transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.Resize(image_size),
|
||||||
|
T.RandomHorizontalFlip(0.2),
|
||||||
|
T.RandomRotation(10),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = UTKFaceDataset(root_dir, transform = train_transform)
|
||||||
|
trainloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = shuffle, num_workers = num_workers, drop_last = True)
|
||||||
|
return trainloader
|
||||||
|
|
||||||
|
def split_dataloader(train_data, validation_split = 0.2):
|
||||||
|
# Chia DataLoader thành phần train và test
|
||||||
|
train_ratio = 1 - validation_split # Tỷ lệ phần train (80%)
|
||||||
|
train_size = int(train_ratio * len(train_data.dataset)) # Số lượng mẫu dùng cho train
|
||||||
|
|
||||||
|
indices = list(range(len(train_data.dataset))) # Danh sách các chỉ số của dataset
|
||||||
|
train_indices = indices[:train_size] # Chỉ số của mẫu dùng cho train
|
||||||
|
val_indices = indices[train_size:] # Chỉ số của mẫu dùng cho test
|
||||||
|
|
||||||
|
# lấy dữ liệu từ dataloader
|
||||||
|
dataset = train_data.dataset
|
||||||
|
batch_size = train_data.batch_size
|
||||||
|
num_workers = train_data.num_workers
|
||||||
|
|
||||||
|
# Tạo ra các SubsetRandomSampler để chọn một phần dữ liệu cho train và test
|
||||||
|
train_sampler = SubsetRandomSampler(train_indices)
|
||||||
|
val_sampler = SubsetRandomSampler(val_indices)
|
||||||
|
|
||||||
|
# Tạo DataLoader mới từ SubsetRandomSampler
|
||||||
|
train_data = DataLoader(dataset, batch_size = batch_size, sampler = train_sampler, num_workers = num_workers, drop_last = True)
|
||||||
|
val_data = DataLoader(dataset, batch_size = batch_size, sampler = val_sampler, num_workers = num_workers, drop_last = True)
|
||||||
|
|
||||||
|
return train_data, val_data
|
||||||
|
|
||||||
|
def visualize_history(history, save_history=True):
|
||||||
|
plt.figure(figsize = (20,8))
|
||||||
|
plt.subplot(131)
|
||||||
|
plt.plot(range(1, len(history['train_age_acc']) + 1), history['train_age_acc'], label = 'train_age_acc', c = 'r')
|
||||||
|
plt.plot(range(1, len(history['val_age_acc']) + 1), history['val_age_acc'], label = 'val_age_acc', c = 'g')
|
||||||
|
plt.plot(range(1, len(history['train_gender_acc']) + 1), history['train_gender_acc'], label = 'train_gender_acc', c = 'b')
|
||||||
|
plt.plot(range(1, len(history['val_gender_acc']) + 1), history['val_gender_acc'], label = 'val_gender_acc', c = 'y')
|
||||||
|
plt.title('Accuracy')
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Gender And Age Prediction Accuracy')
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
plt.subplot(132)
|
||||||
|
plt.plot(range(1, len(history['train_age_loss']) + 1), history['train_age_loss'], label = 'train_age_loss', c = 'r')
|
||||||
|
plt.plot(range(1, len(history['val_age_loss']) + 1), history['val_age_loss'], label = 'val_age_loss', c = 'g')
|
||||||
|
plt.title('Age Loss')
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Age Loss')
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
|
||||||
|
plt.subplot(133)
|
||||||
|
plt.plot(range(1, len(history['train_gender_loss']) + 1), history['train_gender_loss'], label = 'train_gender_loss', c = 'b')
|
||||||
|
plt.plot(range(1, len(history['val_gender_loss']) + 1), history['val_gender_loss'], label = 'val_gender_loss', c = 'y')
|
||||||
|
plt.title('Gender Loss')
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Gender Loss')
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
if save_history:
|
||||||
|
if not os.path.exists("runs"):
|
||||||
|
os.mkdir("runs")
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join('runs', "train")):
|
||||||
|
os.mkdir(os.path.join('runs', "train"))
|
||||||
|
|
||||||
|
exp = os.listdir(os.path.join("runs", 'train'))
|
||||||
|
if len(exp) == 0:
|
||||||
|
last_exp = os.path.join("runs", 'train', 'exp1')
|
||||||
|
os.mkdir(last_exp)
|
||||||
|
else:
|
||||||
|
exp_list = [int(i[3:]) for i in exp]
|
||||||
|
last_exp = os.path.join("runs", 'train', 'exp' + str(int(exp_list[-1]) + 1))
|
||||||
|
os.mkdir(last_exp)
|
||||||
|
plt.savefig(os.path.join(last_exp, "results.png"))
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: b4f35a03a63ebf040b025624d23fe471
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 98e25faaf07e489a624e965684ee8de0756f3ebc
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: fc1c2c2dc78674ff5aa4d67db1ed5b52
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 8688372304eed0b45bb5efd33ebc3392
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2019 Timothy Esler
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: f630914c0f0b597b08fd828a65985720
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,11 @@
|
||||||
|
from .models.inception_resnet_v1 import InceptionResnetV1
|
||||||
|
from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten, fixed_image_standardization
|
||||||
|
from .models.utils.detect_face import extract_face
|
||||||
|
from .models.utils import training
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings(
|
||||||
|
action="ignore",
|
||||||
|
message="This overload of nonzero is deprecated:\n\tnonzero()",
|
||||||
|
category=UserWarning
|
||||||
|
)
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 5663da78e37b6f3a6857c1c6577c7cbd
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 67cd73539a2de632781235ca7d7aa494
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 4e61512c59f6bb79cbd37d2aea24fbeb
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: ba1c774d77bdcf8998b29d38e63a737c
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,3 @@
|
||||||
|
2018*
|
||||||
|
*.json
|
||||||
|
profile.txt
|
After Width: | Height: | Size: 195 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: ab8572daa090e73df85f1b9940917e3e
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 294 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: a4f62947b8e078759b49c81457a887a9
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 1.2 MiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: f4e477d7c34ebe2148591e73fea94996
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 4a4a65ea65f71d3ed9d91a8d4c72f251
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 142d74e0017340668b10f92b8e102eed
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: fa5170ee8c30cf9ae834223e383f10d2
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 1e13d7bf50f86d2b2a29f13492ce0a16
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 5cb20437f0c28e5729e29dac4c25e7ca
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 760 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: d8a521ed41d9eaaeda79409eb7012b7b
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 232f83bfc7690e5629ed698a372308f8
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 4.4 MiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 3d8e513bbbbd722d3afe95cc13f44c24
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: ac46b8e53a69d92b4a55cb1b5a93161d
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 767 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 77ab9c4d3c60fba85857bf9de03df810
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: a4bcc08969f47fae09852a8e9a1fac79
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 921 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 18e226964c759b93bbae123cf4c06a30
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: c9c38cb76534c388890108443f8a7777
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 2.6 MiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: d56516d445a8d2729bc5b6367905199a
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 648250ee8f06135a3ba628c9c1484bfb
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 4bcc4387ae55c265ea87679d87399f2c
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 75 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: ff6b67c1e9331c820a98c1248c0b07c4
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 7f63815e904b464cca4739266ce7caba
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 76 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: aca472f8bd3620e63ac4dc792da4f5f5
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 9ea2b9aead09af7929e7faa6f80a30fc
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 75 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 4e71cfcf1e433cfb5a49b8ea3ae5e5c7
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: bf52100662173c6c7affb1431041e3ab
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 76 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 040d973d841ceee05b7e5e5e499713fa
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 2e747d2b1743b4ac28ebd50040a94842
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 78 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 121246c20607f1fbfa66d525a9289025
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,8 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 3dcc96f5e2b21bb19a85662ada3056cb
|
||||||
|
folderAsset: yes
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 6af4b7dac79e10f35a9563fa58f5adb3
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,279 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Face detection and recognition training pipeline\n",
|
||||||
|
"\n",
|
||||||
|
"The following example illustrates how to fine-tune an InceptionResnetV1 model on your own dataset. This will mostly follow standard pytorch training patterns."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.utils.data import DataLoader, SubsetRandomSampler\n",
|
||||||
|
"from torch import optim\n",
|
||||||
|
"from torch.optim.lr_scheduler import MultiStepLR\n",
|
||||||
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||||
|
"from torchvision import datasets, transforms\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import os"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define run parameters\n",
|
||||||
|
"\n",
|
||||||
|
"The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify `data_dir` to the location of the dataset on wish to finetune on."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"data_dir = '../data/test_images'\n",
|
||||||
|
"\n",
|
||||||
|
"batch_size = 32\n",
|
||||||
|
"epochs = 8\n",
|
||||||
|
"workers = 0 if os.name == 'nt' else 8"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Determine if an nvidia GPU is available"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"print('Running on device: {}'.format(device))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define MTCNN module\n",
|
||||||
|
"\n",
|
||||||
|
"See `help(MTCNN)` for more details."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mtcnn = MTCNN(\n",
|
||||||
|
" image_size=160, margin=0, min_face_size=20,\n",
|
||||||
|
" thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n",
|
||||||
|
" device=device\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Perfom MTCNN facial detection\n",
|
||||||
|
"\n",
|
||||||
|
"Iterate through the DataLoader object and obtain cropped faces."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))\n",
|
||||||
|
"dataset.samples = [\n",
|
||||||
|
" (p, p.replace(data_dir, data_dir + '_cropped'))\n",
|
||||||
|
" for p, _ in dataset.samples\n",
|
||||||
|
"]\n",
|
||||||
|
" \n",
|
||||||
|
"loader = DataLoader(\n",
|
||||||
|
" dataset,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" batch_size=batch_size,\n",
|
||||||
|
" collate_fn=training.collate_pil\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (x, y) in enumerate(loader):\n",
|
||||||
|
" mtcnn(x, save_path=y)\n",
|
||||||
|
" print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')\n",
|
||||||
|
" \n",
|
||||||
|
"# Remove mtcnn to reduce GPU memory usage\n",
|
||||||
|
"del mtcnn"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define Inception Resnet V1 module\n",
|
||||||
|
"\n",
|
||||||
|
"See `help(InceptionResnetV1)` for more details."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"resnet = InceptionResnetV1(\n",
|
||||||
|
" classify=True,\n",
|
||||||
|
" pretrained='vggface2',\n",
|
||||||
|
" num_classes=len(dataset.class_to_idx)\n",
|
||||||
|
").to(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define optimizer, scheduler, dataset, and dataloader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"optimizer = optim.Adam(resnet.parameters(), lr=0.001)\n",
|
||||||
|
"scheduler = MultiStepLR(optimizer, [5, 10])\n",
|
||||||
|
"\n",
|
||||||
|
"trans = transforms.Compose([\n",
|
||||||
|
" np.float32,\n",
|
||||||
|
" transforms.ToTensor(),\n",
|
||||||
|
" fixed_image_standardization\n",
|
||||||
|
"])\n",
|
||||||
|
"dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n",
|
||||||
|
"img_inds = np.arange(len(dataset))\n",
|
||||||
|
"np.random.shuffle(img_inds)\n",
|
||||||
|
"train_inds = img_inds[:int(0.8 * len(img_inds))]\n",
|
||||||
|
"val_inds = img_inds[int(0.8 * len(img_inds)):]\n",
|
||||||
|
"\n",
|
||||||
|
"train_loader = DataLoader(\n",
|
||||||
|
" dataset,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" batch_size=batch_size,\n",
|
||||||
|
" sampler=SubsetRandomSampler(train_inds)\n",
|
||||||
|
")\n",
|
||||||
|
"val_loader = DataLoader(\n",
|
||||||
|
" dataset,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" batch_size=batch_size,\n",
|
||||||
|
" sampler=SubsetRandomSampler(val_inds)\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define loss and evaluation functions"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"loss_fn = torch.nn.CrossEntropyLoss()\n",
|
||||||
|
"metrics = {\n",
|
||||||
|
" 'fps': training.BatchTimer(),\n",
|
||||||
|
" 'acc': training.accuracy\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Train model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"writer = SummaryWriter()\n",
|
||||||
|
"writer.iteration, writer.interval = 0, 10\n",
|
||||||
|
"\n",
|
||||||
|
"print('\\n\\nInitial')\n",
|
||||||
|
"print('-' * 10)\n",
|
||||||
|
"resnet.eval()\n",
|
||||||
|
"training.pass_epoch(\n",
|
||||||
|
" resnet, loss_fn, val_loader,\n",
|
||||||
|
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||||||
|
" writer=writer\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(epochs):\n",
|
||||||
|
" print('\\nEpoch {}/{}'.format(epoch + 1, epochs))\n",
|
||||||
|
" print('-' * 10)\n",
|
||||||
|
"\n",
|
||||||
|
" resnet.train()\n",
|
||||||
|
" training.pass_epoch(\n",
|
||||||
|
" resnet, loss_fn, train_loader, optimizer, scheduler,\n",
|
||||||
|
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||||||
|
" writer=writer\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" resnet.eval()\n",
|
||||||
|
" training.pass_epoch(\n",
|
||||||
|
" resnet, loss_fn, val_loader,\n",
|
||||||
|
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||||||
|
" writer=writer\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"writer.close()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 7212bac47cb61dc798ee7316d53b8f88
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,245 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Face detection and recognition inference pipeline\n",
|
||||||
|
"\n",
|
||||||
|
"The following example illustrates how to use the `facenet_pytorch` python package to perform face detection and recogition on an image dataset using an Inception Resnet V1 pretrained on the VGGFace2 dataset.\n",
|
||||||
|
"\n",
|
||||||
|
"The following Pytorch methods are included:\n",
|
||||||
|
"* Datasets\n",
|
||||||
|
"* Dataloaders\n",
|
||||||
|
"* GPU/CPU processing"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from facenet_pytorch import MTCNN, InceptionResnetV1\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.utils.data import DataLoader\n",
|
||||||
|
"from torchvision import datasets\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"workers = 0 if os.name == 'nt' else 4"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Determine if an nvidia GPU is available"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Running on device: cuda:0\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"print('Running on device: {}'.format(device))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define MTCNN module\n",
|
||||||
|
"\n",
|
||||||
|
"Default params shown for illustration, but not needed. Note that, since MTCNN is a collection of neural nets and other code, the device must be passed in the following way to enable copying of objects when needed internally.\n",
|
||||||
|
"\n",
|
||||||
|
"See `help(MTCNN)` for more details."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mtcnn = MTCNN(\n",
|
||||||
|
" image_size=160, margin=0, min_face_size=20,\n",
|
||||||
|
" thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n",
|
||||||
|
" device=device\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define Inception Resnet V1 module\n",
|
||||||
|
"\n",
|
||||||
|
"Set classify=True for pretrained classifier. For this example, we will use the model to output embeddings/CNN features. Note that for inference, it is important to set the model to `eval` mode.\n",
|
||||||
|
"\n",
|
||||||
|
"See `help(InceptionResnetV1)` for more details."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Define a dataset and data loader\n",
|
||||||
|
"\n",
|
||||||
|
"We add the `idx_to_class` attribute to the dataset to enable easy recoding of label indices to identity names later one."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def collate_fn(x):\n",
|
||||||
|
" return x[0]\n",
|
||||||
|
"\n",
|
||||||
|
"dataset = datasets.ImageFolder('../data/test_images')\n",
|
||||||
|
"dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}\n",
|
||||||
|
"loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Perfom MTCNN facial detection\n",
|
||||||
|
"\n",
|
||||||
|
"Iterate through the DataLoader object and detect faces and associated detection probabilities for each. The `MTCNN` forward method returns images cropped to the detected face, if a face was detected. By default only a single detected face is returned - to have `MTCNN` return all detected faces, set `keep_all=True` when creating the MTCNN object above.\n",
|
||||||
|
"\n",
|
||||||
|
"To obtain bounding boxes rather than cropped face images, you can instead call the lower-level `mtcnn.detect()` function. See `help(mtcnn.detect)` for details."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Face detected with probability: 0.999957\n",
|
||||||
|
"Face detected with probability: 0.999927\n",
|
||||||
|
"Face detected with probability: 0.999662\n",
|
||||||
|
"Face detected with probability: 0.999873\n",
|
||||||
|
"Face detected with probability: 0.999991\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"aligned = []\n",
|
||||||
|
"names = []\n",
|
||||||
|
"for x, y in loader:\n",
|
||||||
|
" x_aligned, prob = mtcnn(x, return_prob=True)\n",
|
||||||
|
" if x_aligned is not None:\n",
|
||||||
|
" print('Face detected with probability: {:8f}'.format(prob))\n",
|
||||||
|
" aligned.append(x_aligned)\n",
|
||||||
|
" names.append(dataset.idx_to_class[y])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Calculate image embeddings\n",
|
||||||
|
"\n",
|
||||||
|
"MTCNN will return images of faces all the same size, enabling easy batch processing with the Resnet recognition module. Here, since we only have a few images, we build a single batch and perform inference on it. \n",
|
||||||
|
"\n",
|
||||||
|
"For real datasets, code should be modified to control batch sizes being passed to the Resnet, particularly if being processed on a GPU. For repeated testing, it is best to separate face detection (using MTCNN) from embedding or classification (using InceptionResnetV1), as calculation of cropped faces or bounding boxes can then be performed a single time and detected faces saved for future use."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"aligned = torch.stack(aligned).to(device)\n",
|
||||||
|
"embeddings = resnet(aligned).detach().cpu()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Print distance matrix for classes"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" angelina_jolie bradley_cooper kate_siegel paul_rudd \\\n",
|
||||||
|
"angelina_jolie 0.000000 1.344806 0.781201 1.425579 \n",
|
||||||
|
"bradley_cooper 1.344806 0.000000 1.256238 0.922126 \n",
|
||||||
|
"kate_siegel 0.781201 1.256238 0.000000 1.366423 \n",
|
||||||
|
"paul_rudd 1.425579 0.922126 1.366423 0.000000 \n",
|
||||||
|
"shea_whigham 1.448495 0.891145 1.416447 0.985438 \n",
|
||||||
|
"\n",
|
||||||
|
" shea_whigham \n",
|
||||||
|
"angelina_jolie 1.448495 \n",
|
||||||
|
"bradley_cooper 0.891145 \n",
|
||||||
|
"kate_siegel 1.416447 \n",
|
||||||
|
"paul_rudd 0.985438 \n",
|
||||||
|
"shea_whigham 0.000000 \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings]\n",
|
||||||
|
"print(pd.DataFrame(dists, columns=names, index=names))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: c791f4d9146d73a68ae3aa4bbb11589f
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
|
@ -0,0 +1,522 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### facenet-pytorch LFW evaluation\n",
|
||||||
|
"This notebook demonstrates how to evaluate performance against the LFW dataset."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training, extract_face\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler\n",
|
||||||
|
"from torchvision import datasets, transforms\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import os"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"data_dir = 'data/lfw/lfw'\n",
|
||||||
|
"pairs_path = 'data/lfw/pairs.txt'\n",
|
||||||
|
"\n",
|
||||||
|
"batch_size = 16\n",
|
||||||
|
"epochs = 15\n",
|
||||||
|
"workers = 0 if os.name == 'nt' else 8"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Running on device: cuda:0\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
"print('Running on device: {}'.format(device))"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mtcnn = MTCNN(\n",
|
||||||
|
" image_size=160,\n",
|
||||||
|
" margin=14,\n",
|
||||||
|
" device=device,\n",
|
||||||
|
" selection_method='center_weighted_size'\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Define the data loader for the input set of images\n",
|
||||||
|
"orig_img_ds = datasets.ImageFolder(data_dir, transform=None)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"# overwrites class labels in dataset with path so path can be used for saving output in mtcnn batches\n",
|
||||||
|
"orig_img_ds.samples = [\n",
|
||||||
|
" (p, p)\n",
|
||||||
|
" for p, _ in orig_img_ds.samples\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"loader = DataLoader(\n",
|
||||||
|
" orig_img_ds,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" batch_size=batch_size,\n",
|
||||||
|
" collate_fn=training.collate_pil\n",
|
||||||
|
")\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"crop_paths = []\n",
|
||||||
|
"box_probs = []\n",
|
||||||
|
"\n",
|
||||||
|
"for i, (x, b_paths) in enumerate(loader):\n",
|
||||||
|
" crops = [p.replace(data_dir, data_dir + '_cropped') for p in b_paths]\n",
|
||||||
|
" mtcnn(x, save_path=crops)\n",
|
||||||
|
" crop_paths.extend(crops)\n",
|
||||||
|
" print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Remove mtcnn to reduce GPU memory usage\n",
|
||||||
|
"del mtcnn\n",
|
||||||
|
"torch.cuda.empty_cache()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# create dataset and data loaders from cropped images output from MTCNN\n",
|
||||||
|
"\n",
|
||||||
|
"trans = transforms.Compose([\n",
|
||||||
|
" np.float32,\n",
|
||||||
|
" transforms.ToTensor(),\n",
|
||||||
|
" fixed_image_standardization\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n",
|
||||||
|
"\n",
|
||||||
|
"embed_loader = DataLoader(\n",
|
||||||
|
" dataset,\n",
|
||||||
|
" num_workers=workers,\n",
|
||||||
|
" batch_size=batch_size,\n",
|
||||||
|
" sampler=SequentialSampler(dataset)\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load pretrained resnet model\n",
|
||||||
|
"resnet = InceptionResnetV1(\n",
|
||||||
|
" classify=False,\n",
|
||||||
|
" pretrained='vggface2'\n",
|
||||||
|
").to(device)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"classes = []\n",
|
||||||
|
"embeddings = []\n",
|
||||||
|
"resnet.eval()\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for xb, yb in embed_loader:\n",
|
||||||
|
" xb = xb.to(device)\n",
|
||||||
|
" b_embeddings = resnet(xb)\n",
|
||||||
|
" b_embeddings = b_embeddings.to('cpu').numpy()\n",
|
||||||
|
" classes.extend(yb.numpy())\n",
|
||||||
|
" embeddings.extend(b_embeddings)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings_dict = dict(zip(crop_paths,embeddings))\n",
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"#### Evaluate embeddings by using distance metrics to perform verification on the official LFW test set.\n",
|
||||||
|
"\n",
|
||||||
|
"The functions in the next block are copy pasted from `facenet.src.lfw`. Unfortunately that module has an absolute import from `facenet`, so can't be imported from the submodule\n",
|
||||||
|
"\n",
|
||||||
|
"added functionality to return false positive and false negatives"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.model_selection import KFold\n",
|
||||||
|
"from scipy import interpolate\n",
|
||||||
|
"\n",
|
||||||
|
"# LFW functions taken from David Sandberg's FaceNet implementation\n",
|
||||||
|
"def distance(embeddings1, embeddings2, distance_metric=0):\n",
|
||||||
|
" if distance_metric==0:\n",
|
||||||
|
" # Euclidian distance\n",
|
||||||
|
" diff = np.subtract(embeddings1, embeddings2)\n",
|
||||||
|
" dist = np.sum(np.square(diff),1)\n",
|
||||||
|
" elif distance_metric==1:\n",
|
||||||
|
" # Distance based on cosine similarity\n",
|
||||||
|
" dot = np.sum(np.multiply(embeddings1, embeddings2), axis=1)\n",
|
||||||
|
" norm = np.linalg.norm(embeddings1, axis=1) * np.linalg.norm(embeddings2, axis=1)\n",
|
||||||
|
" similarity = dot / norm\n",
|
||||||
|
" dist = np.arccos(similarity) / math.pi\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise 'Undefined distance metric %d' % distance_metric\n",
|
||||||
|
"\n",
|
||||||
|
" return dist\n",
|
||||||
|
"\n",
|
||||||
|
"def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
|
||||||
|
" assert(embeddings1.shape[0] == embeddings2.shape[0])\n",
|
||||||
|
" assert(embeddings1.shape[1] == embeddings2.shape[1])\n",
|
||||||
|
" nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n",
|
||||||
|
" nrof_thresholds = len(thresholds)\n",
|
||||||
|
" k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n",
|
||||||
|
"\n",
|
||||||
|
" tprs = np.zeros((nrof_folds,nrof_thresholds))\n",
|
||||||
|
" fprs = np.zeros((nrof_folds,nrof_thresholds))\n",
|
||||||
|
" accuracy = np.zeros((nrof_folds))\n",
|
||||||
|
"\n",
|
||||||
|
" is_false_positive = []\n",
|
||||||
|
" is_false_negative = []\n",
|
||||||
|
"\n",
|
||||||
|
" indices = np.arange(nrof_pairs)\n",
|
||||||
|
"\n",
|
||||||
|
" for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n",
|
||||||
|
" if subtract_mean:\n",
|
||||||
|
" mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n",
|
||||||
|
" else:\n",
|
||||||
|
" mean = 0.0\n",
|
||||||
|
" dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n",
|
||||||
|
"\n",
|
||||||
|
" # Find the best threshold for the fold\n",
|
||||||
|
" acc_train = np.zeros((nrof_thresholds))\n",
|
||||||
|
" for threshold_idx, threshold in enumerate(thresholds):\n",
|
||||||
|
" _, _, acc_train[threshold_idx], _ ,_ = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])\n",
|
||||||
|
" best_threshold_index = np.argmax(acc_train)\n",
|
||||||
|
" for threshold_idx, threshold in enumerate(thresholds):\n",
|
||||||
|
" tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _, _, _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])\n",
|
||||||
|
" _, _, accuracy[fold_idx], is_fp, is_fn = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])\n",
|
||||||
|
"\n",
|
||||||
|
" tpr = np.mean(tprs,0)\n",
|
||||||
|
" fpr = np.mean(fprs,0)\n",
|
||||||
|
" is_false_positive.extend(is_fp)\n",
|
||||||
|
" is_false_negative.extend(is_fn)\n",
|
||||||
|
"\n",
|
||||||
|
" return tpr, fpr, accuracy, is_false_positive, is_false_negative\n",
|
||||||
|
"\n",
|
||||||
|
"def calculate_accuracy(threshold, dist, actual_issame):\n",
|
||||||
|
" predict_issame = np.less(dist, threshold)\n",
|
||||||
|
" tp = np.sum(np.logical_and(predict_issame, actual_issame))\n",
|
||||||
|
" fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n",
|
||||||
|
" tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))\n",
|
||||||
|
" fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))\n",
|
||||||
|
"\n",
|
||||||
|
" is_fp = np.logical_and(predict_issame, np.logical_not(actual_issame))\n",
|
||||||
|
" is_fn = np.logical_and(np.logical_not(predict_issame), actual_issame)\n",
|
||||||
|
"\n",
|
||||||
|
" tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)\n",
|
||||||
|
" fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)\n",
|
||||||
|
" acc = float(tp+tn)/dist.size\n",
|
||||||
|
" return tpr, fpr, acc, is_fp, is_fn\n",
|
||||||
|
"\n",
|
||||||
|
"def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
|
||||||
|
" assert(embeddings1.shape[0] == embeddings2.shape[0])\n",
|
||||||
|
" assert(embeddings1.shape[1] == embeddings2.shape[1])\n",
|
||||||
|
" nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n",
|
||||||
|
" nrof_thresholds = len(thresholds)\n",
|
||||||
|
" k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n",
|
||||||
|
"\n",
|
||||||
|
" val = np.zeros(nrof_folds)\n",
|
||||||
|
" far = np.zeros(nrof_folds)\n",
|
||||||
|
"\n",
|
||||||
|
" indices = np.arange(nrof_pairs)\n",
|
||||||
|
"\n",
|
||||||
|
" for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n",
|
||||||
|
" if subtract_mean:\n",
|
||||||
|
" mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n",
|
||||||
|
" else:\n",
|
||||||
|
" mean = 0.0\n",
|
||||||
|
" dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n",
|
||||||
|
"\n",
|
||||||
|
" # Find the threshold that gives FAR = far_target\n",
|
||||||
|
" far_train = np.zeros(nrof_thresholds)\n",
|
||||||
|
" for threshold_idx, threshold in enumerate(thresholds):\n",
|
||||||
|
" _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])\n",
|
||||||
|
" if np.max(far_train)>=far_target:\n",
|
||||||
|
" f = interpolate.interp1d(far_train, thresholds, kind='slinear')\n",
|
||||||
|
" threshold = f(far_target)\n",
|
||||||
|
" else:\n",
|
||||||
|
" threshold = 0.0\n",
|
||||||
|
"\n",
|
||||||
|
" val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])\n",
|
||||||
|
"\n",
|
||||||
|
" val_mean = np.mean(val)\n",
|
||||||
|
" far_mean = np.mean(far)\n",
|
||||||
|
" val_std = np.std(val)\n",
|
||||||
|
" return val_mean, val_std, far_mean\n",
|
||||||
|
"\n",
|
||||||
|
"def calculate_val_far(threshold, dist, actual_issame):\n",
|
||||||
|
" predict_issame = np.less(dist, threshold)\n",
|
||||||
|
" true_accept = np.sum(np.logical_and(predict_issame, actual_issame))\n",
|
||||||
|
" false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n",
|
||||||
|
" n_same = np.sum(actual_issame)\n",
|
||||||
|
" n_diff = np.sum(np.logical_not(actual_issame))\n",
|
||||||
|
" val = float(true_accept) / float(n_same)\n",
|
||||||
|
" far = float(false_accept) / float(n_diff)\n",
|
||||||
|
" return val, far\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def evaluate(embeddings, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
|
||||||
|
" # Calculate evaluation metrics\n",
|
||||||
|
" thresholds = np.arange(0, 4, 0.01)\n",
|
||||||
|
" embeddings1 = embeddings[0::2]\n",
|
||||||
|
" embeddings2 = embeddings[1::2]\n",
|
||||||
|
" tpr, fpr, accuracy, fp, fn = calculate_roc(thresholds, embeddings1, embeddings2,\n",
|
||||||
|
" np.asarray(actual_issame), nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n",
|
||||||
|
" thresholds = np.arange(0, 4, 0.001)\n",
|
||||||
|
" val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,\n",
|
||||||
|
" np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n",
|
||||||
|
" return tpr, fpr, accuracy, val, val_std, far, fp, fn\n",
|
||||||
|
"\n",
|
||||||
|
"def add_extension(path):\n",
|
||||||
|
" if os.path.exists(path+'.jpg'):\n",
|
||||||
|
" return path+'.jpg'\n",
|
||||||
|
" elif os.path.exists(path+'.png'):\n",
|
||||||
|
" return path+'.png'\n",
|
||||||
|
" else:\n",
|
||||||
|
" raise RuntimeError('No file \"%s\" with extension png or jpg.' % path)\n",
|
||||||
|
"\n",
|
||||||
|
"def get_paths(lfw_dir, pairs):\n",
|
||||||
|
" nrof_skipped_pairs = 0\n",
|
||||||
|
" path_list = []\n",
|
||||||
|
" issame_list = []\n",
|
||||||
|
" for pair in pairs:\n",
|
||||||
|
" if len(pair) == 3:\n",
|
||||||
|
" path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n",
|
||||||
|
" path1 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])))\n",
|
||||||
|
" issame = True\n",
|
||||||
|
" elif len(pair) == 4:\n",
|
||||||
|
" path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n",
|
||||||
|
" path1 = add_extension(os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])))\n",
|
||||||
|
" issame = False\n",
|
||||||
|
" if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist\n",
|
||||||
|
" path_list += (path0,path1)\n",
|
||||||
|
" issame_list.append(issame)\n",
|
||||||
|
" else:\n",
|
||||||
|
" nrof_skipped_pairs += 1\n",
|
||||||
|
" if nrof_skipped_pairs>0:\n",
|
||||||
|
" print('Skipped %d image pairs' % nrof_skipped_pairs)\n",
|
||||||
|
"\n",
|
||||||
|
" return path_list, issame_list\n",
|
||||||
|
"\n",
|
||||||
|
"def read_pairs(pairs_filename):\n",
|
||||||
|
" pairs = []\n",
|
||||||
|
" with open(pairs_filename, 'r') as f:\n",
|
||||||
|
" for line in f.readlines()[1:]:\n",
|
||||||
|
" pair = line.strip().split()\n",
|
||||||
|
" pairs.append(pair)\n",
|
||||||
|
" return np.array(pairs, dtype=object)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"pairs = read_pairs(pairs_path)\n",
|
||||||
|
"path_list, issame_list = get_paths(data_dir+'_cropped', pairs)\n",
|
||||||
|
"embeddings = np.array([embeddings_dict[path] for path in path_list])\n",
|
||||||
|
"\n",
|
||||||
|
"tpr, fpr, accuracy, val, val_std, far, fp, fn = evaluate(embeddings, issame_list)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[0.995 0.995 0.99166667 0.99 0.99 0.99666667\n",
|
||||||
|
" 0.99 0.995 0.99666667 0.99666667]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "0.9936666666666666"
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(accuracy)\n",
|
||||||
|
"np.mean(accuracy)\n",
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: e0374ec1d0b25b8cd871b0d2e93d2dcf
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|
After Width: | Height: | Size: 12 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
fileFormatVersion: 2
|
||||||
|
guid: 903b193c5350d12a48fd9d7d208adfdb
|
||||||
|
DefaultImporter:
|
||||||
|
externalObjects: {}
|
||||||
|
userData:
|
||||||
|
assetBundleName:
|
||||||
|
assetBundleVariant:
|