training.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # import telepot
  2. import pickle
  3. import urllib
  4. # from astropy.io import fits
  5. # from PIL import Image
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. # import torch.optim as optim
  11. import torch.utils.data as data_utils
  12. # import torchvision
  13. # from torchvision import datasets, models, transforms
  14. # from torch.optim import lr_scheduler
  15. # from torchvision.models.resnet import BasicBlock
  16. # from box_convolution import BoxConv2d
  17. # import pretrainedmodels as pm
  18. from kekas import Keker, DataOwner # , DataKek
  19. # from kekas.transformations import Transformer, to_torch, normalize
  20. from kekas.metrics import accuracy # , accuracy_score
  21. # from kekas.modules import Flatten, AdaptiveConcatPool2d
  22. # from kekas.callbacks import Callback, Callbacks, DebuggerCallback
  23. from adabound import AdaBound
  24. from sklearn.utils import class_weight
  25. # from tqdm import tqdm_notebook
  26. # from tg_tqdm import tg_tqdm
  27. import warnings
  28. # import cv2
  29. from utils import SunRegionDataset, Net, step_fn
  30. plt.ion()
  31. warnings.filterwarnings("ignore")
  32. tg_token = 'TOKEN'
  33. tg_chat_id = 1234
  34. ik_chat_id = 1234
  35. sun_group_id = -1234
  36. # define some things
  37. url_pkl = 'https://raw.githubusercontent.com/iknyazeva/FitsProcessing/master/sunspot_1996_2017.pkl'
  38. dataset_folder = 'ALLrescaled/'
  39. path_to_save = ''
  40. logdir = "logs"
  41. lrlogdir = "lrlogs"
  42. checkdir = 'check'
  43. with urllib.request.urlopen(url_pkl) as pkl:
  44. sunspots = pickle.load(pkl)
  45. print(sunspots.tail(5))
  46. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  47. print('\ndevice:', device)
  48. max_value_pixel = 3977.35
  49. batch_size = 64
  50. adam_lr = 0.003
  51. sgd_lr = 0.01
  52. sgd_wd = 0.000
  53. adam_wd = 0.0000
  54. step_size = 30
  55. num_epochs = 100
  56. regions_dataset = SunRegionDataset(path_to_df_pkl=url_pkl, path_to_fits_folder=dataset_folder, height=100, width=100,
  57. only_first_class=True, logarithm=False, max=max_value_pixel)
  58. train_dataset, val_dataset, test_dataset = regions_dataset.split_dataset(0.1, 0.1)
  59. # with open('train_dataset.pkl', 'wb') as train:
  60. # pickle.dump(train_dataset, train)
  61. # with open('val_dataset.pkl', 'wb') as val:
  62. # pickle.dump(val_dataset, val)
  63. # with open('test_dataset.pkl', 'wb') as test:
  64. # pickle.dump(test_dataset, test)
  65. # with open('train_dataset.pkl', 'rb') as train:
  66. # train_dataset = pickle.load(train)
  67. # with open('val_dataset.pkl', 'rb') as val:
  68. # val_dataset = pickle.load(val)
  69. # with open('test_dataset.pkl', 'rb') as test:
  70. # test_dataset = pickle.load(test)
  71. train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  72. val_loader = data_utils.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  73. test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  74. dataowner = DataOwner(train_loader, val_loader, test_loader)
  75. # get weights for classes
  76. label_wts = class_weight.compute_class_weight(
  77. class_weight='balanced', classes=np.unique([class_[0] for class_ in sunspots['class']]), y=[class_[0] for class_ in sunspots['class']])
  78. label_wts = torch.Tensor(label_wts).to(device)
  79. w_criterion = nn.CrossEntropyLoss(weight=label_wts)
  80. criterion = nn.CrossEntropyLoss()
  81. model = Net()
  82. # we use kekas framework for learning (https://github.com/belskikh/kekas/)
  83. keker = Keker(model=model,
  84. dataowner=dataowner,
  85. criterion=w_criterion,
  86. step_fn=step_fn,
  87. target_key="label",
  88. metrics={"acc": accuracy},
  89. # opt=torch.optim.Adam,
  90. # opt=torch.optim.SGD,
  91. # opt_params={"weight_decay": 1e-5}
  92. # opt_params={"momentum": 0.99}
  93. opt=AdaBound,
  94. opt_params={'final_lr': 0.01,
  95. 'weight_decay': 5e-4}
  96. )
  97. keker.freeze(model_attr='net')
  98. keker.kek_one_cycle(max_lr=1e-6,
  99. cycle_len=90,
  100. momentum_range=(0.95, 0.85),
  101. div_factor=10,
  102. increase_fraction=0.3,
  103. logdir=logdir,
  104. cp_saver_params={
  105. "savedir": checkdir,
  106. "metric": "acc",
  107. "n_best": 3,
  108. "prefix": "check",
  109. "mode": "max"
  110. }
  111. )
  112. keker.load(checkdir + '/' + 'check.best.h5')
  113. # FOR FINE TUNE ALL PARAMETERS OF NET
  114. # keker.unfreeze(model_attr='net')
  115. # keker.kek_one_cycle(max_lr=1e-6,
  116. # cycle_len=90,
  117. # momentum_range=(0.95, 0.85),
  118. # div_factor=10,
  119. # increase_fraction=0.3,
  120. # logdir=logdir,
  121. # cp_saver_params={
  122. # "savedir": checkdir,
  123. # "metric": "acc",
  124. # "n_best": 3,
  125. # "prefix": "check",
  126. # "mode": "max"
  127. # }
  128. # )
  129. # keker.load(checkdir + '/' + 'check.best.h5')
  130. keker.predict(savepath="predicts")