from bm3d import bm3d, bm3d_rgb, rgb_to, BM3DProfile
from scipy.ndimage import generic_filter, convolve
import numpy as np
from skimage.util import view_as_windows

# --- Joint MMSE
def covariance_shrinkage(Cdd, alpha, max_cond, min_eig):
    var_CO2 = np.ma.median(Cdd[:,:,0,0])
    var_NO2 = np.ma.median(Cdd[:,:,1,1])
    λ, U = np.linalg.eigh(np.nan_to_num(Cdd))
    λ[:,:,0] = alpha * λ[:,:,0] + (1-alpha) * var_CO2
    λ[:,:,1] = alpha * λ[:,:,1] + (1-alpha) * var_NO2
    eig_max = np.maximum(λ[:,:,1], min_eig)
    eig_min_allowed = eig_max / max_cond
    λ[:,:,0] = np.maximum(λ[:,:,0], eig_min_allowed)
    recon = (U * λ[..., None, :]) @ U.transpose((0,1,3,2))
    recon[np.isnan(Cdd)] = np.nan
    return recon

def ridge_regularization(Cdd, max_cond, γ, Γ):
    cond_num = np.linalg.cond(Cdd.filled(1.0))  # shape (lon, lat)
    p = 2.0  # curvature of ramp
    cond_factor = np.clip(np.log10(cond_num) / np.log10(max_cond), 0, 1)
    nugget_strength = γ + cond_factor**p * (Γ - γ)
    
    # Apply proportional nugget to each variance term
    Cdd[:,:,0,0] += nugget_strength * np.ma.median(Cdd[:,:,0,0])
    Cdd[:,:,1,1] += nugget_strength * np.ma.median(Cdd[:,:,1,1])
    
    variance_floor = np.ma.median( Cdd[:,:,0,0] )
    Cdd[:,:,0,0] = np.maximum(Cdd[:,:,0,0], variance_floor)
    return Cdd

def covariance(D, i, j):
    count = (D[i,...]*D[j,...]).count(axis=-1) - 1
    return np.ma.sum(D[i,...]*D[j,...], axis=-1)/count

def nanaverage(A,W,axis=-1):
    return np.nansum(A*W,axis=axis)/((~np.isnan(A))*W).sum(axis=axis)

def joint_MMSE(
    in_arr1, in_arr2, T, CNN=None, alpha=0.5,
    method='median', max_cond=1e7, min_eig=1e-13,
    γ=1e-3, Γ=1e0):
    
    # --- Pad data (s.t. our windows catch all the data)
    CO2 = np.pad(in_arr1, ((T, T), (T, T)), 'symmetric')
    NO2 = np.pad(in_arr2, ((T, T), (T, T)), 'symmetric')
    
    # --- Extract overlapping patches of size WxW
    CO2_tiles = view_as_windows(CO2               , (T,T))
    NO2_tiles = view_as_windows(NO2               , (T,T))
    mask_tile = view_as_windows(~np.isnan(CO2+NO2), (T,T))

    # --- Reshape
    mask_tile = mask_tile.reshape(*mask_tile.shape[:-2], -1)
    X = np.stack((CO2_tiles.reshape(*CO2_tiles.shape[:-2], -1),
                  NO2_tiles.reshape(*NO2_tiles.shape[:-2], -1)))

    # --- Generate masked array
    X = np.ma.array( X, mask=~np.stack([mask_tile] * 2, axis=0) )

    # --- Compute expected value of the dataset
    av_field = {"mean": np.nanmean(X,-1,keepdims=1), 
                "median": np.nanmedian(X,-1,keepdims=1)}[method]

    # --- Compute sample covariance matrix        
    D = X - av_field
    Cdd = np.ma.zeros((*D.shape[1:3], 2, 2)) * np.nan
    Cdd[...,0,0] = covariance(D, 0, 0)
    Cdd[...,0,1] = covariance(D, 0, 1)
    Cdd[...,1,0] = covariance(D, 1, 0)
    Cdd[...,1,1] = covariance(D, 1, 1)
    Cdd = covariance_shrinkage(Cdd, alpha, max_cond, min_eig)
    Cdd = ridge_regularization(Cdd, max_cond, γ, Γ)

    # --- Compute noise covariance matrix
    Cnn = np.zeros_like(Cdd)
    Cnn[:,:,0,0] = np.ma.median(Cdd[...,0,0])
    if CNN is None:
        pass
    elif type(CNN) is np.float32:
        Cnn[:,:,0,0] = Cnn[:,:,0,0]*0.5 + CNN*0.5
    else:
        CNN = np.pad(CNN, ((T, T), (T, T)), 'symmetric')
        CNN = view_as_windows(CNN, (T,T))
        CNN = CNN.reshape(*mask_tile.shape[:-1], -1)
        CNN = np.nanmedian(CNN, -1)
        Cnn[:,:,0,0] = Cnn[:,:,0,0]*0.5 + CNN*0.5
    
    # --- Apply filter
    wCddICnn = np.linalg.solve(Cdd.filled(np.nan),Cnn.filled(np.nan))
    wCddICnnEM = np.einsum('ijk,kijl->ijl',wCddICnn[...,0].squeeze(),D)

    est_gather = np.zeros((*CO2.shape, T**2)) * np.nan
    for i in range(T**2):
        y, x = np.mod(i,T)-T//2, i//T-T//2
        xs, xe = max((T//2)+x,0), min(CO2.shape[0]-(T//2)+x, CO2.shape[0])
        ys, ye = max((T//2)+y,0), min(CO2.shape[1]-(T//2)+y, CO2.shape[1])
        est_gather[xs:xe, ys:ye, i] = wCddICnnEM[:,:,i]
        
    # --- Generate filter grid coordinates
    y, x = np.meshgrid(np.arange(-T//2+1, T//2+1), np.arange(-T//2+1, T//2+1))
    weights_2dgauss = np.exp(-(x**2 + y**2) / (2 * 4**2))
    
    # --- Compute final output
    pred = CO2 - nanaverage(est_gather,W=weights_2dgauss.flatten())
    est = np.where( np.isfinite(CO2), pred, np.nan )
    est = np.where( np.isnan(NO2) & np.isnan(est), CO2, est )

    return est[T:-T,T:-T]

# --- BM3D
class Normalizer:
    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, data):
        # Compute mean and standard deviation, ignoring NaNs
        self.min = np.nanmin(data)
        self.range = np.nanmax(data) - self.min
        # To avoid division by zero in case of zero standard deviation
        self.range = self.range if (self.range != 0) else 1
        self.isnan = np.isnan(data)

    def transform(self, data):
        # Normalize the data to the range [0,1]
        data = np.nan_to_num(data, nan=np.nanmean(data))
        normalized_data = (data - self.min) / self.range
        return np.nan_to_num(normalized_data)

    def inverse_transform(self, normalized_data):
        # Un-normalize the data
        original_data = normalized_data * self.range + self.min
        original_data[self.isnan]=np.nan
        return original_data

    def fit_transform(self, data):
        # Fit and transform the data
        self.fit(data)
        return self.transform(data)

class BM3DDenoiser:
    def __init__(self, OPP2D=None):
        self.colorspaces = {
            "RGB": np.identity(3),
            "OPP2D": np.array([
                        [1/4, 1/2],
                        [0  , 1/2],
                    ]) if OPP2D is None else OPP2D,
            "OPP3D": np.array([
                        [1/3, 1/3, 1/3],
                        [1/2, 0, -1/2],
                        [1/4, -1/2, 1/4]
                    ]),
            }
        
    def colorspace_transform(self,image_in, inverse=False):
        """
        The transforms are simple multiplications of images (space: [w,h,ch]) with a matrix A (space: [ch', ch]).
        Using it returns an image in [w,h,ch'] space, where ch' is the alternative colorspace.
        To transform back, we simply use the inverse matrix to get back to [w,h,ch] space.
        """
        if inverse:
            return np.einsum("whc,ic->whi", image_in, np.linalg.inv(self.colorspace))
        else:
            return np.einsum("whc,ic->whi", image_in, self.colorspace)

    def est_sigmas(self,sigma):
        """
        When using a color_space like OPP, the sigma values needs to be adjusted.
        """
        return np.sqrt((self.colorspace**2).sum(axis=1))*sigma

    def denoise(self, sigma, data_noisy, data_guide=None, return_both=False):
        """
        Denoise the input field (data_noisy) using BM3D. If data_guide is given,
        we denoise data_noisy using additionally data_guide (which plays a role
        when selecting 'similar' patches for BM3D. 
        Parameter sigma (scalar, or list of [sigma_noisy, sigma_guide] is an estimate
        of the noise variance. Because the data is normalized between a min/max 
        range, some trial and error is required to pick it. However, then it may be
        kept constant.
        """
        if np.isnan(data_noisy).all():
            if return_both:
                return data_noisy, data_guide
            else:
                return data_noisy

        
        self.noisynormalizer = Normalizer()
        self.noisy_N = self.noisynormalizer.fit_transform(data_noisy)
        
        if data_guide is not None:
            self.guidenormalizer = Normalizer()
            self.guide_N = self.guidenormalizer.fit_transform(data_guide)
            self.colorspace = self.colorspaces["OPP2D"]
            self.joint_BM3D(sigma)
        else:
            self.single_BM3D(sigma)

        # Keep noisy data
        self.noisy_denoised = np.where( np.isnan(self.noisy_denoised), data_noisy, self.noisy_denoised)
        
        if return_both and data_guide is None:
            raise ValueError("To return both, data_guide should be given")
        elif return_both:
            return self.noisy_denoised, self.guide_denoised
        else:
            return self.noisy_denoised

    def joint_BM3D(self, sigma_guided=3):        
        z = np.zeros( self.noisy_N.shape + (2,) )
        z[:,:,0] = self.noisy_N # Block matching is performed on this first channel.
        z[:,:,1] = self.guide_N
        z = self.colorspace_transform(z)
        sigmas = self.est_sigmas(sigma_guided)
        # y_est = bm3d(z, sigmas/256, profile='vn_old')
        y_est = bm3d(z, sigmas/256)
        y_est = self.colorspace_transform(y_est, inverse=True)
        self.noisy_denoised = self.noisynormalizer.inverse_transform(y_est[:,:,0])
        self.guide_denoised = self.guidenormalizer.inverse_transform(y_est[:,:,1])
        
    def single_BM3D(self, sigma_unguided=1):
        z = np.empty( self.noisy_N.shape )
        z[:,:] = np.nan_to_num(self.noisy_N) # Block matching is performed on this first channel.
        y_est = bm3d(z, sigma_unguided/256, profile='vn_old')
        self.noisy_denoised = self.noisynormalizer.inverse_transform(y_est[:,:])

# --- 5x5 pix mean filter
def mean_filter_5x5(arr):
    """
    Apply a 5x5 mean filter to a NumPy array while handling NaN values using normalized convolution.

    Parameters:
        arr (numpy.ndarray): Input array of shape (nx, ny) with potential NaN values.

    Returns:
        numpy.ndarray: Filtered array with the same shape as input.
    """
    kernel = np.ones((5, 5))  # 5x5 mean filter kernel
    
    # Create a mask of valid (non-NaN) values
    valid_mask = ~np.isnan(arr)
    
    # Replace NaNs with 0s for convolution computation
    data_filled = np.where(valid_mask, arr, 0)
    
    # Perform convolution
    sum_filtered = convolve(data_filled, kernel, mode='constant', cval=0.0)
    weight_filtered = convolve(valid_mask.astype(float), kernel, mode='constant', cval=0.0)
    
    # Avoid division by zero: only compute where weight > 0
    with np.errstate(invalid='ignore', divide='ignore'):
        result = sum_filtered / weight_filtered
    
    # Maintain NaNs where original data was NaN and had no valid neighbors
    result[weight_filtered == 0] = np.nan
    
    return result
