| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- import copy
- import os
- import pickle
- import time
- import urllib
- import warnings
- import matplotlib.pyplot as plt
- import numpy as np
- # import pretrainedmodels
- from astropy.io import fits
- # from PIL import Image
- from skimage.transform import rescale, resize
- # from sklearn.metrics import f1_score
- # from sklearn.utils import class_weight
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.utils.data as data_utils
- import torchvision
- import torchvision.transforms as transforms
- import albumentations as alb
- from torch.optim import lr_scheduler
- from torchvision import datasets, models
- class SunRegionDataset(data_utils.Dataset):
- def __init__(self, path_to_df_pkl, path_to_fits_folder, height, width,
- only_first_class=False, transformations=None, logarithm=True, max=None):
- """
- Args:
- path_to_df_pkl (string): path or url to pkl file represents pandas dataframe with labels
- path_to_image_folder (string): path to folder with fits
- height (int): image height
- width (int): image width
- only_first_class (bool): create dataset with only one letter represents first layer of Mctosh classes
- transformation: pytorch transforms for transforms and tensor conversion
- """
- if path_to_df_pkl.startswith('http'):
- with urllib.request.urlopen(path_to_df_pkl) as pkl:
- self.sunspots = pickle.load(pkl)
- else:
- self.sunspots = pickle.load(path_to_df_pkl)
- self.classes = np.asarray(self.sunspots.iloc[:, 2].unique())
- self.height = height
- self.width = width
- self.folder_path, self.dirs, self.files = next(os.walk(path_to_fits_folder))
- self.len = len(self.files)
- self.ind = list(range(self.len))
- self.transformations = transformations
- self.alb_transorms = alb.Compose([
- alb.RandomRotate90(p=0.1),
- alb.Rotate(75, p=0.1),
- alb.Resize(224, 224, p=0.1),
- alb.RandomCrop(200, 200, p=0.1),
- alb.HorizontalFlip(),
- # alb.Transpose(),
- alb.VerticalFlip(),
- alb.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
- ], p=0.7)
- self.to_tensor = transforms.ToTensor()
- self.only_first_class = only_first_class
- self.height = height
- self.width = width
- self.logarithm = logarithm
- self.first_classes = set([class_[0] for class_ in self.sunspots['class'].unique()])
- self.second_classes = set([class_[1] for class_ in self.sunspots['class'].unique()])
- self.third_classes = set([class_[2] for class_ in self.sunspots['class'].unique()])
- if max == None:
- self.max = self.find_max_dataset()
- else:
- self.max = max
- def __getitem__(self, index):
- file_path = os.path.join(self.folder_path, self.files[index])
- with fits.open(file_path) as fits_file:
- data = fits_file[0].data
- if self.transformations is None:
- if self.logarithm:
- data = self.log_normalize(data)
- data = self.normalize_data(data)
- # data = resize(data, (self.height, self.width), anti_aliasing=True)
- data = self.aug()(image=data)['image']
- data = self.to_tensor(data).float() # uncomment for float
- # data = data.repeat(3,1,1) # convert to 3 channels to use pretrein models
- else:
- data = self.transformations(data)
- mc_class = self.get_attr_region(self.files[index], self.sunspots, self.only_first_class)
- for ind, letter in enumerate(sorted(self.first_classes)):
- if letter == mc_class:
- num_class = ind
- return (data, num_class, mc_class)
- def __len__(self):
- return self.len
- def show_region(self, index):
- '''Plot region by index from dataset
- index: int, index of sample from dataset
- '''
- date, region = self.files[index].split('.')[1:3]
- file_path = os.path.join(self.folder_path, self.files[index])
- with fits.open(file_path) as fits_file:
- data = fits_file[0].data
- class_, size, location, number_ss = self.get_attr_region(self.files[index],
- self.sunspots,
- only_first_class=False,
- only_class=False)
- ax = plt.axes()
- ax.set_title(
- 'Region {} on date {} with class {} on location {} with size {} and number_of_ss {}'
- .format(region, date, class_, location, size, number_ss))
- ax.imshow(data)
- # ax.annotate((24,12))
- def get_attr_region(self, filename, df, only_first_class=False, only_class=True):
- date, region = filename.split('.')[1:3]
- reg_attr = df.loc[date[:-7], int(region[2:])]
- if only_first_class:
- return reg_attr['class'][0]
- elif (not only_class) and (only_first_class):
- class_, \
- size, \
- location, \
- number_ss = reg_attr[['class', 'size', 'location', 'number_of_ss']]
- return class_[0], size, location, number_ss
- elif (not only_class) and (not only_first_class):
- return reg_attr[['class', 'size', 'location', 'number_of_ss']]
- else:
- return reg_attr['class']
- def log_normalize(self, data):
- return np.sign(data)*np.log1p(np.abs(data))
- def normalize_data(self, data):
- return data/self.max
- def find_max_dataset(self):
- m = []
- for file in self.files:
- with fits.open(self.folder_path + file) as ff:
- m.append(np.nanmax(np.abs(ff[0].data)))
- return np.max(m)
-
- def aug(self):
- return self.alb_transorms
- def split_dataset(self, val_size=None, test_size=None):
- '''Spliting dataset in optional test, train, val datasets
- test_size (optional): float from 0 to 1.
- val_size (optional): float from 0 to 1.
- shuffle (optional): bool, for shuffled smaples in datasets
- Returns datasets in order (train, valid, test)
- '''
- len_all = self.len
- test_split_size = int(np.floor(test_size * len_all)) if test_size else 0
- val_split_size = int(np.floor(val_size * len_all)) if val_size else 0
- train_split_size = len_all - test_split_size - val_split_size
- return data_utils.random_split(self, [train_split_size, val_split_size, test_split_size])
|