Kaynağa Gözat

and be the ligth

metya 6 yıl önce
ebeveyn
işleme
a7fe0ea5c2
3 değiştirilmiş dosya ile 1024 ekleme ve 0 silme
  1. 631 0
      fits_parse.py
  2. 167 0
      training.py
  3. 226 0
      utils.py

+ 631 - 0
fits_parse.py

@@ -0,0 +1,631 @@
+import re
+import csv
+import logging
+import math
+import glob
+# import argparse
+import numpy as np
+import os
+import pandas as pd
+import time
+import datetime
+import drms
+import urllib
+# import json
+import matplotlib.pyplot as plt
+import matplotlib.patches as patches
+import astropy.units as u
+import telegram_handler
+# import warnings
+import sunpy.wcs
+import sunpy.map
+import pickle
+import telepot
+from colorlog import ColoredFormatter
+from astropy.coordinates import SkyCoord
+# from astropy.io import fits
+# from astropy.time import Time
+# from datetime import timedelta
+# from sunpy.coordinates import frames
+# from astropy.coordinates import SkyCoord
+from tg_tqdm import tg_tqdm
+# from tqdm import tqdm
+# warnings.filterwarnings("ignore")
+
+
+# define constants
+EMAIL = 'emal@email.ru'
+SAVE_PATH = 'dataset'
+tg_bot_token = 'TOKEN'
+tm_chat_id = 1234
+ik_chat_id = 1234
+sun_group_id = -1234
+DATE_DELIMIT = '2010-06-28'
+TG_LOGGER = False
+FILE_DELETE = False
+LOGGER_LEVEL = logging.WARNING
+# LOGGER_LEVEL = logging.DEBUG
+VERBOSE = True
+PERIOD = 300
+START_DATE = '1996-04-01'
+CROP_DATE = '2017-11-01'
+SLEEP = 0.1
+PROGRESS = 10
+
+
+# logging.basicConfig(filename='futs_parse.log', level=logging.INFO)
+
+
+def set_logger(level=logging.WARNING, name='logger', telegram=False):
+    """Return a logger with a default ColoredFormatter."""
+
+    file_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(funcName)s - %(message)s")
+    stream_formatter = ColoredFormatter(
+        "%(asctime)s [%(log_color)s%(levelname)-8s%(reset)s: %(funcName)s] %(white)s%(message)s",
+        datefmt=None,
+        reset=True,
+        log_colors={
+            'DEBUG': 'cyan',
+            'INFO': 'green',
+            'WARNING': 'yellow',
+            'ERROR': 'red',
+            'CRITICAL': 'red',
+        }
+    )
+
+    logger = logging.getLogger(name)
+    stream_handler = logging.StreamHandler()
+    stream_handler.setFormatter(stream_formatter)
+    log_handler = logging.FileHandler("fits_parse.log")
+    log_handler.setFormatter(file_formatter)
+    logger.addHandler(stream_handler)
+    logger.addHandler(log_handler)
+
+    if telegram:
+        tg_handler = telegram_handler.TelegramHandler(tg_bot_token, sun_group_id)
+        tg_formatter = telegram_handler.HtmlFormatter()
+        tg_handler.setFormatter(tg_formatter)
+        logger.addHandler(tg_handler)
+
+    logger.setLevel(level)
+
+    return logger
+
+
+logger = set_logger(level=LOGGER_LEVEL, name='sun_logger', telegram=TG_LOGGER)
+
+
+def check_dataset_directory():
+
+    if not os.path.exists('HMIdataset/fragments'):
+        logger.warning('HMIdataset folders not exist, create them')
+        os.makedirs('HMIdataset/fragments')
+
+    if not os.path.exists('MDIdataset/fragments'):
+        logger.warning('MDIdataset folders not exist, create them')
+        os.makedirs('MDIdataset/fragments')
+
+    return True
+
+
+def clean_folder(path):
+    for file in os.listdir(path):
+        file_path = os.path.join(path, file)
+    if os.path.isfile(file_path):
+        os.remove(file_path)
+
+    return True
+
+
+def message_of_start(token=tg_bot_token, id=sun_group_id):
+    bot = telepot.Bot(token)
+    bot.sendMessage(id, 'Start parsing fits on remote server')
+
+
+def message_of_start_cropping(token=tg_bot_token, id=sun_group_id):
+    bot = telepot.Bot(token)
+    bot.sendMessage(id, '-' * 30)
+    bot.sendMessage(id, 'Start cropping regions')
+    bot.sendMessage(id, '-' * 30)
+
+
+def hook_for_download_fits(t):
+    """Wraps tqdm instance.
+    Don't forget to close() or __exit__()
+    the tqdm instance once you're done with it (easiest using `with` syntax).
+    Example
+    -------
+    >>> with tqdm(...) as t:
+    ...     reporthook = my_hook(t)
+    ...     urllib.urlretrieve(..., reporthook=reporthook)
+    """
+    last_b = [0]
+
+    def update_to(b=1, bsize=1, tsize=None):
+        """
+        b  : int, optional
+            Number of blocks transferred so far [default: 1].
+        bsize  : int, optional
+            Size of each block (in tqdm units) [default: 1].
+        tsize  : int, optional
+            Total size (in tqdm units). If [default: None] remains unchanged.
+        """
+        if tsize is not None:
+            t.total = tsize
+        t.update((b - last_b[0]) * bsize)
+        last_b[0] = b
+
+    return update_to
+
+
+def request_mfits_by_date_MDI(moment, email=EMAIL, path_to_save='MDIdataset', verbose=False):
+    """
+    Function for request fits from JSOC database
+    moment: pd.datetime object
+    return: filepath to the magnetogram
+    """
+
+    filename = 'mdi.fd_m_96m_lev182.' + moment.strftime('%Y%m%d_%H%M%S_TAI.data.fits')
+    filepath = os.path.join(path_to_save, filename)
+
+    if os.path.exists(filepath):
+        pass
+    else:
+
+        c = drms.Client(email=email, verbose=verbose)
+        str_for_query = 'mdi.fd_M_96m_lev182' + moment.strftime('[%Y.%m.%d_%H:%M:%S_TAI]')
+        logger.info('Magnetogram: {} will be downloaded ... '.format(str_for_query))
+        r = c.export(str_for_query, method='url', protocol='fits')
+        logger.debug(r)
+
+        try:
+            r.wait()
+            logger.info(r.request_url)
+        except Exception as e:
+            logger.warning('Can not wait anymore, skip this. Get Exception: {}'.format(e))
+
+        try:
+            logger.info("Download data and save to path {}".format(filepath))
+            r.download(path_to_save, verbose=verbose)
+        except Exception as e:
+            logger.error('Get error while trying download: {}'.format(e))
+            logger.warning('Skip this date')
+
+    return filepath
+
+
+def request_batch_mfits_by_date(moment,
+                                period_of_days=30, email=EMAIL,
+                                path_to_save='dataset',
+                                verbose=False,
+                                type_mag='MDI',
+                                token=tg_bot_token,
+                                chat_id=sun_group_id):
+    '''Request batch fits for a period of days and return:
+    request url
+    period of days that was apply
+    first date of batch
+    last date of batch
+    '''
+
+    c = drms.Client(email=email, verbose=verbose)
+
+    def set_str_for_query(period_of_days=period_of_days):
+        if type_mag == 'MDI':
+            str_for_query = 'mdi.fd_M_96m_lev182' + moment.strftime('[%Y.%m.%d_%H:%M:%S_TAI/{}d@24h]'.format(period_of_days))
+            filename_to_check = 'mdi.fd_m_96m_lev182.' + moment.strftime('%Y%m%d_%H%M%S_TAI.data.fits')
+            path_to_save = 'MDIdataset'
+        if type_mag == 'HMI':
+            str_for_query = 'hmi.m_720s' + moment.strftime('[%Y.%m.%d_%H:%M:%S_TAI/{}d@24h]'.format(period_of_days))
+            path_to_save = 'HMIdataset'
+            filename_to_check = 'hmi.m_720s.' + moment.strftime('%Y%m%d_%H%M%S_TAI.magnetogram.fits')
+
+        return str_for_query, path_to_save, filename_to_check
+
+    str_for_query, path_to_save, filename_to_check = set_str_for_query()
+    logger.debug('{}\n{}\n{}'.format(str_for_query, path_to_save, filename_to_check))
+    if os.path.exists(os.path.join(path_to_save, filename_to_check)):
+        period_of_days = 10
+        logger.info('Files already exists. Skip downloads this batch size of {}'.format(period_of_days))
+        return None, period_of_days, moment, moment + datetime.timedelta(days=period_of_days), period_of_days
+
+    logger.info('Magnetogram: {} will be downloaded ... '.format(str_for_query))
+
+    r = c.export(str_for_query, protocol='fits')
+    logger.debug(r)
+    logger.debug(r.has_failed())
+
+    treshold = round(math.log(period_of_days) ** 2 / 2)
+    while r.has_failed():
+        period_of_days -= round(treshold)
+        if period_of_days < round(treshold / 2):
+            logger.warning('Period of days is too small, skip this request to 10 days')
+            logger.warning('Export request was {}: '.format(str_for_query))
+            period_of_days = 10
+            return None, period_of_days, moment, moment + datetime.timedelta(days=period_of_days), period_of_days
+        time.sleep(1)
+        logger.info('Export request has failed. Reduce number of days in it on {}. Now days in request {}'.format(int(treshold), period_of_days))
+        str_for_query, _, _ = set_str_for_query(period_of_days=period_of_days)
+        logger.debug('Request string: {}'.format(str_for_query))
+        r = c.export(str_for_query, protocol='fits')
+
+    logger.debug(r)
+    logger.debug(len(r.data))
+
+    try:
+        r.wait(sleep=10, retries_notfound=10)
+    except Exception as e:
+        logger.error('Can not wait anymore, skip this. Get Exception: {}'.format(e))
+
+    logger.info("Download data and save to path {}".format(path_to_save))
+
+    first_date_batch = r.urls[0:]['record'].values[0].replace('[', ' ').split()[1].split('_')[0].replace('.', '-')
+    last_date_batch = r.urls[-1:]['record'].values[0].replace('[', ' ').split()[1].split('_')[0].replace('.', '-')
+
+    with tg_tqdm(r.urls.index, token=token, chat_id=chat_id, desc='DOWNLOAD BATCH',
+                 postfix='start_date = {}, end_date = {}'.format(first_date_batch, last_date_batch)) as batch_d:
+        for ind in batch_d:
+            try:
+                # file_name = '.'.join(r.urls.filename[ind].split('.')[:3] + r.urls.filename[ind].split('.')[4:])
+                urllib.request.urlretrieve(r.urls.url[ind], os.path.join(path_to_save, r.urls.filename[ind]))
+            except Exception as e:
+                logger.error('Get error while trying download {}: {}'.format(r.urls.url[ind], repr(e)))
+                logger.warning('Skip this file')
+
+    len_batch = len(r.urls)
+
+    return r.request_url, period_of_days, first_date_batch, last_date_batch, len_batch
+
+
+def request_mfits_by_date_HMI(moment, email=EMAIL, path_to_save='HMIdataset', verbose=False):
+    """
+    Function for request fits from JSOC database
+    moment: pd.datetime object
+    return: filepath to the magnetogram
+    """
+
+    filename = 'hmi.m_720s.' + moment.strftime('%Y%m%d_%H%M%S_TAI.magnetogram.fits')
+    filepath = os.path.join(path_to_save, filename)
+
+    if os.path.exists(filepath):
+        pass
+    else:
+
+        c = drms.Client(email=email, verbose=verbose)
+        str_for_query = 'hmi.m_720s' + moment.strftime('[%Y.%m.%d_%H:%M:%S_TAI]{magnetogram}')
+        logger.info('Magnetogram: {} will be downloaded ... '.format(str_for_query))
+        r = c.export(str_for_query, method='url', protocol='fits')
+        logger.debug(r)
+
+        try:
+            r.wait()
+            logger.info(r.request_url)
+        except Exception as e:
+            logger.warning('Can not wait anymore, skip this. Get Exception: {}'.format(e))
+
+        try:
+            logger.info("Download data and save to path {}".format(filepath))
+            r.download(path_to_save, verbose=verbose)
+        except Exception as e:
+            logger.error('Get error while trying download: {}'.format(e))
+            logger.warning('Skip this date')
+
+    return filepath
+
+
+def read_fits_to_map(filepath, plot_show=False, ln=False):
+    """
+    read fits to sunpy object and plot in logariphmic scale
+    return
+    mymap: sunpy object
+    """
+
+    mymap = sunpy.map.Map(filepath)
+
+    if plot_show:
+        plt.figure(figsize=(12, 12))
+        if ln:
+            data = np.sign(mymap.data) * np.log1p(np.abs(mymap.data))
+        data = mymap.data
+        plt.imshow(data, cmap='gray')
+
+    return mymap
+
+
+def region_coord_list(datestr, sunspots_df, limit_deg=45):
+    """
+    Function for working with sunspot_1996_2017.pkl dataframe,
+    return list of tuples: (datestr, NOAA number, location)
+    used in cropping
+
+    args:
+    datestr: string for date in the format used in dataframe '2001-04-30'
+    sunspots_df: dataframe from file sunspot_1996_2017.pkl
+
+    return: list of tuples
+    """
+
+    date_df = sunspots_df.loc[datestr]
+    date_df.index = date_df.index.droplevel()
+    rc_list = []
+    for index, row in date_df.iterrows():
+        try:
+            restriction_degree = (abs(float(row.location[1:3]) <= limit_deg)) and (abs(float(row.location[4:])) <= limit_deg)
+            if restriction_degree:
+                rc_list.append((pd.to_datetime(datestr, format='%Y-%m-%d'), index, row.location))
+        except ValueError as e:
+            if TG_LOGGER:
+                time.sleep(SLEEP)
+            logger.warning('Some error with read location {} in degree for date {}: {}'.format(row.location, datestr, e))
+        except Exception as e:
+            if TG_LOGGER:
+                time.sleep(SLEEP)
+            logger.error('Some error with read location {} in degree for date {}: {}'.format(row.location, datestr, e))
+
+    return rc_list
+
+
+def return_pixel_from_map(mag_map, record, limit_deg=45):
+    '''
+    convert lon lat coordinate to coordinate in pixel in sun map and return it
+    '''
+
+    pattern = re.compile("[NS]\d{2}[EW]\d{2}")
+    assert bool(pattern.match(record)), 'Pattern should be in the same format as N20E18'
+    assert (abs(float(record[1:3]) <= limit_deg)) and (abs(float(record[4:])) <= limit_deg), 'Consider only regions between -{}, +{} degree'.format(limit_deg)
+    if record[0] == 'N':
+        lat = float(record[1:3])
+    else:
+        lat = -float(record[1:3])
+    if record[3] == 'W':
+        lon = float(record[4:])
+    else:
+        lon = -float(record[4:])
+
+    hpc_coord = sunpy.wcs.convert_hg_hpc(lon, lat, b0_deg=mag_map.meta['crlt_obs'])
+    coord = SkyCoord(hpc_coord[0] * u.arcsec, hpc_coord[1] * u.arcsec, frame=mag_map.coordinate_frame)
+    # pixel_pos = mag_map.world_to_pixel(coord)
+    pixel_pos = mag_map.world_to_pixel(coord) * u.pixel
+    # pixel_pos = pixel_pos.to_value()
+
+    return pixel_pos
+
+
+def crop_regions(mag_map, rc_list, type_mag, delta=100, plot_rec=False, plot_crop=False, limit_deg=45, ln=False):
+    '''
+    Crop region by size delta and save it to disk,
+    if plot_rec, plot rectangle of regions on disk,
+    if plot_crop, plot only crop regions
+    '''
+    if ln:
+        data = np.sign(mag_map.data) * np.log1p(np.abs(mag_map.data))
+    data = mag_map.data
+
+    if type_mag == 'MDI':
+        delta = 100
+    if type_mag == 'HMI':
+        delta = 200
+
+    if plot_rec:
+        fig, ax = plt.subplots(1, figsize=(12, 12))
+        ax.matshow(data)
+        plt.gray()
+        ax.set_title('{} magnetogram at '.format(type_mag) + rc_list[0][0].strftime('%Y-%m-%d %H:%M'))
+
+        for record in rc_list:
+            try:
+                pxs = return_pixel_from_map(mag_map, record[2], limit_deg).to_value()
+            except Exception as e:
+                logger.error('Some error with get pixel coordinates from map: {}. Skip it'.format(e))
+                continue
+            rect = patches.Rectangle((pxs[0] - 1.25 * delta, pxs[1] - delta), 2.5 * delta, 2 * delta, linewidth=3, edgecolor='r', facecolor='none')
+            ax.add_patch(rect)
+            ax.annotate('{}.AR'.format(type_mag) + str(record[1]), xy=(pxs[0], pxs[1]), xytext=(pxs[0], pxs[1] - 50), color='yellow', fontsize='xx-large')
+
+        plt.show()
+
+    submaps = []
+    for record in rc_list:
+
+        filename = '{}.{}.AR{}.fits'.format(type_mag, record[0].strftime('%Y-%m-%d_%H%M%S'), record[1])
+        filepath = os.path.join('{}dataset/fragments'.format(type_mag), filename)
+        try:
+            pxs = return_pixel_from_map(mag_map, record[2], limit_deg)
+        except Exception as e:
+            logger.error('Some error with get pixel coordinates from map: {}. Skip it'.format(e))
+            continue
+        bot_l = [pxs[0] - delta * 1.25 * u.pixel, pxs[1] - delta * u.pixel]
+        top_r = [pxs[0] + delta * 1.25 * u.pixel, pxs[1] + delta * u.pixel]
+
+        submap = mag_map.submap(bot_l * u.pixel, top_r * u.pixel)
+
+        if plot_crop:
+            submap.peek()
+
+        try:
+            submap.save(filepath)
+        except Exception as e:
+            if TG_LOGGER:
+                time.sleep(SLEEP)
+            logger.info('Could not save fits {} cause: {}. Skip it'.format(filename, e))
+
+        submaps.append(submap)
+
+    return submaps
+
+
+def date_compare(date):
+    return date < datetime.datetime.fromtimestamp(time.mktime(time.strptime(DATE_DELIMIT, '%Y-%m-%d')))
+
+
+if __name__ == '__main__':
+
+    check_dataset_directory()
+    message_of_start()
+
+    try:
+        sunspots = pickle.load(urllib.request.urlopen('https://raw.githubusercontent.com/iknyazeva/FitsProcessing/master/sunspot_1996_2017.pkl'))
+        logger.info('Load sunspot dataframe is successful!')
+    except Exception as e:
+        logger.error('Can not load sunspot dataframe, halt parsing! Get Exception: {}'.format(e))
+        raise(e)
+
+    requests_urls = []
+    if START_DATE:
+        try:
+            start_moment = sunspots[(sunspots.index.get_level_values(0) > START_DATE)].index.get_level_values(0)[0]
+        except IndexError as e:
+            logger.info('Index out of bound. Possibly the table is ended: {}'.format(e))
+            start_moment = START_DATE
+        except Exception as e:
+            logger.error('Some error then get start_moment for first iteration: {}'.format(e))
+    else:
+        start_moment = sunspots.index.get_level_values(0)[0]
+    logger.debug(start_moment)
+    count_of_days_left = len(sunspots[(sunspots.index.get_level_values(0) >= start_moment)].groupby(level=0))
+    logger.debug(count_of_days_left)
+
+    with tg_tqdm(sunspots[(sunspots.index.get_level_values(0) > start_moment)].groupby(level=0),
+                 token=tg_bot_token, chat_id=sun_group_id, desc='MAIN PROGRESS DOWNLOAD') as tgm:
+        number_batch = 1
+        while count_of_days_left > 0:
+            tgm.set_postfix(batch=number_batch)
+            if date_compare(start_moment):
+                request_url,\
+                    period_of_days,\
+                    first_date_batch,\
+                    last_date_batch,\
+                    len_batch = request_batch_mfits_by_date(start_moment, period_of_days=PERIOD,
+                                                            email=EMAIL, type_mag='MDI', verbose=VERBOSE)
+            else:
+                request_url,\
+                    period_of_days,\
+                    first_date_batch,\
+                    last_date_batch,\
+                    len_batch = request_batch_mfits_by_date(start_moment, period_of_days=PERIOD,
+                                                            email=EMAIL, type_mag='HMI', verbose=VERBOSE)
+
+            logger.debug('Returned period of days {}'.format(period_of_days))
+            # requests_urls.append(request_url)
+            try:
+                start_moment = sunspots[(sunspots.index.get_level_values(0) > last_date_batch)].index.get_level_values(0)[0]
+            except IndexError as e:
+                logger.info('Index out of bound. Possibly the table is ended: {}'.format(e))
+            except Exception as e:
+                logger.error('Some error then get start_moment for next iteration: {}'.format(e))
+            count_of_days_left = len(sunspots[(sunspots.index.get_level_values(0) >= start_moment)])
+            number_batch += 1
+
+            with open('requests_urls.csv', 'a', newline='') as file:
+                csv.writer(file).writerow(request_url)
+
+            tgm.update(len_batch)
+
+    # with open('requests_urls.csv', 'w') as file:
+    #     csv.writer(file, delimiter='\n').writerow(requests_urls)
+
+    message_of_start_cropping()
+
+    if CROP_DATE:
+        crop_df = sunspots[(sunspots.index.get_level_values(0) > CROP_DATE)]
+    else:
+        crop_df = sunspots
+    with tg_tqdm(range(1), tg_bot_token, sun_group_id,
+                 total=len(crop_df.groupby(level=0)), desc='CROPPING PROGRESS') as tgt:
+
+        def is_progress(acc, total, progress=PROGRESS, tqdm_instanse=tgt):
+            if (acc % PROGRESS == 0):
+                logger.debug('In if acc = {}'.format(acc))
+                time.sleep(SLEEP)
+                tgt.update(PROGRESS)
+            elif (acc >= total):
+                logger.debug('In if acc = {}'.format(acc))
+                time.sleep(SLEEP)
+                tgt.update(total % PROGRESS)
+
+            return True
+
+        acc = 0
+        total = len(crop_df.groupby(level=0))
+        logger.debug(total)
+        for date, df in crop_df.groupby(level=0):
+
+            rc_list = region_coord_list(str(date), df, limit_deg=45)
+
+            if not rc_list:
+                acc += 1
+                time.sleep(SLEEP)
+                is_progress(acc, total)
+                logger.debug('rc_list is empty - {}, acc = {}'.format(rc_list, acc))
+                continue
+
+            if date_compare(date):
+                filename = 'mdi.fd_m_96m_lev182.' + date.strftime('%Y%m%d_%H%M%S_TAI') + '*.fits'
+                path = 'MDIdataset/'
+                try:
+                    filepath = glob.glob(path + filename)[0]
+                    if TG_LOGGER:
+                        time.sleep(SLEEP)
+                    logger.debug('filepath: {}'.format(filepath))
+                except IndexError as e:
+                    logger.info('File with this date {} is not exist'.format(str(date)))
+                    acc += 1
+                    is_progress(acc, total)
+                    continue
+                except Exception as e:
+                    logger.error('Some error with glob:'.format(e))
+                    acc += 1
+                    is_progress(acc, total)
+                    continue
+                type_mag = 'MDI'
+
+            else:
+                filename = 'hmi.m_720s.' + date.strftime('%Y%m%d_%H%M%S_TAI') + '*.fits'
+                path = 'HMIdataset/'
+                try:
+                    filepath = glob.glob(path + filename)[0]
+                    if TG_LOGGER:
+                        time.sleep(SLEEP)
+                    logger.debug('filepath: {}'.format(filepath))
+                except IndexError as e:
+                    if TG_LOGGER:
+                        time.sleep(SLEEP)
+                    logger.info('File with this date {} is not exist'.format(str(date)))
+                    acc += 1
+                    is_progress(acc, total)
+                    continue
+                except Exception as e:
+                    if TG_LOGGER:
+                        time.sleep(SLEEP)
+                    logger.error('Some error with glob:'.format(e))
+                    acc += 1
+                    is_progress(acc, total)
+                    continue
+                type_mag = 'HMI'
+
+            try:
+                sun_map = read_fits_to_map(filepath, plot_show=False)
+                crop_regions(sun_map, rc_list, plot_rec=False, plot_crop=False, type_mag=type_mag)
+            except ValueError as e:
+                if TG_LOGGER:
+                    time.sleep(SLEEP)
+                logger.info('Get Exception while reading: {}'.format(e))
+                logger.info('Doing active farther, skip it.')
+                # acc += 1
+                # continue
+            except Exception as e:
+                if TG_LOGGER:
+                    time.sleep(SLEEP)
+                logger.error('Get Exception while reading: {}'.format(e))
+                logger.warning('Doing active farther, skip it.') 
+               # acc += 1
+                # continue
+
+            # tgt.update()
+            acc += 1
+            logger.debug('acc = {}'.format(acc))
+            is_progress(acc, total)
+
+    if FILE_DELETE:
+        clean_folder('MDIdataset')
+        clean_folder('HMIdataset')

+ 167 - 0
training.py

@@ -0,0 +1,167 @@
+# import telepot
+import pickle
+import urllib
+
+# from astropy.io import fits
+# from PIL import Image
+import matplotlib.pyplot as plt
+import numpy as np
+
+import torch
+import torch.nn as nn
+# import torch.optim as optim
+import torch.utils.data as data_utils
+# import torchvision
+
+# from torchvision import datasets, models, transforms
+# from torch.optim import lr_scheduler
+# from torchvision.models.resnet import BasicBlock
+# from box_convolution import BoxConv2d
+# import pretrainedmodels as pm
+
+from kekas import Keker, DataOwner  # , DataKek
+# from kekas.transformations import Transformer, to_torch, normalize
+from kekas.metrics import accuracy  # , accuracy_score
+# from kekas.modules import Flatten, AdaptiveConcatPool2d
+# from kekas.callbacks import Callback, Callbacks, DebuggerCallback
+
+from adabound import AdaBound
+
+from sklearn.utils import class_weight
+# from tqdm import tqdm_notebook
+# from tg_tqdm import tg_tqdm
+
+import warnings
+# import cv2
+
+from utils import SunRegionDataset, Net, step_fn
+
+plt.ion()
+warnings.filterwarnings("ignore")
+
+tg_token = 'TOKEN'
+tg_chat_id = 1234
+ik_chat_id = 1234
+sun_group_id = -1234
+
+# define some things
+url_pkl = 'https://raw.githubusercontent.com/iknyazeva/FitsProcessing/master/sunspot_1996_2017.pkl'
+dataset_folder = 'ALLrescaled/'
+path_to_save = ''
+
+logdir = "logs"
+lrlogdir = "lrlogs"
+checkdir = 'check'
+
+with urllib.request.urlopen(url_pkl) as pkl:
+    sunspots = pickle.load(pkl)
+
+print(sunspots.tail(5))
+device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+print('\ndevice:', device)
+
+
+max_value_pixel = 3977.35
+batch_size = 64
+adam_lr = 0.003
+sgd_lr = 0.01
+sgd_wd = 0.000
+adam_wd = 0.0000
+step_size = 30
+num_epochs = 100
+
+
+regions_dataset = SunRegionDataset(path_to_df_pkl=url_pkl, path_to_fits_folder=dataset_folder, height=100, width=100,
+                                   only_first_class=True, logarithm=False, max=max_value_pixel)
+
+train_dataset, val_dataset, test_dataset = regions_dataset.split_dataset(0.1, 0.1)
+
+# with open('train_dataset.pkl', 'wb') as train:
+#     pickle.dump(train_dataset, train)
+# with open('val_dataset.pkl', 'wb') as val:
+#     pickle.dump(val_dataset, val)
+# with open('test_dataset.pkl', 'wb') as test:
+#     pickle.dump(test_dataset, test)
+
+# with open('train_dataset.pkl', 'rb') as train:
+#     train_dataset = pickle.load(train)
+# with open('val_dataset.pkl', 'rb') as val:
+#     val_dataset = pickle.load(val)
+# with open('test_dataset.pkl', 'rb') as test:
+#     test_dataset = pickle.load(test)
+
+
+train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
+val_loader = data_utils.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
+test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
+
+dataowner = DataOwner(train_loader, val_loader, test_loader)
+
+# get weights for classes
+label_wts = class_weight.compute_class_weight(
+    class_weight='balanced', classes=np.unique([class_[0] for class_ in sunspots['class']]), y=[class_[0] for class_ in sunspots['class']])
+
+label_wts = torch.Tensor(label_wts).to(device)
+
+w_criterion = nn.CrossEntropyLoss(weight=label_wts)
+criterion = nn.CrossEntropyLoss()
+model = Net()
+
+# we use kekas framework for learning (https://github.com/belskikh/kekas/)
+keker = Keker(model=model,
+              dataowner=dataowner,
+              criterion=w_criterion,
+              step_fn=step_fn,
+              target_key="label",
+              metrics={"acc": accuracy},
+              # opt=torch.optim.Adam,
+              # opt=torch.optim.SGD,
+              # opt_params={"weight_decay": 1e-5}
+              # opt_params={"momentum": 0.99}
+              opt=AdaBound,
+              opt_params={'final_lr': 0.01,
+                          'weight_decay': 5e-4}
+              )
+
+keker.freeze(model_attr='net')
+
+keker.kek_one_cycle(max_lr=1e-6,
+                    cycle_len=90,
+                    momentum_range=(0.95, 0.85),
+                    div_factor=10,
+                    increase_fraction=0.3,
+                    logdir=logdir,
+                    cp_saver_params={
+                        "savedir": checkdir,
+                        "metric": "acc",
+                        "n_best": 3,
+                        "prefix": "check",
+                        "mode": "max"
+                    }
+                    )
+
+
+keker.load(checkdir + '/' + 'check.best.h5')
+
+# FOR FINE TUNE ALL PARAMETERS OF NET
+
+# keker.unfreeze(model_attr='net')
+
+# keker.kek_one_cycle(max_lr=1e-6,
+#                     cycle_len=90,
+#                     momentum_range=(0.95, 0.85),
+#                     div_factor=10,
+#                     increase_fraction=0.3,
+#                     logdir=logdir,
+#                     cp_saver_params={
+#                         "savedir": checkdir,
+#                         "metric": "acc",
+#                         "n_best": 3,
+#                         "prefix": "check",
+#                         "mode": "max"
+#                     }
+#                     )
+
+# keker.load(checkdir + '/' + 'check.best.h5')
+
+keker.predict(savepath="predicts")

+ 226 - 0
utils.py

@@ -0,0 +1,226 @@
+import os
+import albumentations as alb
+# import telepot
+import pickle
+import urllib
+
+from astropy.io import fits
+import matplotlib.pyplot as plt
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.utils.data as data_utils
+import torchvision.transforms as transforms
+# from box_convolution import BoxConv2d
+import pretrainedmodels as pm
+
+from kekas.modules import Flatten, AdaptiveConcatPool2d
+
+from skimage.transform import resize  # , rescale
+
+from tqdm import tqdm_notebook
+# from tg_tqdm import tg_tqdm
+# import cv2
+
+
+class SunRegionDataset(data_utils.Dataset):
+    def __init__(self, path_to_df_pkl, path_to_fits_folder, height, width,
+                 only_first_class=False, transformations=None, logarithm=True, max=None):
+        """
+        Args:
+            path_to_df_pkl (string): path or url to pkl file represents pandas dataframe with labels
+            path_to_image_folder (string): path to folder with fits
+            height (int): image height
+            width (int): image width
+            only_first_class (bool): create dataset with only one letter represents first layer of Mctosh classes
+            transformation: pytorch transforms for transforms and tensor conversion
+        """
+        if path_to_df_pkl.startswith('http'):
+            with urllib.request.urlopen(path_to_df_pkl) as pkl:
+                self.sunspots = pickle.load(pkl)
+        else:
+            self.sunspots = pickle.load(path_to_df_pkl)
+        self.classes = np.asarray(self.sunspots.iloc[:, 2].unique())
+        self.height = height
+        self.width = width
+        self.folder_path, self.dirs, self.files = next(os.walk(path_to_fits_folder))
+        self.len = len(self.files)
+        self.ind = list(range(self.len))
+        self.transformations = transformations
+        self.alb_transorms = alb.Compose([
+                                         alb.RandomRotate90(p=0.1),
+                                         alb.Rotate(75, p=0.1),
+                                         # alb.Resize(224, 224, p=1),  #default 0.1
+                                         # alb.RandomCrop(200, 200, p=0.1),
+                                         alb.HorizontalFlip(),
+                                         # alb.Transpose(),
+                                         alb.VerticalFlip(),
+                                         # alb.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
+                                         ], p=0.7)  # default 0.7
+        self.to_tensor = transforms.ToTensor()
+        self.only_first_class = only_first_class
+        self.height = height
+        self.width = width
+        self.logarithm = logarithm
+        self.first_classes = set([class_[0] for class_ in self.sunspots['class'].unique()])
+        self.second_classes = set([class_[1] for class_ in self.sunspots['class'].unique()])
+        self.third_classes = set([class_[2] for class_ in self.sunspots['class'].unique()])
+        if max is None:
+            self.max = self.find_max_dataset()
+        else:
+            self.max = max
+
+    def __getitem__(self, index):
+        file_path = os.path.join(self.folder_path, self.files[index])
+        with fits.open(file_path) as fits_file:
+            data = fits_file[0].data
+
+        if self.transformations is None:
+            if self.logarithm:
+                data = self.log_normalize(data)
+            data = self.normalize_data(data)
+#             data = data.reshape(1, data.shape[0],data.shape[1]).repeat(3, axis=0)
+            data = resize(data, (self.height, self.width), anti_aliasing=True)
+            data = self.aug()(image=data)['image']  # augumentation
+            data = self.to_tensor(data).float()  # uncomment for float
+            data = data.repeat(3, 1, 1)
+        else:
+            data = self.transformations(data)
+
+        mc_class = self.get_attr_region(self.files[index], self.sunspots, self.only_first_class)
+
+        for ind, letter in enumerate(sorted(self.first_classes)):
+            if letter == mc_class:
+                num_class = ind
+
+#         return (data, num_class, mc_class)
+        return {"image": data, "label": num_class, "letter_label": mc_class}
+
+    def __len__(self):
+        return self.len
+
+    def show_region(self, index):
+        '''Plot region by index from dataset
+        index: int, index of sample from dataset
+        '''
+        date, region = self.files[index].split('.')[1:3]
+        file_path = os.path.join(self.folder_path, self.files[index])
+        with fits.open(file_path) as fits_file:
+            data = fits_file[0].data
+        class_, size, location, number_ss = self.get_attr_region(self.files[index],
+                                                                 self.sunspots,
+                                                                 only_first_class=False,
+                                                                 only_class=False)
+        ax = plt.axes()
+        ax.set_title(
+            'Region {} on date {} with class {} on location {} with size {} and number_of_ss {}'
+            .format(region, date, class_, location, size, number_ss))
+        ax.imshow(data)
+        # ax.annotate((24,12))
+
+    def get_attr_region(self, filename, df, only_first_class=False, only_class=True):
+        '''Get labels for regions
+        '''
+        date, region = filename.split('.')[1:3]
+        reg_attr = df.loc[date[:-7], int(region[2:])]
+        if only_first_class:
+            return reg_attr['class'][0]
+        elif (not only_class) and (only_first_class):
+            class_, \
+                size, \
+                location, \
+                number_ss = reg_attr[['class', 'size', 'location', 'number_of_ss']]
+            return class_[0], size, location, number_ss
+        elif (not only_class) and (not only_first_class):
+            return reg_attr[['class', 'size', 'location', 'number_of_ss']]
+        else:
+            return reg_attr['class']
+
+    def log_normalize(self, data):
+        return np.sign(data) * np.log1p(np.abs(data))
+
+    def normalize_data(self, data):
+        return data / self.max
+
+    def find_max_dataset(self):
+        '''Find max value of pixels over all dataset
+        '''
+        m = []
+        print('find max all over dataset')
+        for file in tqdm_notebook(self.files):
+            with fits.open(self.folder_path + file) as ff:
+                m.append(np.nanmax(np.abs(ff[0].data)))
+        return np.max(m)
+
+    def aug(self):
+        return self.alb_transorms
+
+    def split_dataset(self, val_size=None, test_size=None):
+        '''Spliting dataset in optional test, train, val datasets
+        test_size (optional): float from 0 to 1.
+        val_size (optional): float from 0 to 1.
+
+        Returns datasets in order (train, valid, test)
+
+        '''
+        len_all = self.len
+        test_split_size = int(np.floor(test_size * len_all)) if test_size else 0
+        val_split_size = int(np.floor(val_size * len_all)) if val_size else 0
+        train_split_size = len_all - test_split_size - val_split_size
+
+        return data_utils.random_split(self, [train_split_size, val_split_size, test_split_size])
+
+
+class Net(nn.Module):
+    def __init__(
+            self,
+            num_classes: int = 7,
+            p: float = 0.2,
+            pooling_size: int = 2,
+            last_conv_size: int = 1664,
+            arch: str = "densenet169",
+            pretrained: str = "imagenet") -> None:
+        """A model to finetune.
+
+        Args:
+            num_classes: the number of target classes, the size of the last layer's output
+            p: dropout probability
+            pooling_size: the size of the result feature map after adaptive pooling layer
+            last_conv_size: size of the flatten last backbone conv layer
+            arch: the name of the architecture form pretrainedmodels
+            pretrained: the mode for pretrained model from pretrainedmodels
+        """
+        super().__init__()
+        net = pm.__dict__[arch](pretrained=None)
+        modules = list(net.children())[:-1]  # delete last layer
+        # add custom head
+        modules += [nn.Sequential(
+            # AdaptiveConcatPool2d is a concat of AdaptiveMaxPooling and AdaptiveAveragePooling
+            AdaptiveConcatPool2d(size=pooling_size),
+            Flatten(),
+            nn.BatchNorm1d(13312),
+            nn.Dropout(p),
+            nn.Linear(13312, num_classes)
+        )]
+        self.net = nn.Sequential(*modules)
+
+    def forward(self, x):
+        logits = self.net(x)
+        return logits
+
+
+def step_fn(model: torch.nn.Module,
+            batch: torch.Tensor) -> torch.Tensor:
+    """Determine what your model will do with your data.
+
+    Args:
+        model: the pytorch module to pass input in
+        batch: the batch of data from the DataLoader
+
+    Returns:
+        The models forward pass results
+    """
+    inp = batch["image"]  # here we get an "image" from our dataset
+    return model(inp)
+