145 lines
5.0 KiB
Python
145 lines
5.0 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
import time
|
||
|
|
||
|
|
||
|
class Logger(object):
|
||
|
|
||
|
def __init__(self, mode, length, calculate_mean=False):
|
||
|
self.mode = mode
|
||
|
self.length = length
|
||
|
self.calculate_mean = calculate_mean
|
||
|
if self.calculate_mean:
|
||
|
self.fn = lambda x, i: x / (i + 1)
|
||
|
else:
|
||
|
self.fn = lambda x, i: x
|
||
|
|
||
|
def __call__(self, loss, metrics, i):
|
||
|
track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length)
|
||
|
loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i))
|
||
|
metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items())
|
||
|
print(track_str + loss_str + metric_str + ' ', end='')
|
||
|
if i + 1 == self.length:
|
||
|
print('')
|
||
|
|
||
|
|
||
|
class BatchTimer(object):
|
||
|
"""Batch timing class.
|
||
|
Use this class for tracking training and testing time/rate per batch or per sample.
|
||
|
|
||
|
Keyword Arguments:
|
||
|
rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds
|
||
|
per batch or sample). (default: {True})
|
||
|
per_sample {bool} -- Whether to report times or rates per sample or per batch.
|
||
|
(default: {True})
|
||
|
"""
|
||
|
|
||
|
def __init__(self, rate=True, per_sample=True):
|
||
|
self.start = time.time()
|
||
|
self.end = None
|
||
|
self.rate = rate
|
||
|
self.per_sample = per_sample
|
||
|
|
||
|
def __call__(self, y_pred, y):
|
||
|
self.end = time.time()
|
||
|
elapsed = self.end - self.start
|
||
|
self.start = self.end
|
||
|
self.end = None
|
||
|
|
||
|
if self.per_sample:
|
||
|
elapsed /= len(y_pred)
|
||
|
if self.rate:
|
||
|
elapsed = 1 / elapsed
|
||
|
|
||
|
return torch.tensor(elapsed)
|
||
|
|
||
|
|
||
|
def accuracy(logits, y):
|
||
|
_, preds = torch.max(logits, 1)
|
||
|
return (preds == y).float().mean()
|
||
|
|
||
|
|
||
|
def pass_epoch(
|
||
|
model, loss_fn, loader, optimizer=None, scheduler=None,
|
||
|
batch_metrics={'time': BatchTimer()}, show_running=True,
|
||
|
device='cpu', writer=None
|
||
|
):
|
||
|
"""Train or evaluate over a data epoch.
|
||
|
|
||
|
Arguments:
|
||
|
model {torch.nn.Module} -- Pytorch model.
|
||
|
loss_fn {callable} -- A function to compute (scalar) loss.
|
||
|
loader {torch.utils.data.DataLoader} -- A pytorch data loader.
|
||
|
|
||
|
Keyword Arguments:
|
||
|
optimizer {torch.optim.Optimizer} -- A pytorch optimizer.
|
||
|
scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None})
|
||
|
batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default
|
||
|
is a simple timer. A progressive average of these metrics, along with the average
|
||
|
loss, is printed every batch. (default: {{'time': iter_timer()}})
|
||
|
show_running {bool} -- Whether or not to print losses and metrics for the current batch
|
||
|
or rolling averages. (default: {False})
|
||
|
device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'})
|
||
|
writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None})
|
||
|
|
||
|
Returns:
|
||
|
tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average
|
||
|
metric values across the epoch.
|
||
|
"""
|
||
|
|
||
|
mode = 'Train' if model.training else 'Valid'
|
||
|
logger = Logger(mode, length=len(loader), calculate_mean=show_running)
|
||
|
loss = 0
|
||
|
metrics = {}
|
||
|
|
||
|
for i_batch, (x, y) in enumerate(loader):
|
||
|
x = x.to(device)
|
||
|
y = y.to(device)
|
||
|
y_pred = model(x)
|
||
|
loss_batch = loss_fn(y_pred, y)
|
||
|
|
||
|
if model.training:
|
||
|
loss_batch.backward()
|
||
|
optimizer.step()
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
metrics_batch = {}
|
||
|
for metric_name, metric_fn in batch_metrics.items():
|
||
|
metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu()
|
||
|
metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name]
|
||
|
|
||
|
if writer is not None and model.training:
|
||
|
if writer.iteration % writer.interval == 0:
|
||
|
writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration)
|
||
|
for metric_name, metric_batch in metrics_batch.items():
|
||
|
writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration)
|
||
|
writer.iteration += 1
|
||
|
|
||
|
loss_batch = loss_batch.detach().cpu()
|
||
|
loss += loss_batch
|
||
|
if show_running:
|
||
|
logger(loss, metrics, i_batch)
|
||
|
else:
|
||
|
logger(loss_batch, metrics_batch, i_batch)
|
||
|
|
||
|
if model.training and scheduler is not None:
|
||
|
scheduler.step()
|
||
|
|
||
|
loss = loss / (i_batch + 1)
|
||
|
metrics = {k: v / (i_batch + 1) for k, v in metrics.items()}
|
||
|
|
||
|
if writer is not None and not model.training:
|
||
|
writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration)
|
||
|
for metric_name, metric in metrics.items():
|
||
|
writer.add_scalars(metric_name, {mode: metric})
|
||
|
|
||
|
return loss, metrics
|
||
|
|
||
|
|
||
|
def collate_pil(x):
|
||
|
out_x, out_y = [], []
|
||
|
for xx, yy in x:
|
||
|
out_x.append(xx)
|
||
|
out_y.append(yy)
|
||
|
return out_x, out_y
|