import numpy as np
import matplotlib.pyplot as plt

# Daily registers. Year=365 days
daysyear=365

# Number of years in the simulation
years=30

totdays=years*daysyear
t = np.linspace(0, totdays, totdays)

# Base oscillation
freq=1/daysyear
oscillations = 2.0*np.sin(freq * t * 2.0 * np.pi) + 0.2 * np.sin(t * 2.0 * np.pi)


# Trend
trend = 10 * t/totdays

# Combine trend and oscillations
series = oscillations + trend + 15

# Add extreme values at random locations
np.random.seed(42)
extreme_indices = np.random.choice(len(t), size=25, replace=False)
series[extreme_indices] += np.random.choice([5, -5], size=25)  # large spikes

# Daily climatology (mean)
seriesingle_mean=np.mean(series.reshape(-1,daysyear), 0)
# extend to the total series
series_mean=np.tile(seriesingle_mean,years)

# 90 percentile (threshold for MHW)
seriesingle_90=np.percentile(series.reshape(-1,daysyear), 90,0)
# extend to the total series
series_90=np.tile(seriesingle_90,years)

# Above the climatology (intensity)
intensity=series-series_mean

# Points below 90 percentile (MHW)
Not_MHW_indices=np.where(series<series_90)
# Discard these points
intensity[Not_MHW_indices]=np.nan

# Repeat for detrended
detrended_series=series-trend

detrended_seriesingle_mean=np.mean(detrended_series.reshape(-1,daysyear), 0)
detrended_series_mean=np.tile(detrended_seriesingle_mean,years)

detrended_seriesingle_90=np.percentile(detrended_series.reshape(-1,daysyear), 90,0)
detrended_series_90=np.tile(detrended_seriesingle_90,years)

detrended_intensity=detrended_series-detrended_series_mean
detrended_Not_MHW_indices=np.where(detrended_series<detrended_series_90)
detrended_intensity[detrended_Not_MHW_indices]=np.nan


# Plot not detrended case
time=t*freq
fig, axs = plt.subplots(2, 2,figsize=(15, 15))
axs[0,0].plot(time, series, label='Time Series')
axs[0,0].scatter(time[extreme_indices], series[extreme_indices], color='red', label='Extreme Values', zorder=5)
axs[0,0].set_title('Synthetic time series')
axs[0,0].legend(loc='upper left')
axs[0,0].grid(True)


axs[0,1].set_title('Daily climatology')
axs[0,1].plot(time, series, label='Time Series')
axs[0,1].plot(time, series_mean,  label='Daily mean value')
axs[0,1].legend(loc='upper left')
axs[0,1].grid(True)

axs[1,0].set_title('Daily 90 percentile')
axs[1,0].plot(time, series, label='Time Series')
axs[1,0].plot(time, series_90,  label='90 percentile')
axs[1,0].legend(loc='upper left')
axs[1,0].grid(True)

axs[1,1].set_title('Detected MHW')
axs[1,1].plot(time, series, label='Time Series')
axs[1,1].scatter(time[extreme_indices], series[extreme_indices], color='red', label='Extreme Values', zorder=5)
axs[1,1].bar(time, intensity,  label='MHw intensity')
axs[1,1].legend(loc='upper left')
axs[1,1].grid(True)


fig.suptitle('Time series with trend, oscillations, and extreme values')
for ax in axs.flat:
    ax.set(xlabel='Year', ylabel='Temperature (ºC)')
plt.show()
fig.savefig('not_detrended.png')

# Plot detrended case
fig, axs = plt.subplots(2, 2,figsize=(15, 15))
axs[0,0].plot(time,detrended_series, label='Detrended Time Series')
axs[0,0].scatter(time[extreme_indices], detrended_series[extreme_indices], color='red', label='Extreme Values', zorder=5)
axs[0,0].set_title('Detrended time series from previous synthetic one')
#axs[0,0].legend(loc='upper left')
axs[0,0].grid(True)

axs[0,1].set_title('Detrended daily climatology')
axs[0,1].plot(time, detrended_series)
axs[0,1].plot(time, detrended_series_mean,  label='Daily mean value')
axs[0,1].legend(loc='upper left')
axs[0,1].grid(True)

axs[1,0].set_title('Daily 90 percentile from detrended time series')
axs[1,0].plot(time, detrended_series)
axs[1,0].plot(time, detrended_series_90,  label='90 percentile')
axs[1,0].legend(loc='upper left')
axs[1,0].grid(True)


axs[1,1].set_title('Detected MHW')
axs[1,1].plot(time, detrended_series,)
axs[1,1].scatter(time[extreme_indices], detrended_series[extreme_indices], color='red', zorder=5)
axs[1,1].bar(time, detrended_intensity,  label='MHw intensity')
axs[1,1].grid(True)


fig.suptitle('Detreded time series with trend, oscillations, and extreme values')
for ax in axs.flat:
    ax.set(xlabel='Year', ylabel='Temperature (ºC)')
plt.show()
fig.savefig('detrended.png')
