metya il y a 6 ans
Parent
commit
22576ddbc9
6 fichiers modifiés avec 462 ajouts et 0 suppressions
  1. 8 0
      .gitignore
  2. 201 0
      main.py
  3. 117 0
      models/PreactResNet.py
  4. 0 0
      models/__init__.py
  5. 128 0
      models/utils.py
  6. 8 0
      requirements.txt

+ 8 - 0
.gitignore

@@ -0,0 +1,8 @@
+
+\.idea/
+
+data/
+
+models/__pycache__/
+
+checkpoint/

+ 201 - 0
main.py

@@ -0,0 +1,201 @@
+"""Train PreActResNet on CIFAR10 with PyTorch."""
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+
+import torchvision
+import torchvision.transforms as transforms
+import torchtest
+from torchsummary import summary
+from pytorch_lightning import Trainer
+
+import os
+import argparse
+
+from models.PreactResNet import PreActResNet18
+from models.utils import progress_bar
+
+
+def create_dataloaders():
+    print('==> Preparing data..')
+    transform_train = transforms.Compose([
+        transforms.RandomCrop(32, padding=4),
+        transforms.RandomHorizontalFlip(),
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ])
+
+    transform_test = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+    ])
+
+    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
+
+    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)
+
+    test_ds = iter(trainloader).next()
+    test_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(100)), batch_size=100)
+    val_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(trainset, range(100, 200)), batch_size=10)
+
+    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
+
+    return trainloader, testloader, test_ds, test_dl, val_dl, classes
+
+
+def create_model(args):
+    print('==> Building model..')
+    net = PreActResNet18()
+    net = net.to(device)
+    criterion = nn.CrossEntropyLoss()
+    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
+    if device == 'cuda':
+        net = torch.nn.DataParallel(net)
+        cudnn.benchmark = True
+
+    summary(net, (3, 32, 32))
+
+    return net, criterion, optimizer
+
+
+# Training
+def train(epoch, trainloader, verbose=True):
+    if verbose:
+        print('\nEpoch: %d' % epoch)
+    net.train()
+    train_loss = 0
+    correct = 0
+    total = 0
+    for batch_idx, (inputs, targets) in enumerate(trainloader):
+        inputs, targets = inputs.to(device), targets.to(device)
+        optimizer.zero_grad()
+        outputs = net(inputs)
+        loss = criterion(outputs, targets)
+        loss.backward()
+        optimizer.step()
+
+        train_loss += loss.item()
+        _, predicted = outputs.max(1)
+        total += targets.size(0)
+        correct += predicted.eq(targets).sum().item()
+        if verbose:
+            progress_bar(batch_idx, len(trainloader), f'Loss: {train_loss / (batch_idx + 1)} | '
+                                                      f'Acc: {100. * correct / total}'
+                                                      f' ({correct}/{total})')
+
+    return 100. * correct / total
+
+
+def test(epoch, testloader, verbose=True):
+    global best_acc
+    net.eval()
+    test_loss = 0
+    correct = 0
+    total = 0
+    with torch.no_grad():
+        for batch_idx, (inputs, targets) in enumerate(testloader):
+            inputs, targets = inputs.to(device), targets.to(device)
+            outputs = net(inputs)
+            loss = criterion(outputs, targets)
+
+            test_loss += loss.item()
+            _, predicted = outputs.max(1)
+            total += targets.size(0)
+            correct += predicted.eq(targets).sum().item()
+            if verbose:
+                progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
+                             % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
+
+    # Save checkpoint.
+    acc = 100. * correct / total
+    if acc > best_acc:
+        print('Saving..')
+        state = {
+            'net': net.state_dict(),
+            'acc': acc,
+            'epoch': epoch,
+        }
+        if not os.path.isdir('checkpoint'):
+            os.mkdir('checkpoint')
+        torch.save(state, './checkpoint/ckpt.pth')
+        best_acc = acc
+
+    return acc
+
+
+def overfit_test():
+    for it in range(500):
+        train_acc = train(it, test_dl, verbose=False)
+    test_acc = test(it, val_dl)
+    print(f'train_acc = {train_acc}')
+    print(f'test_acc = {test_acc}')
+    if train_acc >= 80:
+        print('==> Overfit is Over and success!')
+    else:
+        raise AssertionError('Overfiting test not passed')
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='PyTorch with PreActResNet CIFAR10 Training')
+    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
+    parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
+    parser.add_argument('--epochs', default=1, type=int, help='number of epochs for training')
+    parser.add_argument('--test', action='store_true', help='testing model and train process though unit tests')
+    args = parser.parse_args()
+
+    device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
+    best_acc = 0  # best test accuracy
+    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
+
+    # Data
+    trainloader, testloader, test_ds, test_dl, val_dl, classes = create_dataloaders()
+
+    # Model
+    net, criterion, optimizer = create_model(args)
+
+    if args.resume:
+        # Load checkpoint.
+        print('==> Resuming from checkpoint..')
+        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
+        checkpoint = torch.load('./checkpoint/ckpt.pth')
+        net.load_state_dict(checkpoint['net'])
+        best_acc = checkpoint['acc']
+        start_epoch = checkpoint['epoch']
+
+    if args.test:
+        # testing model
+        print('==> Testing model and train process...')
+
+        torchtest.assert_vars_change(
+            model=net,
+            loss_fn=criterion,
+            optim=optimizer,
+            batch=test_ds,
+            device=device)
+
+        torchtest.test_suite(
+            model=net,
+            loss_fn=criterion,
+            optim=optimizer,
+            batch=test_ds,
+            device=device,
+            test_nan_vals=True,
+            test_vars_change=True,
+            # non_train_vars=None,
+            test_inf_vals=True
+        )
+
+        overfit_test()
+
+        print('==> All test are passed! Let is train whole network.')
+
+    print('==> Let is TRAIN begin!')
+    best_acc = 0  # best test accuracy
+    for epoch in range(start_epoch, start_epoch + args.epochs):
+        train(epoch, trainloader)
+        test(epoch, testloader)
+    print("==> Train is finished")

+ 117 - 0
models/PreactResNet.py

@@ -0,0 +1,117 @@
+"""Pre-activation ResNet in PyTorch.
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+    Identity Mappings in Deep Residual Networks. arXiv:1603.05027
+
+Just stole from here https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PreActBlock(nn.Module):
+    """Pre-activation version of the BasicBlock."""
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(PreActBlock, self).__init__()
+        self.bn1 = nn.BatchNorm2d(in_planes)
+        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+
+        if stride != 1 or in_planes != self.expansion*planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(x))
+        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
+        out = self.conv1(out)
+        out = self.conv2(F.relu(self.bn2(out)))
+        out += shortcut
+        return out
+
+
+class PreActBottleneck(nn.Module):
+    """Pre-activation version of the original Bottleneck module."""
+    expansion = 4
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(PreActBottleneck, self).__init__()
+        self.bn1 = nn.BatchNorm2d(in_planes)
+        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
+
+        if stride != 1 or in_planes != self.expansion*planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(x))
+        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
+        out = self.conv1(out)
+        out = self.conv2(F.relu(self.bn2(out)))
+        out = self.conv3(F.relu(self.bn3(out)))
+        out += shortcut
+        return out
+
+
+class PreActResNet(nn.Module):
+    def __init__(self, block, num_blocks, num_classes=10):
+        super(PreActResNet, self).__init__()
+        self.in_planes = 64
+
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+        self.linear = nn.Linear(512*block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1]*(num_blocks-1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = self.layer4(out)
+        out = F.avg_pool2d(out, 4)
+        out = out.view(out.size(0), -1)
+        out = self.linear(out)
+        return out
+
+
+def PreActResNet18():
+    return PreActResNet(PreActBlock, [2,2,2,2])
+
+def PreActResNet34():
+    return PreActResNet(PreActBlock, [3,4,6,3])
+
+def PreActResNet50():
+    return PreActResNet(PreActBottleneck, [3,4,6,3])
+
+def PreActResNet101():
+    return PreActResNet(PreActBottleneck, [3,4,23,3])
+
+def PreActResNet152():
+    return PreActResNet(PreActBottleneck, [3,8,36,3])
+
+
+def test():
+    net = PreActResNet18()
+    y = net((torch.randn(1,3,32,32)))
+    print(y.size())

+ 0 - 0
models/__init__.py


+ 128 - 0
models/utils.py

@@ -0,0 +1,128 @@
+"""Some helper functions for PyTorch, including:
+    - get_mean_and_std: calculate the mean and std value of dataset.
+    - msr_init: net parameter initialization.
+    - progress_bar: progress bar mimic xlua.progress.
+"""
+import os
+import sys
+import time
+import math
+import torch
+from shutil import get_terminal_size
+import torch.nn as nn
+import torch.nn.init as init
+
+
+def get_mean_and_std(dataset):
+    '''Compute the mean and std value of dataset.'''
+    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
+    mean = torch.zeros(3)
+    std = torch.zeros(3)
+    print('==> Computing mean and std..')
+    for inputs, targets in dataloader:
+        for i in range(3):
+            mean[i] += inputs[:, i, :, :].mean()
+            std[i] += inputs[:, i, :, :].std()
+    mean.div_(len(dataset))
+    std.div_(len(dataset))
+    return mean, std
+
+
+def init_params(net):
+    '''Init layer parameters.'''
+    for m in net.modules():
+        if isinstance(m, nn.Conv2d):
+            init.kaiming_normal(m.weight, mode='fan_out')
+            if m.bias:
+                init.constant(m.bias, 0)
+        elif isinstance(m, nn.BatchNorm2d):
+            init.constant(m.weight, 1)
+            init.constant(m.bias, 0)
+        elif isinstance(m, nn.Linear):
+            init.normal(m.weight, std=1e-3)
+            if m.bias:
+                init.constant(m.bias, 0)
+
+
+term_width, _ = get_terminal_size()
+
+TOTAL_BAR_LENGTH = 65.
+last_time = time.time()
+begin_time = last_time
+
+
+def progress_bar(current, total, msg=None):
+    global last_time, begin_time
+    if current == 0:
+        begin_time = time.time()  # Reset for new bar.
+
+    cur_len = int(TOTAL_BAR_LENGTH * current / total)
+    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
+
+    sys.stdout.write(' [')
+    for i in range(cur_len):
+        sys.stdout.write('=')
+    sys.stdout.write('>')
+    for i in range(rest_len):
+        sys.stdout.write('.')
+    sys.stdout.write(']')
+
+    cur_time = time.time()
+    step_time = cur_time - last_time
+    last_time = cur_time
+    tot_time = cur_time - begin_time
+
+    L = []
+    L.append('  Step: %s' % format_time(step_time))
+    L.append(' | Tot: %s' % format_time(tot_time))
+    if msg:
+        L.append(' | ' + msg)
+
+    msg = ''.join(L)
+    sys.stdout.write(msg)
+    for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
+        sys.stdout.write(' ')
+
+    # Go back to the center of the bar.
+    for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
+        sys.stdout.write('\b')
+    sys.stdout.write(' %d/%d ' % (current + 1, total))
+
+    if current < total - 1:
+        sys.stdout.write('\r')
+    else:
+        sys.stdout.write('\n')
+    sys.stdout.flush()
+
+
+def format_time(seconds):
+    days = int(seconds / 3600 / 24)
+    seconds = seconds - days * 3600 * 24
+    hours = int(seconds / 3600)
+    seconds = seconds - hours * 3600
+    minutes = int(seconds / 60)
+    seconds = seconds - minutes * 60
+    secondsf = int(seconds)
+    seconds = seconds - secondsf
+    millis = int(seconds * 1000)
+
+    f = ''
+    i = 1
+    if days > 0:
+        f += str(days) + 'D'
+        i += 1
+    if hours > 0 and i <= 2:
+        f += str(hours) + 'h'
+        i += 1
+    if minutes > 0 and i <= 2:
+        f += str(minutes) + 'm'
+        i += 1
+    if secondsf > 0 and i <= 2:
+        f += str(secondsf) + 's'
+        i += 1
+    if millis > 0 and i <= 2:
+        f += str(millis) + 'ms'
+        i += 1
+    if f == '':
+        f = '0ms'
+    return f

+ 8 - 0
requirements.txt

@@ -0,0 +1,8 @@
+# Requirements automatically generated by pigar.
+# https://github.com/damnever/pigar
+
+pytorch_lightning == 0.4.3
+torch == 1.2.0
+torchsummary == 1.5.1
+torchtest == 0.5
+torchvision == 0.4.0