"""Explores the COVID school closures dataset downloaded from
https://docs.google.com/spreadsheets/d/1ndHgP53atJ5J-EtxgWcpSfYG8LdzHpUsnb6mWybErYg
in line with the COMP0034 'How to' guide for data preparation and exploration

Continues from data_prep.py

Usage:
    ./data_explore.py

Author:
    Sarah Sanders - 23.05.2022
"""
import pandas as pd
import matplotlib

matplotlib.use('TkAgg')
import matplotlib.pyplot as plt


def set_pandas_display_options(dataframe):
    """Set the pandas display options to the size of the dataframe

    Parameters
    ----------
    dataframe : pandas DataFrame
        DataFrame containing the raw data.
    """
    pd.set_option("display.max_rows", len(dataframe.index))
    pd.set_option("display.max_columns", len(dataframe.columns))
    pd.set_option('display.width', 150)


if __name__ == "__main__":
    # TODO: Scope to create further functions and reduce the procedural code in main!

    df_cgd = pd.read_excel('cgd_cleaned.xlsx')
    set_pandas_display_options(df_cgd)

    # Print summary stats for the data
    print("\nSummary stats:\n==============\n")
    print(df_cgd.describe())

    # Create a boxplot looking for outliers
    boxplot = df_cgd.boxplot(column=['Cases at closure'])
    boxplot.plot()
    plt.show()

    # Find the outlier and then re-create the boxplot without it. This just lets you see the remaining data points
    # more clearly.
    print(df_cgd.loc[df_cgd['Cases at closure'] > 70000])
    df_cgd_ex_china = df_cgd.drop([41])
    boxplot_ex = df_cgd_ex_china.boxplot(column=['Cases at closure'])
    boxplot_ex.plot()
    plt.show()

    # Create a histogram of the values for the number of weeks schools were closed
    hist = df_cgd['Weeks closed'].hist()
    hist.plot()
    plt.show()

    # Create a bar chart for the number of countries closing on each day in March 2020
    cols = ['Closure date', 'Country']
    df_date = df_cgd[cols]
    df_date = df_date[df_date['Closure date'].between('2020-03-02', '2020-03-31')]
    print(df_date)
    df_date = df_date.groupby(pd.Grouper(key='Closure date', freq='D')).count()
    barh = df_date.plot.barh()
    plt.show()