280 lines
6.9 KiB
Plaintext
280 lines
6.9 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Face detection and recognition training pipeline\n",
|
||
|
"\n",
|
||
|
"The following example illustrates how to fine-tune an InceptionResnetV1 model on your own dataset. This will mostly follow standard pytorch training patterns."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training\n",
|
||
|
"import torch\n",
|
||
|
"from torch.utils.data import DataLoader, SubsetRandomSampler\n",
|
||
|
"from torch import optim\n",
|
||
|
"from torch.optim.lr_scheduler import MultiStepLR\n",
|
||
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
||
|
"from torchvision import datasets, transforms\n",
|
||
|
"import numpy as np\n",
|
||
|
"import os"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Define run parameters\n",
|
||
|
"\n",
|
||
|
"The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify `data_dir` to the location of the dataset on wish to finetune on."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"data_dir = '../data/test_images'\n",
|
||
|
"\n",
|
||
|
"batch_size = 32\n",
|
||
|
"epochs = 8\n",
|
||
|
"workers = 0 if os.name == 'nt' else 8"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Determine if an nvidia GPU is available"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"print('Running on device: {}'.format(device))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Define MTCNN module\n",
|
||
|
"\n",
|
||
|
"See `help(MTCNN)` for more details."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"mtcnn = MTCNN(\n",
|
||
|
" image_size=160, margin=0, min_face_size=20,\n",
|
||
|
" thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n",
|
||
|
" device=device\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Perfom MTCNN facial detection\n",
|
||
|
"\n",
|
||
|
"Iterate through the DataLoader object and obtain cropped faces."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))\n",
|
||
|
"dataset.samples = [\n",
|
||
|
" (p, p.replace(data_dir, data_dir + '_cropped'))\n",
|
||
|
" for p, _ in dataset.samples\n",
|
||
|
"]\n",
|
||
|
" \n",
|
||
|
"loader = DataLoader(\n",
|
||
|
" dataset,\n",
|
||
|
" num_workers=workers,\n",
|
||
|
" batch_size=batch_size,\n",
|
||
|
" collate_fn=training.collate_pil\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"for i, (x, y) in enumerate(loader):\n",
|
||
|
" mtcnn(x, save_path=y)\n",
|
||
|
" print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')\n",
|
||
|
" \n",
|
||
|
"# Remove mtcnn to reduce GPU memory usage\n",
|
||
|
"del mtcnn"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Define Inception Resnet V1 module\n",
|
||
|
"\n",
|
||
|
"See `help(InceptionResnetV1)` for more details."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"resnet = InceptionResnetV1(\n",
|
||
|
" classify=True,\n",
|
||
|
" pretrained='vggface2',\n",
|
||
|
" num_classes=len(dataset.class_to_idx)\n",
|
||
|
").to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Define optimizer, scheduler, dataset, and dataloader"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"optimizer = optim.Adam(resnet.parameters(), lr=0.001)\n",
|
||
|
"scheduler = MultiStepLR(optimizer, [5, 10])\n",
|
||
|
"\n",
|
||
|
"trans = transforms.Compose([\n",
|
||
|
" np.float32,\n",
|
||
|
" transforms.ToTensor(),\n",
|
||
|
" fixed_image_standardization\n",
|
||
|
"])\n",
|
||
|
"dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n",
|
||
|
"img_inds = np.arange(len(dataset))\n",
|
||
|
"np.random.shuffle(img_inds)\n",
|
||
|
"train_inds = img_inds[:int(0.8 * len(img_inds))]\n",
|
||
|
"val_inds = img_inds[int(0.8 * len(img_inds)):]\n",
|
||
|
"\n",
|
||
|
"train_loader = DataLoader(\n",
|
||
|
" dataset,\n",
|
||
|
" num_workers=workers,\n",
|
||
|
" batch_size=batch_size,\n",
|
||
|
" sampler=SubsetRandomSampler(train_inds)\n",
|
||
|
")\n",
|
||
|
"val_loader = DataLoader(\n",
|
||
|
" dataset,\n",
|
||
|
" num_workers=workers,\n",
|
||
|
" batch_size=batch_size,\n",
|
||
|
" sampler=SubsetRandomSampler(val_inds)\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Define loss and evaluation functions"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"loss_fn = torch.nn.CrossEntropyLoss()\n",
|
||
|
"metrics = {\n",
|
||
|
" 'fps': training.BatchTimer(),\n",
|
||
|
" 'acc': training.accuracy\n",
|
||
|
"}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Train model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"writer = SummaryWriter()\n",
|
||
|
"writer.iteration, writer.interval = 0, 10\n",
|
||
|
"\n",
|
||
|
"print('\\n\\nInitial')\n",
|
||
|
"print('-' * 10)\n",
|
||
|
"resnet.eval()\n",
|
||
|
"training.pass_epoch(\n",
|
||
|
" resnet, loss_fn, val_loader,\n",
|
||
|
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||
|
" writer=writer\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"for epoch in range(epochs):\n",
|
||
|
" print('\\nEpoch {}/{}'.format(epoch + 1, epochs))\n",
|
||
|
" print('-' * 10)\n",
|
||
|
"\n",
|
||
|
" resnet.train()\n",
|
||
|
" training.pass_epoch(\n",
|
||
|
" resnet, loss_fn, train_loader, optimizer, scheduler,\n",
|
||
|
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||
|
" writer=writer\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" resnet.eval()\n",
|
||
|
" training.pass_epoch(\n",
|
||
|
" resnet, loss_fn, val_loader,\n",
|
||
|
" batch_metrics=metrics, show_running=True, device=device,\n",
|
||
|
" writer=writer\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"writer.close()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.3"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|