"""
====================================================================================================
9.13 Reverse Stress Test
====================================================================================================
Utilitary functions can be found next to this file. Here, we only define codpy-related functions.
"""

#########################################################################
# Necessary Imports
# ------------------------
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 

from codpy.kernel import Sampler, Kernel
from codpy.plot_utils import multi_plot, compare_plot_lists
from codpy.data_conversion import get_float, get_matrix

try:
    CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
    CURRENT_DIR = os.getcwd()
data_path = os.path.join(CURRENT_DIR, "data")
PARENT_DIR = os.path.abspath(os.path.join(CURRENT_DIR, ".."))
sys.path.insert(0, PARENT_DIR)

from utils.ch9.plot_utils import compare_plot_lists_ax
from utils.ch9.heston import predict_prices 
 
def reverse_pnls(kwargs = None):
    if kwargs is None: kwargs = predict_prices()

    def plot_helper(xfx,**kwargs):
        x, fx = xfx[0],xfx[1]
        args = xfx[2]
        compare_plot_lists({**{'listxs':x, 'listfxs':fx}, **kwargs,**args})

    def graphic(**out):
        dim = len(out['symbols'])
        getter = out['getter']
        f_names = ['Reverse PnLs (GEN)','errors RST']
        reverse_pnls_gen = kwargs['reverse_pnls (GEN)']
        reverse_pnls_hist = kwargs['reverse_pnls (HIST)']
        pnls = kwargs['pnls']
        pnls_var = kwargs['check pnls (GEN)']
        pnls_hist = kwargs['check pnls (HIST)']

        regular = np.linspace(0,1., pnls.shape[0])



        plot_datas = [
            [
                [pnls for n in range(0,dim)],
                [reverse_pnls_gen.values[:,n] for n in range(0,dim)],
                {'labelx' : "pnls", 'labely' : "reverse_pnls (GEN)",'listlabels':list(reverse_pnls_gen.columns)}
            ],
            [
                # [pnls.values[:,0],pnls.values[:,0]],
                [pnls.values[:,0]],
                [(pnls_var.values[:,0]-pnls.values[:,0])*1000/pnls.values[:,0].max()],
                {'labelx' : "pnls", 'labely' : "errors (BP)",'listlabels':["errors"]}
            ]
        ]
        multi_plot(plot_datas,plot_helper,f_names = f_names,loc = 'upper left',prop={'size': 4}, mp_ncols=3, mp_nrows = 1, **out)
        
        pass
    kwargs["set_fun"] = kwargs['getter'].set_spot
    test_y = kwargs['TestData']
    today_price = kwargs['pricer'](kwargs['data'].values[[-1],:],**kwargs)
    # test_y = kwargs['VaRData']
    y = kwargs['SyntheticData']
    Dates = get_float(test_y.index.unique())
    y = y.loc[Dates]
    kwargs["set_fun"] = kwargs['getter'].set_time_spot
    x = get_matrix(kwargs['pricer'](y,**kwargs)) - today_price
    z = get_matrix(kwargs['pricer'](test_y,**kwargs)) - today_price
    kwargs['pnls'] = pd.DataFrame(z, columns = ['pnls'], index = test_y.index)
    
    # We map the pnl to underlyings & we extrapolate on TestData
    reverse_pnls_gen = Kernel(x=x, fx=y)(z=z)

    # op.projection(**{**kwargs,**{'x':x,'y':x,'z':z,'fx':y}})
    reverse_pnls_gen = pd.DataFrame(reverse_pnls_gen, columns = y.columns)
    reverse_pnls_gen.index = test_y.index

    y = kwargs['data']
    x = get_matrix(kwargs['pricer'](y,**kwargs)) - today_price

    # Same here, we get the historical values
    reverse_pnls_hist = Kernel(x=x, fx=y)(z=z)
    # reverse_pnls_hist = op.projection(**{**kwargs,**{'x':x,'y':x,'z':z,'fx':y}})
    reverse_pnls_hist = pd.DataFrame(reverse_pnls_hist, columns = y.columns)
    reverse_pnls_hist.index = test_y.index
    # test = reverse_pnls-y
    kwargs['reverse_pnls (HIST)']= reverse_pnls_hist
    kwargs['reverse_pnls (GEN)']= reverse_pnls_gen
    kwargs['exact_reverse_pnls']= y
    kwargs['check pnls (HIST)'] = pd.DataFrame(get_matrix(kwargs['pricer'](reverse_pnls_hist,**kwargs)), columns = ['pnls'], index = test_y.index)- today_price
    kwargs['check pnls (GEN)'] = pd.DataFrame(get_matrix(kwargs['pricer'](reverse_pnls_gen,**kwargs)), columns = ['pnls'], index = test_y.index)- today_price
    pnls = kwargs['pnls']
    error1 = (kwargs['check pnls (HIST)'] - pnls)/(kwargs['check pnls (HIST)'] + pnls)
    error2 = (kwargs['check pnls (GEN)'] - pnls)/(kwargs['check pnls (GEN)'] + pnls)
    kwargs['graphic']= graphic

    pnl_values = [-3.40,-0.01] # test bench for particular values
    for value in pnl_values:
        min_indice = abs(kwargs['pnls'].values - value).argmin()
        min_value = kwargs['pnls'].values[min_indice]
        proba = kwargs['pnls'].sort_values(by=["pnls"]) 
        proba = ( abs(proba.to_numpy() - value).argmin()+1) / proba.shape[0]
        min_hist = kwargs['check pnls (HIST)'].values[min_indice]
        min_gen = kwargs['check pnls (GEN)'].values[min_indice]
        min_sj_gen = kwargs['reverse_pnls (GEN)'].values[min_indice]
        min_sj_hist = kwargs['reverse_pnls (HIST)'].values[min_indice]


    return kwargs

kwargs = reverse_pnls()
kwargs['graphic'](**kwargs)
plt.show()
