Usage

Classify Light curves

Use the following example code:

from astrorapid.classify import Classify

# Each light curve should be a tuple in this form. Look at the example code for an example of the input format.
light_curve_info1 = (mjd, flux, fluxerr, passband, photflag, ra, dec, objid, redshift, mwebv)
light_curve_list = [light_curve_info1,]
contextual_info_list = [{'hosttype': value},]  # Only use this parameter if you have trained your own classifer with specific meta data. Otherwise set to None.

# Classify Light curves
classification = Classify(known_redshift=True)
predictions = classification.get_predictions(light_curve_list, contextual_info_list)
print(predictions)

# Plot light curve and classification vs time of the light curves at the specified indexes
classification.plot_light_curves_and_classifications(indexes_to_plot=(0,1,4,6))
classification.plot_classification_animation(indexes_to_plot=(0,1,4,6))

Train your own classifier with your own data

You’ll simply need to run the function astrorapid.custom_classifier.create_custom_classifier() to get started with training your own classifier. An example is shown below.

from astrorapid.custom_classifier import create_custom_classifier

script_dir = os.path.dirname(os.path.abspath(__file__))

create_custom_classifier(get_data_func=astrorapid.get_training_data.get_data_from_snana_fits,
                         data_dir=os.path.join(script_dir, '..', 'data/ZTF_20190512'),
                         class_nums=(1, 2, 12, 14, 3, 13, 41, 43, 51, 60, 61, 62, 63, 64, 70),
                         class_name_map={1: 'SNIa-norm', 2: 'SNII', 12: 'SNII', 14: 'SNII', 3: 'SNIbc', 13: 'SNIbc', 41: 'SNIa-91bg', 43: 'SNIa-x', 51: 'Kilonova', 60: 'SLSN-I', 61: 'PISN', 62: 'ILOT', 63: 'CART', 64: 'TDE', 70: 'AGN'},
                         reread_data=False,
                         contextual_info=('redshift', 'some_contextual_info1'),
                         passbands=('g', 'r'),
                         retrain_network=False,
                         train_epochs=100,
                         zcut=0.5,
                         bcut=True,
                         ignore_classes=(61, 62, 64, 70),
                         nprocesses=None,
                         nchunks=10000,
                         otherchange='',
                         training_set_dir='data/training_set_files',
                         save_dir='data/saved_light_curves',
                         fig_dir='data/training_set_files/Figures'),
                         plot=True
                         )

You’ll need to write your own function get_data_func to read your data and use the astrorapid preprocessing tools. Use the skeleton function here astrorapid.get_custom_data.get_custom_data(), or as rewritten below.

def get_custom_data(class_num, data_dir, save_dir, passbands, known_redshift, nprocesses, redo):
    """
    Get data from custom data files.
    You will need to write this function with the following skeleton function:

    Parameters
    ----------
    class_num : int
        Class number. E.g. SNIa is 1. See helpers.py for lookup table.
        E.g. class_num = 1
    data_dir : str
        Directory where data is stored
        E.g. data_dir='data/ZTF_20190512/'
    save_dir : str
        Directory to save processed data
        E.g. save_dir='data/saved_light_curves/'
    passbands : tuple
        Passbands to use.
        E.g. passbands=('g', 'r')
    known_redshift : bool
        Whether to correct the light curves for cosmological time dilation using redshift.
    nprocesses : int or None
        Number of processes to use
    redo : bool
        Whether to redo reading the data and saving the processed data.


    Returns
    -------
    light_curves : dict of astropy.table.Table objects
        e.g light_curves['objid1'] =
            passband   time       flux     fluxErr   photflag
              str1   float32    float32    float32    int32
            -------- -------- ----------- ---------- --------
                   g -46.8942  -48.926975  42.277767        0
                   g -43.9352  -105.35379   72.97575        0
                   g -35.9161  -46.264206    99.9172        0
                   g -28.9377  -28.978344  42.417065        0
                   g -25.9787  109.886566   46.03949        0
                   g -15.0399    -80.2485   80.38155        0
                   g -12.0218    93.51743  113.21529        0
                   g  -6.9585   248.88364 108.606865        0
                   g  -4.0411   341.41498  47.765404        0
                   g      0.0    501.7441   45.37485     6144
                 ...      ...         ...        ...      ...
                   r  40.9147   194.32494  57.836903     4096
                   r  59.9162    67.59185   45.66463     4096
                   r  62.8976    80.85155  44.356197     4096
                   r  65.8974   28.174305   44.75049     4096
                   r  71.8966  -18.790287 108.049774     4096
                   r  74.9297  -3.1707647  125.15057     4096
                   r  77.9341 -11.0205965 125.784676     4096
                   r  80.8576   129.65466   69.99305     4096
                   r  88.8922  -14.259436  52.917866     4096
                   r 103.8734   27.178356 115.537704     4096

    """

    # If the data has already been run and processed load it. Otherwise read it and save it
    save_lc_filepath = os.path.join(save_dir, f"lc_classnum_{class_num}.pickle")
    if os.path.exists(save_lc_filepath) and not redo:
        with open(save_lc_filepath, "rb") as fp:  # Unpickling
            light_curves = pickle.load(fp)
    else:
        light_curves = {}
        # Read in data from data_dir and get the mjd, flux, fluxerr, passband, photflag as 1D numpy arrays for
        # each light curve. Get the ra, dec, objid, redshift, mwebv, model_num, peak_mjd as floats or strings.
        # Set whether you'd like to train a model with a known redshift or not. Set known_redshift as a boolean.

        # Enter your own data-reading code here that gets the mjds, fluxes, fluxerrs, passbands, photflags,
        # ras, decs, objids, redshifts, mwebvs, model_nums, peak_mjds for all the light curves from the data_dir

        # Once you have the required data information for each light curve, pass it into InputLightCurve with
        # something like the following code:
        for i, objid in enumerate(objids):
            inputlightcurve = InputLightCurve(mjds[i], fluxes[i], fluxerrs[i], passbands[i], photflags[i],
                                              ras[i], decs[i], objids[i], redshifts[i], mwebvs[i],
                                              known_redshift=known_redshift,
                                              training_set_parameters={'class_number': int(class_num),
                                                                       'peakmjd': peakmjds[i]},
                                              other_meta_data={'some_contextual_info1': value})
            light_curves[objid] = inputlightcurve.preprocess_light_curve()

        # If you think that reading the data is too slow, you may want to replace the for loop above with
        # multiprocessing. See the example function in get_training_data.py if you need help doing this.

        # Next, we save it:
        with open(save_lc_filepath, "wb") as fp:  # Pickling
            pickle.dump(light_curves, fp)

    return light_curves