main.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. """Train PreActResNet on CIFAR10 with PyTorch."""
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torch.nn.functional as F
  6. import torch.backends.cudnn as cudnn
  7. import torchvision
  8. import torchvision.transforms as transforms
  9. import torchtest
  10. from torchsummary import summary
  11. # from pytorch_lightning import Trainer
  12. import os
  13. import argparse
  14. from tqdm import tqdm
  15. from models.PreactResNet import PreActResNet18
  16. from models.utils import progress_bar
  17. def create_dataloaders():
  18. print('==> Preparing data..')
  19. transform_train = transforms.Compose([
  20. transforms.RandomCrop(32, padding=4),
  21. transforms.RandomHorizontalFlip(),
  22. transforms.ToTensor(),
  23. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  24. ])
  25. transform_test = transforms.Compose([
  26. transforms.ToTensor(),
  27. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  28. ])
  29. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
  30. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
  31. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
  32. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)
  33. test_ds = iter(trainloader).next()
  34. test_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(100)), batch_size=100)
  35. val_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(100, 200)), batch_size=10)
  36. classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  37. return trainloader, testloader, test_ds, test_dl, val_dl, classes
  38. def create_model(args):
  39. print('==> Building model..')
  40. net = PreActResNet18()
  41. net = net.to(device)
  42. criterion = nn.CrossEntropyLoss()
  43. optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
  44. if device == 'cuda':
  45. net = torch.nn.DataParallel(net)
  46. cudnn.benchmark = True
  47. summary(net, (3, 32, 32))
  48. return net, criterion, optimizer
  49. # Training
  50. def train(epoch, trainloader, verbose=True):
  51. if verbose:
  52. print('\nEpoch: %d' % epoch)
  53. net.train()
  54. train_loss = 0
  55. correct = 0
  56. total = 0
  57. for batch_idx, (inputs, targets) in enumerate(trainloader):
  58. inputs, targets = inputs.to(device), targets.to(device)
  59. optimizer.zero_grad()
  60. outputs = net(inputs)
  61. loss = criterion(outputs, targets)
  62. loss.backward()
  63. optimizer.step()
  64. train_loss += loss.item()
  65. _, predicted = outputs.max(1)
  66. total += targets.size(0)
  67. correct += predicted.eq(targets).sum().item()
  68. if verbose:
  69. progress_bar(batch_idx, len(trainloader), f'Loss: {train_loss / (batch_idx + 1)} | '
  70. f'Acc: {100. * correct / total}'
  71. f' ({correct}/{total})')
  72. return 100. * correct / total
  73. def test(epoch, testloader, verbose=True):
  74. global best_acc
  75. net.eval()
  76. test_loss = 0
  77. correct = 0
  78. total = 0
  79. with torch.no_grad():
  80. for batch_idx, (inputs, targets) in enumerate(testloader):
  81. inputs, targets = inputs.to(device), targets.to(device)
  82. outputs = net(inputs)
  83. loss = criterion(outputs, targets)
  84. test_loss += loss.item()
  85. _, predicted = outputs.max(1)
  86. total += targets.size(0)
  87. correct += predicted.eq(targets).sum().item()
  88. if verbose:
  89. progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
  90. % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
  91. # Save checkpoint.
  92. acc = 100. * correct / total
  93. if acc > best_acc:
  94. print('Saving..')
  95. state = {
  96. 'net': net.state_dict(),
  97. 'acc': acc,
  98. 'epoch': epoch,
  99. }
  100. if not os.path.isdir('checkpoint'):
  101. os.mkdir('checkpoint')
  102. torch.save(state, './checkpoint/ckpt.pth')
  103. best_acc = acc
  104. return acc
  105. def overfit_test():
  106. for it in tqdm(range(500)):
  107. train_acc = train(it, test_dl, verbose=False)
  108. test_acc = test(it, val_dl, verbose=False)
  109. print(f'train_acc = {train_acc}')
  110. print(f'test_acc = {test_acc}')
  111. if train_acc >= 80:
  112. print('==> Overfit is Over and success!')
  113. else:
  114. raise AssertionError('Overfiting test not passed')
  115. if __name__ == "__main__":
  116. parser = argparse.ArgumentParser(description='PyTorch with PreActResNet CIFAR10 Training')
  117. parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
  118. parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
  119. parser.add_argument('--epochs', default=1, type=int, help='number of epochs for training')
  120. parser.add_argument('--test', action='store_true', help='testing model and train process though unit tests')
  121. parser.add_argument('--train', action='store_true', help='train model')
  122. args = parser.parse_args()
  123. device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
  124. best_acc = 0 # best test accuracy
  125. start_epoch = 0 # start from epoch 0 or last checkpoint epoch
  126. # Data
  127. trainloader, testloader, test_ds, test_dl, val_dl, classes = create_dataloaders()
  128. # Model
  129. net, criterion, optimizer = create_model(args)
  130. if args.resume:
  131. # Load checkpoint.
  132. print('==> Resuming from checkpoint..')
  133. assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
  134. checkpoint = torch.load('./checkpoint/ckpt.pth')
  135. net.load_state_dict(checkpoint['net'])
  136. best_acc = checkpoint['acc']
  137. start_epoch = checkpoint['epoch']
  138. if args.test:
  139. # testing model
  140. print('==> Testing model and train process...')
  141. torchtest.assert_vars_change(
  142. model=net,
  143. loss_fn=criterion,
  144. optim=optimizer,
  145. batch=test_ds,
  146. device=device)
  147. torchtest.test_suite(
  148. model=net,
  149. loss_fn=criterion,
  150. optim=optimizer,
  151. batch=test_ds,
  152. device=device,
  153. test_nan_vals=True,
  154. test_vars_change=True,
  155. # non_train_vars=None,
  156. test_inf_vals=True
  157. )
  158. overfit_test()
  159. print('==> All test are passed! Let is train whole network.')
  160. if args.train or args.resume:
  161. print('==> Let is TRAIN begin!')
  162. best_acc = 0 # best test accuracy
  163. for epoch in range(start_epoch, start_epoch + args.epochs):
  164. train(epoch, trainloader)
  165. test(epoch, testloader)
  166. print("==> Train is finished")