utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. """Some helper functions for PyTorch, including:
  2. - get_mean_and_std: calculate the mean and std value of dataset.
  3. - msr_init: net parameter initialization.
  4. - progress_bar: progress bar mimic xlua.progress.
  5. """
  6. import os
  7. import sys
  8. import time
  9. import math
  10. import torch
  11. from shutil import get_terminal_size
  12. import torch.nn as nn
  13. import torch.nn.init as init
  14. def get_mean_and_std(dataset):
  15. '''Compute the mean and std value of dataset.'''
  16. dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  17. mean = torch.zeros(3)
  18. std = torch.zeros(3)
  19. print('==> Computing mean and std..')
  20. for inputs, targets in dataloader:
  21. for i in range(3):
  22. mean[i] += inputs[:, i, :, :].mean()
  23. std[i] += inputs[:, i, :, :].std()
  24. mean.div_(len(dataset))
  25. std.div_(len(dataset))
  26. return mean, std
  27. def init_params(net):
  28. '''Init layer parameters.'''
  29. for m in net.modules():
  30. if isinstance(m, nn.Conv2d):
  31. init.kaiming_normal(m.weight, mode='fan_out')
  32. if m.bias:
  33. init.constant(m.bias, 0)
  34. elif isinstance(m, nn.BatchNorm2d):
  35. init.constant(m.weight, 1)
  36. init.constant(m.bias, 0)
  37. elif isinstance(m, nn.Linear):
  38. init.normal(m.weight, std=1e-3)
  39. if m.bias:
  40. init.constant(m.bias, 0)
  41. term_width, _ = get_terminal_size()
  42. TOTAL_BAR_LENGTH = 65.
  43. last_time = time.time()
  44. begin_time = last_time
  45. def progress_bar(current, total, msg=None):
  46. global last_time, begin_time
  47. if current == 0:
  48. begin_time = time.time() # Reset for new bar.
  49. cur_len = int(TOTAL_BAR_LENGTH * current / total)
  50. rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
  51. sys.stdout.write(' [')
  52. for i in range(cur_len):
  53. sys.stdout.write('=')
  54. sys.stdout.write('>')
  55. for i in range(rest_len):
  56. sys.stdout.write('.')
  57. sys.stdout.write(']')
  58. cur_time = time.time()
  59. step_time = cur_time - last_time
  60. last_time = cur_time
  61. tot_time = cur_time - begin_time
  62. L = []
  63. L.append(' Step: %s' % format_time(step_time))
  64. L.append(' | Tot: %s' % format_time(tot_time))
  65. if msg:
  66. L.append(' | ' + msg)
  67. msg = ''.join(L)
  68. sys.stdout.write(msg)
  69. for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
  70. sys.stdout.write(' ')
  71. # Go back to the center of the bar.
  72. for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
  73. sys.stdout.write('\b')
  74. sys.stdout.write(' %d/%d ' % (current + 1, total))
  75. if current < total - 1:
  76. sys.stdout.write('\r')
  77. else:
  78. sys.stdout.write('\n')
  79. sys.stdout.flush()
  80. def format_time(seconds):
  81. days = int(seconds / 3600 / 24)
  82. seconds = seconds - days * 3600 * 24
  83. hours = int(seconds / 3600)
  84. seconds = seconds - hours * 3600
  85. minutes = int(seconds / 60)
  86. seconds = seconds - minutes * 60
  87. secondsf = int(seconds)
  88. seconds = seconds - secondsf
  89. millis = int(seconds * 1000)
  90. f = ''
  91. i = 1
  92. if days > 0:
  93. f += str(days) + 'D'
  94. i += 1
  95. if hours > 0 and i <= 2:
  96. f += str(hours) + 'h'
  97. i += 1
  98. if minutes > 0 and i <= 2:
  99. f += str(minutes) + 'm'
  100. i += 1
  101. if secondsf > 0 and i <= 2:
  102. f += str(secondsf) + 's'
  103. i += 1
  104. if millis > 0 and i <= 2:
  105. f += str(millis) + 'ms'
  106. i += 1
  107. if f == '':
  108. f = '0ms'
  109. return f