smart-interactive-display/Assets/StreamingAssets/MergeFace/Facenet/examples/face_tracking.ipynb

232 lines
384 KiB
Plaintext
Raw Normal View History

2024-06-21 01:20:01 -07:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Face tracking pipeline\n",
"\n",
"The following example illustrates how to use the `facenet_pytorch` python package to perform face detection and tracking on an image dataset using MTCNN."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from facenet_pytorch import MTCNN\n",
"import torch\n",
"import numpy as np\n",
"import mmcv, cv2\n",
"from PIL import Image, ImageDraw\n",
"from IPython import display"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Determine if an nvidia GPU is available"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on device: cuda:0\n"
]
}
],
"source": [
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"print('Running on device: {}'.format(device))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Define MTCNN module\n",
"\n",
"Note that, since MTCNN is a collection of neural nets and other code, the device must be passed in the following way to enable copying of objects when needed internally.\n",
"\n",
"See `help(MTCNN)` for more details."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"mtcnn = MTCNN(keep_all=True, device=device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Get a sample video\n",
"\n",
"We begin by loading a video with some faces in it. The `mmcv` PyPI package by mmlabs is used to read the video frames (it can be installed with `pip install mmcv`). Frames are then converted to PIL images."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"<video src=\"video.mp4\" controls width=\"640\" >\n",
" Your browser does not support the <code>video</code> element.\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.Video object>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"video = mmcv.VideoReader('video.mp4')\n",
"frames = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in video]\n",
"\n",
"display.Video('video.mp4', width=640)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Run video through MTCNN\n",
"\n",
"We iterate through each frame, detect faces, and draw their bounding boxes on the video frames."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracking frame: 105\n",
"Done\n"
]
}
],
"source": [
"frames_tracked = []\n",
"for i, frame in enumerate(frames):\n",
" print('\\rTracking frame: {}'.format(i + 1), end='')\n",
" \n",
" # Detect faces\n",
" boxes, _ = mtcnn.detect(frame)\n",
" \n",
" # Draw faces\n",
" frame_draw = frame.copy()\n",
" draw = ImageDraw.Draw(frame_draw)\n",
" for box in boxes:\n",
" draw.rectangle(box.tolist(), outline=(255, 0, 0), width=6)\n",
" \n",
" # Add to frame list\n",
" frames_tracked.append(frame_draw.resize((640, 360), Image.BILINEAR))\n",
"print('\\nDone')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Display detections"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAFoCAIAAABIUN0GAAEAAElEQVR4nOz9Wa9s2ZEmiJnZWnvw2c987jxFMAYymAPJrGqoStVCoasBAXpQFyBA0E8SWoAA/QAJehHU6ofSQ0GAJKi7gOqqTJJBxsCY7jyf+fjse1jLTA+21nY/594YmMzMSmbl4uWNc4+7b997DTZ89pkZ/lf/6n9b1/XBwQEAtLq9P/0v/sk//5f/8r133z95dfJ/+m//D5/8+td5ajqdnMV1Bt3/xb/6l5s7O0VVA9F8PndV3W23/8v/8l/82Z9/mGUECACCgOLkq989rRbuz//knnj+D//Tx599/tnJ2dloOvn4l786ePXKe0eERCZvd4cb25s7u4JS1zV4RmAiAmRjsCqr87OT5XLh2ANgv9vtdDpHJ8dk0v5w63/1v/5v/uk//WedvPt//N//t//vf/tv0wS7W51//b/73/z8L35xfn7unBeRunKuqhOT+NrVdTmfTYxgnue/+fjjL7/4whjz0z/96U//5KOvvrr/q1/+8srVvffff//+N/dns2Wn0+n3+51u++XL58fHx4iUZTmzbO/uX7l2Nc9zZgZBZvBemKXVahERgKRpkqapMQYAnKsNQprlNknn82lVVSxIRAAAhIhkbZIaC4RZluZJy1oiBGMNpbhYlk8ePl7MF4wMiMaavb2rnU7P1SwCAKJ/BBjiICIiQJDR6PTp0yfe14iQpr2bt+622m0RDxIGMDOwiKAgGiAgBDg+PXr5/KmvHCIBIFmzs7O7u3cFCZuvABFgYQ5fSoAGEBFfvHx+enrEwB7EmHRna/fqlatEICAMIiKIKCL6KSNAqDMAi/nsxYsX89lCBIwxw42N3f29NM3XHwoIUQAAEAEREIBAyJD3brGYn5+fz5dLFgFAgzTsbwyHwyRJiEgQGZExfBZAEAFQAAAxPBQxwfpAnVgIyxTmee11NPEHBIDmoWBtIb5joFz6hTD7tYu8+X5Z/34BELzwBrr8UQRA+Lsbl+fn+9998e4uPzvKxXdcfpxL72/WMb4KzIJAiBZAPHsWNsaISLuV/+xPf/rej+61WokwGAsgIADCgAhVXf/lLz/9+NNPKyfeszWGBH78wQf/7H/253meCgChEAEKAsJyWf2//j//7otvHoAHFEDE995777/+r/5ZnlkgQPIADCDM9pv7T//7f/Nvi9ITUq+T/uidd7udjWrpnj558uTJUyFwvkDvAPxkMprPF1mWEREiMLu6rrwAIBmTGDTLZbFcLvNWqjut3W7XdV1XdbvVRsSiKERkOBz2er00TRsRVDsuinK5XE6nU+89IOukMTMCWmsFJEkSay0R5XmbIKE0GwyHe3t7Ozs73cGgNxgMBhv3H9z/7/8f/7d6Mc2QLCWIBo0x1iZJamxmszzL87zd2t7Z3tzabnW77W631WrleW6tzXJ7fnr+f/0//1+OXr9KkVCcMYAEhqy1mbGpSRJjUmPTJM2SLNna3B4Mhp1ut93tZFmekJ2Mzl8+fFQdnR0/ePTo098sJqO9ve2Pfv4ntpV9/ttPXj5+WlTLKVejerngOk1aO5s7/U4/S1uDzc2824E88YbEYOX5fDTrDDZ+/JOPHj95cnpy0m+308xkQMN2r5e2/Hh6/PT5y6dPF4vZxvbm+x9+uH/jxuz8/Hcf//Lw8SPvlqUvvKuNCHpJjN3Z2trb3TXGzKazyXgM7IkARTxI5SovngUExAugMcYmgMAIXgRFoHSGwSJiECUq2JhRBAGtIYOEAAiCIAjIzHVVO+e89yiinxoMBtevX81SunQqjDHOl81pERFhRhWf8RyJCCG2221jjGMHIOx9USyYGZDTNBH2+qmwY/BtkgUBDel1jTH7+/vD4fDs/BwRRQCJ9OPWWhGHSN57Y8zVq1cX8/nLly8/++wzAb575x34+c8//91nT548ee/9987PxoeHR/fv309SmyQmSRIiU1VlVfn5fLZcLtM0reu6WJZV5YhskqStVstaq8LIOVdVVVmW3rs8TYhMp9tF7E4m07KqvfdEiGAAgJk9eO8EQCwmhhIwqBOo90+IQCRB0HyXsHv75PxeQ2V9vAwipkny3ZdFREIqy3K51FUDRMzSrNfrJ0nC7HQD6EV0EQEAkRCBRZj5/Px8sVgwMyLleb65uZllmT77992sVFU1Ho+LomBhEUzTpNfuDAZ9a+3fwGz8kQ68vEXeUPn/wEbzeBfUdpJYEeGaQYLOns9nn376KSG8/947SWIBVx8VgSRJPvzgR8fno6fPX3jPiCjMz549fXxl573375K5sJ3yPHn3R/eevng9n81BgJlfvXp1dHR288YuCIhQFKi4tbU53Bi+fn0KIM755XLZyjrj8fjg4NV8Pm112yACCK52y+USERObIKGIMDMH0xWIxHnvnDPGJEnqfU0U7Pg0TYlo3SjRna8HzVqbZVmaZq1Wyzk3n88h2sHSDJDmkBKhNZaMpTiMMdZawH+Am0hEhMVamySJYVHdRETWmDzPsyzZ2dnpd/tJkpIxiKrgVI4BASGJMfQD5Qwi6JqSMTrhFO1qG9+BAsLC4lmcJwFDSAZU7IsAILD3zjnxnr1vvngwHHZ7HXzj2BOhMDMDNb5CeO7guCGCgACCGl/iBVWck3HOea6dq4W9Z5ao7N/+bACknhEIM6s1h2uvhsdD/QtYhIi63e67774LAJ9/8dm///f//vmzlzdv3tzc3Hz16tXGxubVq9c3N7fqun7+4hmitFotaxMAuHX7xt6V62mWtVotEUlTSZJMBImMGp7O1cy+qqrlcnl6esre7+xsZnmrKMo8T1ut3Hl2znnPIIJIAL5m8cJ1bRNKE0tEBgDizUL87/fs/8aKenPQD1ZF3nvvffhqJGNMludvruwFFwRBAGbzWVEUAAgoSNhqtzudTrhpvHzvuLoOzOeL8XjsvQdAa+1gMOh0OkQkEj62ZrFdelwoiuL8/Hw2mzEzGrLWdrvdjeFGajP4TnMkWoAXbuaPaOB3/lPkj/KhfvCQi0gDXtphuu7GmF6vh0iz+bQoC9VGRGYymXz88W+yNL1777axBIqpEOiO7g+6H3304/F0dnR0orbvdDr74suvrlzdGW70128CEW/euHL12tXHDx57dgAwnU4fPnq6v7+dZtR47YjY7XauXLlyeHjOzN678XhsyJycHp+cHnrHALneflWVzrl2u2NtIgICTMRERlgUf/LeO+darRYhehAMVu8ysYmwMLNzzlqrHqeiTSo2CVEQ8zzf3t5GxPliqpqgeRBjTJZl7Xa73W532p1W3s873W6/3+v1BoNBq93O0sz8wRZt0PXNqv11FDoKADOH6yCEx/wBsvGySRrwMBAQIrLWkvMQVAPmWba5sZHmye7+3qDfTy45IfH7kBCR5E0I59tGMATlkkS0uk6IiALs2XtflRUhIAoikKpLYUQAEWbWDzf6fGNzwybWMSj4CQAiAAKGjHA0tQAgagK168ITSdA0a/pS1GUVYQH23nnvRAQJjTFvVTPGGmuDGeG9n01ni8VCZTo28JUAMzOzwlsikiRJr9frdrvMMJmOR6NPv/7mq729nTRNP//8s/F4cvPGjb29ncVydnJyPB5PAGA4HH74wYc//bM//+b+/dOTU/aeSNrtbl36oiwNIYEQorF2Np0evj44PDrKszRNDSItlsvhsA+AaZqIsHOOvRfxIgCeHXN8ahAR7z3FNQkaCJHordYWBiW1PoUX5RESIaDI5YW/cBVEEKnqSkAwWnkKZK0Drbrjm9+EmxMpiiI4DQDWJr1eN02tAK9269oXUUCRQdiPR6PlshARItPudDY3N621/P0nCpz3k8lkMhk75xDREHU63c3NzSxJQd4wGSJ+gIgIEo9jOEdvns6/e2P/DTT7298JF1bxO9b0j2U0W/dbHl9lxWrnf/skSbwgGGta7Xa73W61s9F4vFgs6roWEECazeYf/+YTQHrnnZtJolGzII4t4a2be+PJh//+f/qrsijU6Tk4OPjii6//yT/5c2svgHztVnbv3p3XL17P3UzFy+Mnzz784Ed7+4PGShC
"text/plain": [
"<PIL.Image.Image image mode=RGB size=640x360 at 0x7F9174732978>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"d = display.display(frames_tracked[0], display_id=True)\n",
"i = 1\n",
"try:\n",
" while True:\n",
" d.update(frames_tracked[i % len(frames_tracked)])\n",
" i += 1\n",
"except KeyboardInterrupt:\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Save tracked video"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"dim = frames_tracked[0].size\n",
"fourcc = cv2.VideoWriter_fourcc(*'FMP4') \n",
"video_tracked = cv2.VideoWriter('video_tracked.mp4', fourcc, 25.0, dim)\n",
"for frame in frames_tracked:\n",
" video_tracked.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))\n",
"video_tracked.release()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}