main.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. """Train PreActResNet on CIFAR10 with PyTorch."""
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torch.nn.functional as F
  6. import torch.backends.cudnn as cudnn
  7. import torchvision
  8. import torchvision.transforms as transforms
  9. import torchtest
  10. from torchsummary import summary
  11. from pytorch_lightning import Trainer
  12. import os
  13. import argparse
  14. from models.PreactResNet import PreActResNet18
  15. from models.utils import progress_bar
  16. def create_dataloaders():
  17. print('==> Preparing data..')
  18. transform_train = transforms.Compose([
  19. transforms.RandomCrop(32, padding=4),
  20. transforms.RandomHorizontalFlip(),
  21. transforms.ToTensor(),
  22. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  23. ])
  24. transform_test = transforms.Compose([
  25. transforms.ToTensor(),
  26. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  27. ])
  28. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
  29. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
  30. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
  31. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)
  32. test_ds = iter(trainloader).next()
  33. test_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(100)), batch_size=100)
  34. val_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(100, 200)), batch_size=10)
  35. classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  36. return trainloader, testloader, test_ds, test_dl, val_dl, classes
  37. def create_model(args):
  38. print('==> Building model..')
  39. net = PreActResNet18()
  40. net = net.to(device)
  41. criterion = nn.CrossEntropyLoss()
  42. optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
  43. if device == 'cuda':
  44. net = torch.nn.DataParallel(net)
  45. cudnn.benchmark = True
  46. summary(net, (3, 32, 32))
  47. return net, criterion, optimizer
  48. # Training
  49. def train(epoch, trainloader, verbose=True):
  50. if verbose:
  51. print('\nEpoch: %d' % epoch)
  52. net.train()
  53. train_loss = 0
  54. correct = 0
  55. total = 0
  56. for batch_idx, (inputs, targets) in enumerate(trainloader):
  57. inputs, targets = inputs.to(device), targets.to(device)
  58. optimizer.zero_grad()
  59. outputs = net(inputs)
  60. loss = criterion(outputs, targets)
  61. loss.backward()
  62. optimizer.step()
  63. train_loss += loss.item()
  64. _, predicted = outputs.max(1)
  65. total += targets.size(0)
  66. correct += predicted.eq(targets).sum().item()
  67. if verbose:
  68. progress_bar(batch_idx, len(trainloader), f'Loss: {train_loss / (batch_idx + 1)} | '
  69. f'Acc: {100. * correct / total}'
  70. f' ({correct}/{total})')
  71. return 100. * correct / total
  72. def test(epoch, testloader, verbose=True):
  73. global best_acc
  74. net.eval()
  75. test_loss = 0
  76. correct = 0
  77. total = 0
  78. with torch.no_grad():
  79. for batch_idx, (inputs, targets) in enumerate(testloader):
  80. inputs, targets = inputs.to(device), targets.to(device)
  81. outputs = net(inputs)
  82. loss = criterion(outputs, targets)
  83. test_loss += loss.item()
  84. _, predicted = outputs.max(1)
  85. total += targets.size(0)
  86. correct += predicted.eq(targets).sum().item()
  87. if verbose:
  88. progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
  89. % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
  90. # Save checkpoint.
  91. acc = 100. * correct / total
  92. if acc > best_acc:
  93. print('Saving..')
  94. state = {
  95. 'net': net.state_dict(),
  96. 'acc': acc,
  97. 'epoch': epoch,
  98. }
  99. if not os.path.isdir('checkpoint'):
  100. os.mkdir('checkpoint')
  101. torch.save(state, './checkpoint/ckpt.pth')
  102. best_acc = acc
  103. return acc
  104. def overfit_test():
  105. for it in range(500):
  106. train_acc = train(it, test_dl, verbose=False)
  107. test_acc = test(it, val_dl)
  108. print(f'train_acc = {train_acc}')
  109. print(f'test_acc = {test_acc}')
  110. if train_acc >= 80:
  111. print('==> Overfit is Over and success!')
  112. else:
  113. raise AssertionError('Overfiting test not passed')
  114. if __name__ == "__main__":
  115. parser = argparse.ArgumentParser(description='PyTorch with PreActResNet CIFAR10 Training')
  116. parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
  117. parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
  118. parser.add_argument('--epochs', default=1, type=int, help='number of epochs for training')
  119. parser.add_argument('--test', action='store_true', help='testing model and train process though unit tests')
  120. args = parser.parse_args()
  121. device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
  122. best_acc = 0 # best test accuracy
  123. start_epoch = 0 # start from epoch 0 or last checkpoint epoch
  124. # Data
  125. trainloader, testloader, test_ds, test_dl, val_dl, classes = create_dataloaders()
  126. # Model
  127. net, criterion, optimizer = create_model(args)
  128. if args.resume:
  129. # Load checkpoint.
  130. print('==> Resuming from checkpoint..')
  131. assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
  132. checkpoint = torch.load('./checkpoint/ckpt.pth')
  133. net.load_state_dict(checkpoint['net'])
  134. best_acc = checkpoint['acc']
  135. start_epoch = checkpoint['epoch']
  136. if args.test:
  137. # testing model
  138. print('==> Testing model and train process...')
  139. torchtest.assert_vars_change(
  140. model=net,
  141. loss_fn=criterion,
  142. optim=optimizer,
  143. batch=test_ds,
  144. device=device)
  145. torchtest.test_suite(
  146. model=net,
  147. loss_fn=criterion,
  148. optim=optimizer,
  149. batch=test_ds,
  150. device=device,
  151. test_nan_vals=True,
  152. test_vars_change=True,
  153. # non_train_vars=None,
  154. test_inf_vals=True
  155. )
  156. overfit_test()
  157. print('==> All test are passed! Let is train whole network.')
  158. print('==> Let is TRAIN begin!')
  159. best_acc = 0 # best test accuracy
  160. for epoch in range(start_epoch, start_epoch + args.epochs):
  161. train(epoch, trainloader)
  162. test(epoch, testloader)
  163. print("==> Train is finished")