utils.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import os
  2. import albumentations as alb
  3. # import telepot
  4. import pickle
  5. import urllib
  6. from astropy.io import fits
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. import torch.utils.data as data_utils
  12. import torchvision.transforms as transforms
  13. # from box_convolution import BoxConv2d
  14. import pretrainedmodels as pm
  15. from kekas.modules import Flatten, AdaptiveConcatPool2d
  16. from skimage.transform import resize # , rescale
  17. from tqdm import tqdm_notebook
  18. # from tg_tqdm import tg_tqdm
  19. # import cv2
  20. class SunRegionDataset(data_utils.Dataset):
  21. def __init__(self, path_to_df_pkl, path_to_fits_folder, height, width,
  22. only_first_class=False, transformations=None, logarithm=True, max=None):
  23. """
  24. Args:
  25. path_to_df_pkl (string): path or url to pkl file represents pandas dataframe with labels
  26. path_to_image_folder (string): path to folder with fits
  27. height (int): image height
  28. width (int): image width
  29. only_first_class (bool): create dataset with only one letter represents first layer of Mctosh classes
  30. transformation: pytorch transforms for transforms and tensor conversion
  31. """
  32. if path_to_df_pkl.startswith('http'):
  33. with urllib.request.urlopen(path_to_df_pkl) as pkl:
  34. self.sunspots = pickle.load(pkl)
  35. else:
  36. self.sunspots = pickle.load(path_to_df_pkl)
  37. self.classes = np.asarray(self.sunspots.iloc[:, 2].unique())
  38. self.height = height
  39. self.width = width
  40. self.folder_path, self.dirs, self.files = next(os.walk(path_to_fits_folder))
  41. self.len = len(self.files)
  42. self.ind = list(range(self.len))
  43. self.transformations = transformations
  44. self.alb_transorms = alb.Compose([
  45. alb.RandomRotate90(p=0.1),
  46. alb.Rotate(75, p=0.1),
  47. # alb.Resize(224, 224, p=1), #default 0.1
  48. # alb.RandomCrop(200, 200, p=0.1),
  49. alb.HorizontalFlip(),
  50. # alb.Transpose(),
  51. alb.VerticalFlip(),
  52. # alb.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
  53. ], p=0.7) # default 0.7
  54. self.to_tensor = transforms.ToTensor()
  55. self.only_first_class = only_first_class
  56. self.height = height
  57. self.width = width
  58. self.logarithm = logarithm
  59. self.first_classes = set([class_[0] for class_ in self.sunspots['class'].unique()])
  60. self.second_classes = set([class_[1] for class_ in self.sunspots['class'].unique()])
  61. self.third_classes = set([class_[2] for class_ in self.sunspots['class'].unique()])
  62. if max is None:
  63. self.max = self.find_max_dataset()
  64. else:
  65. self.max = max
  66. def __getitem__(self, index):
  67. file_path = os.path.join(self.folder_path, self.files[index])
  68. with fits.open(file_path) as fits_file:
  69. data = fits_file[0].data
  70. if self.transformations is None:
  71. if self.logarithm:
  72. data = self.log_normalize(data)
  73. data = self.normalize_data(data)
  74. # data = data.reshape(1, data.shape[0],data.shape[1]).repeat(3, axis=0)
  75. data = resize(data, (self.height, self.width), anti_aliasing=True)
  76. data = self.aug()(image=data)['image'] # augumentation
  77. data = self.to_tensor(data).float() # uncomment for float
  78. data = data.repeat(3, 1, 1)
  79. else:
  80. data = self.transformations(data)
  81. mc_class = self.get_attr_region(self.files[index], self.sunspots, self.only_first_class)
  82. for ind, letter in enumerate(sorted(self.first_classes)):
  83. if letter == mc_class:
  84. num_class = ind
  85. # return (data, num_class, mc_class)
  86. return {"image": data, "label": num_class, "letter_label": mc_class}
  87. def __len__(self):
  88. return self.len
  89. def show_region(self, index):
  90. '''Plot region by index from dataset
  91. index: int, index of sample from dataset
  92. '''
  93. date, region = self.files[index].split('.')[1:3]
  94. file_path = os.path.join(self.folder_path, self.files[index])
  95. with fits.open(file_path) as fits_file:
  96. data = fits_file[0].data
  97. class_, size, location, number_ss = self.get_attr_region(self.files[index],
  98. self.sunspots,
  99. only_first_class=False,
  100. only_class=False)
  101. ax = plt.axes()
  102. ax.set_title(
  103. 'Region {} on date {} with class {} on location {} with size {} and number_of_ss {}'
  104. .format(region, date, class_, location, size, number_ss))
  105. ax.imshow(data)
  106. # ax.annotate((24,12))
  107. def get_attr_region(self, filename, df, only_first_class=False, only_class=True):
  108. '''Get labels for regions
  109. '''
  110. date, region = filename.split('.')[1:3]
  111. reg_attr = df.loc[date[:-7], int(region[2:])]
  112. if only_first_class:
  113. return reg_attr['class'][0]
  114. elif (not only_class) and (only_first_class):
  115. class_, \
  116. size, \
  117. location, \
  118. number_ss = reg_attr[['class', 'size', 'location', 'number_of_ss']]
  119. return class_[0], size, location, number_ss
  120. elif (not only_class) and (not only_first_class):
  121. return reg_attr[['class', 'size', 'location', 'number_of_ss']]
  122. else:
  123. return reg_attr['class']
  124. def log_normalize(self, data):
  125. return np.sign(data) * np.log1p(np.abs(data))
  126. def normalize_data(self, data):
  127. return data / self.max
  128. def find_max_dataset(self):
  129. '''Find max value of pixels over all dataset
  130. '''
  131. m = []
  132. print('find max all over dataset')
  133. for file in tqdm_notebook(self.files):
  134. with fits.open(self.folder_path + file) as ff:
  135. m.append(np.nanmax(np.abs(ff[0].data)))
  136. return np.max(m)
  137. def aug(self):
  138. return self.alb_transorms
  139. def split_dataset(self, val_size=None, test_size=None):
  140. '''Spliting dataset in optional test, train, val datasets
  141. test_size (optional): float from 0 to 1.
  142. val_size (optional): float from 0 to 1.
  143. Returns datasets in order (train, valid, test)
  144. '''
  145. len_all = self.len
  146. test_split_size = int(np.floor(test_size * len_all)) if test_size else 0
  147. val_split_size = int(np.floor(val_size * len_all)) if val_size else 0
  148. train_split_size = len_all - test_split_size - val_split_size
  149. return data_utils.random_split(self, [train_split_size, val_split_size, test_split_size])
  150. class Net(nn.Module):
  151. def __init__(
  152. self,
  153. num_classes: int = 7,
  154. p: float = 0.2,
  155. pooling_size: int = 2,
  156. last_conv_size: int = 1664,
  157. arch: str = "densenet169",
  158. pretrained: str = "imagenet") -> None:
  159. """A model to finetune.
  160. Args:
  161. num_classes: the number of target classes, the size of the last layer's output
  162. p: dropout probability
  163. pooling_size: the size of the result feature map after adaptive pooling layer
  164. last_conv_size: size of the flatten last backbone conv layer
  165. arch: the name of the architecture form pretrainedmodels
  166. pretrained: the mode for pretrained model from pretrainedmodels
  167. """
  168. super().__init__()
  169. net = pm.__dict__[arch](pretrained=None)
  170. modules = list(net.children())[:-1] # delete last layer
  171. # add custom head
  172. modules += [nn.Sequential(
  173. # AdaptiveConcatPool2d is a concat of AdaptiveMaxPooling and AdaptiveAveragePooling
  174. AdaptiveConcatPool2d(size=pooling_size),
  175. Flatten(),
  176. nn.BatchNorm1d(13312),
  177. nn.Dropout(p),
  178. nn.Linear(13312, num_classes)
  179. )]
  180. self.net = nn.Sequential(*modules)
  181. def forward(self, x):
  182. logits = self.net(x)
  183. return logits
  184. def step_fn(model: torch.nn.Module,
  185. batch: torch.Tensor) -> torch.Tensor:
  186. """Determine what your model will do with your data.
  187. Args:
  188. model: the pytorch module to pass input in
  189. batch: the batch of data from the DataLoader
  190. Returns:
  191. The models forward pass results
  192. """
  193. inp = batch["image"] # here we get an "image" from our dataset
  194. return model(inp)