| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- """Some helper functions for PyTorch, including:
- - get_mean_and_std: calculate the mean and std value of dataset.
- - msr_init: net parameter initialization.
- - progress_bar: progress bar mimic xlua.progress.
- """
- import os
- import sys
- import time
- import math
- import torch
- from shutil import get_terminal_size
- import torch.nn as nn
- import torch.nn.init as init
- def get_mean_and_std(dataset):
- '''Compute the mean and std value of dataset.'''
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
- mean = torch.zeros(3)
- std = torch.zeros(3)
- print('==> Computing mean and std..')
- for inputs, targets in dataloader:
- for i in range(3):
- mean[i] += inputs[:, i, :, :].mean()
- std[i] += inputs[:, i, :, :].std()
- mean.div_(len(dataset))
- std.div_(len(dataset))
- return mean, std
- def init_params(net):
- '''Init layer parameters.'''
- for m in net.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal(m.weight, mode='fan_out')
- if m.bias:
- init.constant(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- init.constant(m.weight, 1)
- init.constant(m.bias, 0)
- elif isinstance(m, nn.Linear):
- init.normal(m.weight, std=1e-3)
- if m.bias:
- init.constant(m.bias, 0)
- term_width, _ = get_terminal_size()
- TOTAL_BAR_LENGTH = 65.
- last_time = time.time()
- begin_time = last_time
- def progress_bar(current, total, msg=None):
- global last_time, begin_time
- if current == 0:
- begin_time = time.time() # Reset for new bar.
- cur_len = int(TOTAL_BAR_LENGTH * current / total)
- rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
- sys.stdout.write(' [')
- for i in range(cur_len):
- sys.stdout.write('=')
- sys.stdout.write('>')
- for i in range(rest_len):
- sys.stdout.write('.')
- sys.stdout.write(']')
- cur_time = time.time()
- step_time = cur_time - last_time
- last_time = cur_time
- tot_time = cur_time - begin_time
- L = []
- L.append(' Step: %s' % format_time(step_time))
- L.append(' | Tot: %s' % format_time(tot_time))
- if msg:
- L.append(' | ' + msg)
- msg = ''.join(L)
- sys.stdout.write(msg)
- for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
- sys.stdout.write(' ')
- # Go back to the center of the bar.
- for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
- sys.stdout.write('\b')
- sys.stdout.write(' %d/%d ' % (current + 1, total))
- if current < total - 1:
- sys.stdout.write('\r')
- else:
- sys.stdout.write('\n')
- sys.stdout.flush()
- def format_time(seconds):
- days = int(seconds / 3600 / 24)
- seconds = seconds - days * 3600 * 24
- hours = int(seconds / 3600)
- seconds = seconds - hours * 3600
- minutes = int(seconds / 60)
- seconds = seconds - minutes * 60
- secondsf = int(seconds)
- seconds = seconds - secondsf
- millis = int(seconds * 1000)
- f = ''
- i = 1
- if days > 0:
- f += str(days) + 'D'
- i += 1
- if hours > 0 and i <= 2:
- f += str(hours) + 'h'
- i += 1
- if minutes > 0 and i <= 2:
- f += str(minutes) + 'm'
- i += 1
- if secondsf > 0 and i <= 2:
- f += str(secondsf) + 's'
- i += 1
- if millis > 0 and i <= 2:
- f += str(millis) + 'ms'
- i += 1
- if f == '':
- f = '0ms'
- return f
|