
import numpy as np


def hagler(BC, ATN, delATN):
    """
    Smooth BC measurements by averaging over segments between filter changes
    based on ATN changes. Implements Hagler's method.

    Parameters:
    - BC: ndarray, mass concentration
    - ATN: ndarray, attenuation
    - delATN: float, threshold for averaging in ATN units

    Returns:
    - BC: mass concentration after averaging
    - count: number of points averaged in each segment
    """

    BC = BC.copy()
    count = np.ones(len(BC))

    # Find filter change points
    fast_changes = np.where(np.abs(np.diff(ATN)) > 30)[0]
    nan_breaks = np.where(np.isnan(ATN[1:]))[0]
    filtchange = np.concatenate([[0], np.unique(np.concatenate([fast_changes, nan_breaks])), [len(ATN)]])

    # Apply averaging
    for kk in range(len(filtchange) - 1):
        jj = filtchange[kk] + 1
        j_end = filtchange[kk + 1]

        while jj < j_end:
            idx_end = jj
            
            while idx_end + 1 < j_end and ATN[idx_end + 1] <= ATN[jj] + delATN:
                idx_end += 1

            segment = slice(jj, idx_end + 1)
            BC[segment] = np.nanmean(BC[segment])
            count[segment] = idx_end - jj + 1

            if count[jj] == 0:
                print('Warning: zero count encountered at index', jj)

            jj = idx_end + 1

    return BC, count


def sipkens(BC, ATN, Dt, Qa, UR, p=4.91, sL=0., gam=24.2):
    """
    Smooth BC data based on reproducibility threshold.

    Parameters:
    - BC: ndarray, mass concentration
    - ATN: ndarray, attenuation
    - Dt: sampling frequency of the measurements (min)
    - Qa: aerosol flow rate (mL/min)
    - UR: expanded reproducibility in percent (e.g., 20 for 20%)

    Returns:
    - BC: mass concentration after averaging
    - COUNT: number of points averaged in each segment
    """

    BC = BC.copy()
    count = np.ones(len(BC))
    err = np.ones(len(BC))
    
    def err_fun(sumM):
        sumM = np.maximum(sumM, 1e-10)  # avoid divide by zero
        err = np.sqrt(
            np.maximum(
                sL**2,
                sL**2 + (6 * p) / (Qa * sumM) +
                (6 * gam**2) / (Qa**2 * Dt * sumM**2)
            )
        )
        return err

    # --- Find filter change points ---
    fast_changes = np.where(np.abs(np.diff(ATN)) > 30)[0]
    nan_breaks = np.where(np.isnan(ATN[1:]))[0]
    filtchange = np.concatenate([[0], np.unique(np.concatenate([fast_changes, nan_breaks])), [len(ATN)]])

    # --- Loop over filter sections ---
    for kk in range(len(filtchange) - 1):
        jj = filtchange[kk] + 1
        j_end = filtchange[kk + 1]

        while jj < j_end:
            idx_end = jj

            while idx_end + 1 < j_end and \
                2 * err_fun(np.sum(BC[jj:idx_end + 2])) >= UR / 100:
                idx_end += 1

            segment = slice(jj, idx_end + 1)
            BC[segment] = np.nanmean(BC[segment])
            count[segment] = idx_end - jj + 1
            err[segment] = err_fun(np.sum(BC[jj:idx_end + 2]))

            jj = idx_end + 1
    
    count[np.isnan(BC)] = np.nan  # ensure NaNs are NaNs
    err[np.isnan(BC)] = np.nan  # ensure NaNs are NaNs

    return BC, count, err


def sipkens_kalman(BC, Dt, Qa, pn=3e-2, p=4.91, sL=0., gam=24.2):
    """
    Applies a Kalman-like smoothing filter to BC data based on physical uncertainty.

    Parameters:
    - BC: ndarray, mass concentration
    - ATN: ndarray, attenuation
    - Dt: sampling frequency of the measurements (min)
    - Qa: aerosol flow rate in mL/min
    - pn: process noise variance (higher corresponds to less smoothing)

    Returns:
    - ndarray with columns [time, BC_smoothed, ATN, SIG]
    """

    BC = BC.copy()
    err = np.zeros_like(BC)
    N = len(BC)

    # Function for evaluating error model.
    def sig_fun(BC):
        # Bottom out at noise floor for very small or negative BC
        min_val = np.sqrt((6 * gam**2) / (Qa**2 * Dt))
        term = sL**2 * BC**2 + (6 * p) * BC / Qa + (6 * gam**2) / (Qa**2 * Dt)
        return np.maximum(min_val, np.sqrt(np.maximum(0, term)))

    # Only loop through non-NaN values.
    idxs = np.arange(len(BC))
    idxs = idxs[~np.isnan(BC)]
    H = np.array([[1, 0]])  # measurement model

    x = np.array([BC[idxs[0]], 0.0])  # initial state: [BC, dBC/dt]
    P = np.eye(2) * 1.0  # initial uncertainty

    # --- Apply Kalman filter ---
    for tt, idx in enumerate(idxs):
        z = BC[idx]  # assign measurement to z

        # Restate matrices that can change with each iteration. 
        # If skipping time steps (e.g., because of NaNs), expand time step and uncertainties.
        if tt == 0: dt_tt = Dt # for first iteration
        else: dt_tt = (idx - idxs[tt - 1]) * Dt
        F = np.array([[1, dt_tt], [0, 0]])  # state transition (dampen slope to zero)
        Q = np.array([[pn, 0], [0, pn * dt_tt * 1000]])  # process noise covariance (tunable)

        # --- Predict step ---
        x_pred = F @ x
        P_pred = F @ P @ F.T + Q

        # --- Update step ---
        R = sig_fun(x_pred[0])  # measurement noise (use predicted value to evaluate error model)
        S = H @ P_pred @ H.T + R
        K = P_pred @ H.T / S  # Kalman gain

        residual = z - H @ x_pred
        
        x = x_pred + K @ residual
        P = (np.eye(P.shape[0]) - K @ H) @ P_pred

        BC[idx] = x[0]  # Save filtered BC + uncertainty
        err[idx] = 2 * P[0, 0] / x[0]
    
    err[np.isnan(BC)] = np.nan  # ensure NaNs are NaNs

    return BC, err
