metya 6 роки тому
батько
коміт
cdbdb022df
2 змінених файлів з 35 додано та 10 видалено
  1. 23 1
      README.md
  2. 12 9
      main.py

+ 23 - 1
README.md

@@ -1 +1,23 @@
-KalugaRes
+# PreAct ResNet with some tests on cifar10
+
+To run train without tests just do
+
+```shell script
+python main.py --train --epochs num
+```
+
+To run train alongside with some unit tests just do 
+```shell script
+python main.py --test --train --epochs num
+``` 
+
+If you want to run just tests, then do
+```shell script
+python main.py --test
+```
+
+###TODO
+
+- Regression metrics tests
+- Gradient explosion and gradient vanishing check
+- Pythonistic Unit Tests

+ 12 - 9
main.py

@@ -9,10 +9,11 @@ import torchvision
 import torchvision.transforms as transforms
 import torchtest
 from torchsummary import summary
-from pytorch_lightning import Trainer
+# from pytorch_lightning import Trainer
 
 import os
 import argparse
+from tqdm import tqdm
 
 from models.PreactResNet import PreActResNet18
 from models.utils import progress_bar
@@ -128,9 +129,9 @@ def test(epoch, testloader, verbose=True):
 
 
 def overfit_test():
-    for it in range(500):
+    for it in tqdm(range(500)):
         train_acc = train(it, test_dl, verbose=False)
-    test_acc = test(it, val_dl)
+    test_acc = test(it, val_dl, verbose=False)
     print(f'train_acc = {train_acc}')
     print(f'test_acc = {test_acc}')
     if train_acc >= 80:
@@ -145,6 +146,7 @@ if __name__ == "__main__":
     parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
     parser.add_argument('--epochs', default=1, type=int, help='number of epochs for training')
     parser.add_argument('--test', action='store_true', help='testing model and train process though unit tests')
+    parser.add_argument('--train', action='store_true', help='train model')
     args = parser.parse_args()
 
     device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
@@ -193,9 +195,10 @@ if __name__ == "__main__":
 
         print('==> All test are passed! Let is train whole network.')
 
-    print('==> Let is TRAIN begin!')
-    best_acc = 0  # best test accuracy
-    for epoch in range(start_epoch, start_epoch + args.epochs):
-        train(epoch, trainloader)
-        test(epoch, testloader)
-    print("==> Train is finished")
+    if args.train or args.resume:
+        print('==> Let is TRAIN begin!')
+        best_acc = 0  # best test accuracy
+        for epoch in range(start_epoch, start_epoch + args.epochs):
+            train(epoch, trainloader)
+            test(epoch, testloader)
+        print("==> Train is finished")