utils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import copy
  2. import os
  3. import pickle
  4. import time
  5. import urllib
  6. import warnings
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. # import pretrainedmodels
  10. from astropy.io import fits
  11. # from PIL import Image
  12. from skimage.transform import rescale, resize
  13. # from sklearn.metrics import f1_score
  14. # from sklearn.utils import class_weight
  15. import torch
  16. import torch.nn as nn
  17. import torch.optim as optim
  18. import torch.utils.data as data_utils
  19. import torchvision
  20. import torchvision.transforms as transforms
  21. import albumentations as alb
  22. from torch.optim import lr_scheduler
  23. from torchvision import datasets, models
  24. class SunRegionDataset(data_utils.Dataset):
  25. def __init__(self, path_to_df_pkl, path_to_fits_folder, height, width,
  26. only_first_class=False, transformations=None, logarithm=True, max=None):
  27. """
  28. Args:
  29. path_to_df_pkl (string): path or url to pkl file represents pandas dataframe with labels
  30. path_to_image_folder (string): path to folder with fits
  31. height (int): image height
  32. width (int): image width
  33. only_first_class (bool): create dataset with only one letter represents first layer of Mctosh classes
  34. transformation: pytorch transforms for transforms and tensor conversion
  35. """
  36. if path_to_df_pkl.startswith('http'):
  37. with urllib.request.urlopen(path_to_df_pkl) as pkl:
  38. self.sunspots = pickle.load(pkl)
  39. else:
  40. self.sunspots = pickle.load(path_to_df_pkl)
  41. self.classes = np.asarray(self.sunspots.iloc[:, 2].unique())
  42. self.height = height
  43. self.width = width
  44. self.folder_path, self.dirs, self.files = next(os.walk(path_to_fits_folder))
  45. self.len = len(self.files)
  46. self.ind = list(range(self.len))
  47. self.transformations = transformations
  48. self.alb_transorms = alb.Compose([
  49. alb.RandomRotate90(p=0.1),
  50. alb.Rotate(75, p=0.1),
  51. alb.Resize(224, 224, p=0.1),
  52. alb.RandomCrop(200, 200, p=0.1),
  53. alb.HorizontalFlip(),
  54. # alb.Transpose(),
  55. alb.VerticalFlip(),
  56. alb.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
  57. ], p=0.7)
  58. self.to_tensor = transforms.ToTensor()
  59. self.only_first_class = only_first_class
  60. self.height = height
  61. self.width = width
  62. self.logarithm = logarithm
  63. self.first_classes = set([class_[0] for class_ in self.sunspots['class'].unique()])
  64. self.second_classes = set([class_[1] for class_ in self.sunspots['class'].unique()])
  65. self.third_classes = set([class_[2] for class_ in self.sunspots['class'].unique()])
  66. if max == None:
  67. self.max = self.find_max_dataset()
  68. else:
  69. self.max = max
  70. def __getitem__(self, index):
  71. file_path = os.path.join(self.folder_path, self.files[index])
  72. with fits.open(file_path) as fits_file:
  73. data = fits_file[0].data
  74. if self.transformations is None:
  75. if self.logarithm:
  76. data = self.log_normalize(data)
  77. data = self.normalize_data(data)
  78. # data = resize(data, (self.height, self.width), anti_aliasing=True)
  79. data = self.aug()(image=data)['image']
  80. data = self.to_tensor(data).float() # uncomment for float
  81. # data = data.repeat(3,1,1) # convert to 3 channels to use pretrein models
  82. else:
  83. data = self.transformations(data)
  84. mc_class = self.get_attr_region(self.files[index], self.sunspots, self.only_first_class)
  85. for ind, letter in enumerate(sorted(self.first_classes)):
  86. if letter == mc_class:
  87. num_class = ind
  88. return (data, num_class, mc_class)
  89. def __len__(self):
  90. return self.len
  91. def show_region(self, index):
  92. '''Plot region by index from dataset
  93. index: int, index of sample from dataset
  94. '''
  95. date, region = self.files[index].split('.')[1:3]
  96. file_path = os.path.join(self.folder_path, self.files[index])
  97. with fits.open(file_path) as fits_file:
  98. data = fits_file[0].data
  99. class_, size, location, number_ss = self.get_attr_region(self.files[index],
  100. self.sunspots,
  101. only_first_class=False,
  102. only_class=False)
  103. ax = plt.axes()
  104. ax.set_title(
  105. 'Region {} on date {} with class {} on location {} with size {} and number_of_ss {}'
  106. .format(region, date, class_, location, size, number_ss))
  107. ax.imshow(data)
  108. # ax.annotate((24,12))
  109. def get_attr_region(self, filename, df, only_first_class=False, only_class=True):
  110. date, region = filename.split('.')[1:3]
  111. reg_attr = df.loc[date[:-7], int(region[2:])]
  112. if only_first_class:
  113. return reg_attr['class'][0]
  114. elif (not only_class) and (only_first_class):
  115. class_, \
  116. size, \
  117. location, \
  118. number_ss = reg_attr[['class', 'size', 'location', 'number_of_ss']]
  119. return class_[0], size, location, number_ss
  120. elif (not only_class) and (not only_first_class):
  121. return reg_attr[['class', 'size', 'location', 'number_of_ss']]
  122. else:
  123. return reg_attr['class']
  124. def log_normalize(self, data):
  125. return np.sign(data)*np.log1p(np.abs(data))
  126. def normalize_data(self, data):
  127. return data/self.max
  128. def find_max_dataset(self):
  129. m = []
  130. for file in self.files:
  131. with fits.open(self.folder_path + file) as ff:
  132. m.append(np.nanmax(np.abs(ff[0].data)))
  133. return np.max(m)
  134. def aug(self):
  135. return self.alb_transorms
  136. def split_dataset(self, val_size=None, test_size=None):
  137. '''Spliting dataset in optional test, train, val datasets
  138. test_size (optional): float from 0 to 1.
  139. val_size (optional): float from 0 to 1.
  140. shuffle (optional): bool, for shuffled smaples in datasets
  141. Returns datasets in order (train, valid, test)
  142. '''
  143. len_all = self.len
  144. test_split_size = int(np.floor(test_size * len_all)) if test_size else 0
  145. val_split_size = int(np.floor(val_size * len_all)) if val_size else 0
  146. train_split_size = len_all - test_split_size - val_split_size
  147. return data_utils.random_split(self, [train_split_size, val_split_size, test_split_size])