|
|
@@ -9,10 +9,11 @@ import torchvision
|
|
|
import torchvision.transforms as transforms
|
|
|
import torchtest
|
|
|
from torchsummary import summary
|
|
|
-from pytorch_lightning import Trainer
|
|
|
+# from pytorch_lightning import Trainer
|
|
|
|
|
|
import os
|
|
|
import argparse
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
from models.PreactResNet import PreActResNet18
|
|
|
from models.utils import progress_bar
|
|
|
@@ -128,9 +129,9 @@ def test(epoch, testloader, verbose=True):
|
|
|
|
|
|
|
|
|
def overfit_test():
|
|
|
- for it in range(500):
|
|
|
+ for it in tqdm(range(500)):
|
|
|
train_acc = train(it, test_dl, verbose=False)
|
|
|
- test_acc = test(it, val_dl)
|
|
|
+ test_acc = test(it, val_dl, verbose=False)
|
|
|
print(f'train_acc = {train_acc}')
|
|
|
print(f'test_acc = {test_acc}')
|
|
|
if train_acc >= 80:
|
|
|
@@ -145,6 +146,7 @@ if __name__ == "__main__":
|
|
|
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
|
|
|
parser.add_argument('--epochs', default=1, type=int, help='number of epochs for training')
|
|
|
parser.add_argument('--test', action='store_true', help='testing model and train process though unit tests')
|
|
|
+ parser.add_argument('--train', action='store_true', help='train model')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
|
|
|
@@ -193,9 +195,10 @@ if __name__ == "__main__":
|
|
|
|
|
|
print('==> All test are passed! Let is train whole network.')
|
|
|
|
|
|
- print('==> Let is TRAIN begin!')
|
|
|
- best_acc = 0 # best test accuracy
|
|
|
- for epoch in range(start_epoch, start_epoch + args.epochs):
|
|
|
- train(epoch, trainloader)
|
|
|
- test(epoch, testloader)
|
|
|
- print("==> Train is finished")
|
|
|
+ if args.train or args.resume:
|
|
|
+ print('==> Let is TRAIN begin!')
|
|
|
+ best_acc = 0 # best test accuracy
|
|
|
+ for epoch in range(start_epoch, start_epoch + args.epochs):
|
|
|
+ train(epoch, trainloader)
|
|
|
+ test(epoch, testloader)
|
|
|
+ print("==> Train is finished")
|