import numpy as np
import math
from numba import jit
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
import pywt
from skimage.filters import gaussian
from skimage.transform import rescale

# warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
# warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)

def check_dims(image):
    """
    Checks image dimentions and adds a 3rd axis in case it is 2D
    :param image: image to add axis if needed
    :return: 3D image
    """
    if len(image.shape) < 2 or len(image.shape) > 3:
        raise ValueError()

    if len(image.shape) < 3:
        image = np.expand_dims(image, axis=0)

    return image


# @jit
def min_max_normalization(image, low, high):
    """
    Min/Max Normalization of the image using low and high percentiles and following the formula:
    I_new = (I - min) * (Max_new - Min_new) / (Max - Min) + Min_new
    :param image: stack or image to be normalized with low and high percentile
    :param low: lower percentile for normalization
    :param high: higher percentile for normalization
    :return: normalized image using the defined percentiles
    """
    # image = check_dims(image)

    if low is not None:
        p_low = low
    else:
        p_low = 1.0

    if high is not None:
        p_high = high
    else:
        p_high = 99.0

    minI = np.percentile(image, p_low)
    maxI = np.percentile(image, p_high)

    image_norm = (image - minI) / (maxI - minI + 1e-20)
    return image_norm


# @jit
def absolute_laplacian(image, do_gauss):
    """
    Wrapper(ish) function
    Computes the Absolute Laplacian of each slice of a 3D stack and returns an array with values
    :param image: image or stack where to compute the absolute laplacians
    :return: array shape=(len(z)) containing the absolute laplacian for each slice
    """
    # n_px_slice = (image.shape[1] - 2) * (image.shape[2] - 2)

    abs_laplacian_array = np.zeros((image.shape[0]), dtype='float32')

    for z in range(image.shape[0]):
        image_slice = image[z].copy()
        if do_gauss:
            image_slice = gaussian(image_slice, sigma=1)
        abs_laplacian_array[z] = absolute_laplacian_slice(image_slice)

    return abs_laplacian_array


def absolute_laplacian_scaled(image, scale_factor):
    """
    TODO add detailed description
    :param image:
    :param scale_factor:
    :return:
    """
    assert scale_factor <= 1, print("Scale factor needs to be smaller or equal to 1")

    abs_laplacian_array = np.zeros((image.shape[0]), dtype='float32')

    for z in range(image.shape[0]):
        image_slice = image[z].copy()
        if scale_factor != 1:
            image_slice = rescale(image_slice, scale_factor)
        abs_laplacian_array[z] = absolute_laplacian_slice(image_slice)

    return abs_laplacian_array


@jit
def absolute_laplacian_slice(image):
    """
    Computes the Absolute Laplacian of a single image and returns value
    :param image: nd array of single image where to compute the absolute laplacian
    :return: absolute laplacian value
    """

    n_px = (image.shape[0] - 2) * (image.shape[1] - 2)

    laplacian = 0
    for y in range(1, image.shape[0] - 1):  # skip first and last pixel in y
        for x in range(1, image.shape[1] - 1):  # skip first and last pixel in x
            laplacian += abs(2 * image[y][x] - image[y][x - 1] - image[y][x + 1]) +\
                             abs(2 * image[y][x] - image[y - 1][x] - image[y + 1][x])

    return np.float32(laplacian / n_px)


@jit
def quant_first_derivative(image):
    """
    Computes the 1st derivative of each pixel and returns results as an nd array with shape(z, y-2, x-2)
    :param image: image where to compute the second derivatives in an image
    :return: nd array image where the pixels are the 1st derivative at each pixel
    """
    # image = check_dims(image)

    image_derivative = np.zeros_like(image)
    for z in range(image.shape[0]):
        for y in range(1, image.shape[1] - 1):
            for x in range(1, image.shape[2] - 1):
                image_derivative[z][y][z] = abs(image[z][y][x] - image[z][y][x - 1]) + abs(image[z][y][x] - image[z][y - 1][x])

    # next step reduces size of output to exclude zeros on borders which are not derivatives
    image_derivative = image_derivative[:, 1:(image_derivative.shape[1] - 1), 1:(image_derivative.shape[2] - 1)]
    return image_derivative


def quant_second_derivative(image, do_gauss):
    """
    Computes the 2nd derivative of each pixel and returns results as an nd array with shape(z, y-2, x-2)
    :param image: image where to compute the second derivatives in an image
    :return: nd array with smaller X and Y dims image where the pixels are the 2nd derivative at each pixel
    """
    derivative_array = np.zeros((image.shape[0], image.shape[1]-2, image.shape[2]-2), dtype='float32')
    for z in range(image.shape[0]):
        image_slice = image[z].copy()
        if do_gauss:
            image_slice = gaussian(image_slice, sigma=1)
        derivative_array[z] = quant_second_derivative_slice(image_slice)

    return derivative_array


@jit
def quant_second_derivative_slice(image):
    """
    Computes the 2nd derivative of each pixel and returns results as an nd array with shape(z, y-2, x-2)
    :param image: image where to compute the second derivatives in an image
    :return: nd array image where the pixels are the 2nd derivative at each pixel
    """
    image_derivative = np.zeros_like(image)
    for y in range(1, image.shape[0] - 1):
        for x in range(1, image.shape[1] - 1):
            image_derivative[y][x] = (abs(
                2 * image[y][x] - image[y][x - 1] - image[y][x + 1]
            ) + abs(
                2 * image[y][x] - image[y - 1][x] - image[y + 1][x])
                                         ) / 2
    # next step reduces size of output to exclude zeros on borders which are not derivatives
    image_derivative = image_derivative[1:(image_derivative.shape[0] - 1), 1:(image_derivative.shape[1] - 1)]
    return image_derivative


# @jit
def total_variation(image, do_gauss):
    """
    Computes Total Variation metric for the whole stack
    :param image: nd array of a 3D image where to compute the Total Variation
    :return: Total Variation score of image stack
    """

    tv_array = np.zeros((image.shape[0]), dtype='float32')
    for z in range(image.shape[0]):
        image_slice = image[z].copy()
        if do_gauss:
            image_slice = gaussian(image_slice, sigma=1)
        tv_array[z] = total_variation_slice(image_slice)

    return tv_array


@jit
def total_variation_slice(image):
    """
    Computes the Total Variation of each slice seperately and returns as an array
    :param image: nd array from 3D image
    :return: array with Total Variation scores for each slice
    """
    n_px_slice = (image.shape[0] - 2) * (image.shape[1] - 2)

    tv = 0.0
    for y in range(1, image.shape[0] - 1):
        for x in range(1, image.shape[1] - 1):
            tv += math.sqrt(pow(image[y][x + 1] - image[y][x - 1], 2) + pow(image[y + 1][x] - image[y - 1][x], 2))

    return np.float32(tv / n_px_slice)


def wavelet_contrast_index(image, p_low, p_high):
    """
    Wrapper function for contrast index quantification, checking image dimensions

    Following the quantification for contrast using wavelet decomposition/analysis seen in Albright et al. 2023 PNAS.
    And defined as contrast_index = log(w_{95} / w_{50})
    :param image: image to measure contrast in
    :param p_low: lowest percentile to normalize for contrast index
    :param p_high: top percentile to measure contrast, should not be sensitive to noise in image
    :return: contrast index
    """

    contrast_indexes = np.zeros((image.shape[0]), dtype='float32')
    for z in range(image.shape[0]):
        image_slice = image[z].copy()
        contrast_indexes[z] = wavelet_contrast_index_slice(image_slice, p_low, p_high)

    return contrast_indexes


def wavelet_contrast_index_slice(image, p_low, p_high):
    """
    Following the quantification for contrast using wavelet decomposition/analysis seen in Albright et al. 2023 PNAS.
    And defined as contrast_index = log(w_{95} / w_{50})
    :param image: image to measure contrast in
    :param p_low: lowest percentile to normalize for contrast index
    :param p_high: top percentile to measure contrast, should not be sensitive to noise in image
    :return: contrast index
    """
    wavelet_coefficients = pywt.wavedec2(image, 'haar', level=4)
    coefficient_array, coefficient_slices = pywt.coeffs_to_array(wavelet_coefficients[:-1])
    coefficients_absolute = np.absolute(coefficient_array)
    coeff_p_low = np.percentile(coefficients_absolute, p_low)
    coeff_p_high = np.percentile(coefficients_absolute, p_high)
    contrast_index = math.log(coeff_p_high / coeff_p_low)
    return contrast_index


def percentile_contrast_index(image, p_low, p_high):
    """
    TODO add description
    :param image:
    :param p_low:
    :param p_high:
    :return:
    """
    assert p_low < p_high, print("Low percentile needs to be smaller than High percentile")
    assert 0 <= p_low <= 100.0, print("Lower percentile needs to be between [0-100]")
    assert 0 <= p_high <= 100.0, print("Higher percentile needs to be between [0-100]")

    contrast_indexes = np.zeros((image.shape[0]), dtype='float32')
    for z in range(image.shape[0]):
        image_slice = image[z]
        perc_low = np.percentile(image_slice, p_low)
        perc_high = np.percentile(image_slice, p_high)
        try:
            contrast_indexes[z] = math.log(perc_high / perc_low)
        except:
            contrast_indexes[z] = np.nan
    
    return contrast_indexes

def PSNR(gt, pred, range_=4095.0):
    """
    TODO add description
    Based on code from PPN2V repo for PSNR quantification
    :param gt:
    :param pred:
    :param range_:
    :return:
    """
    assert gt.shape[0] == pred.shape[0], print("Stacks don't have the same number of z planes")

    psnr_indices = np.zeros((gt.shape[0]), dtype='float32')
    for z in range(gt.shape[0]):
        mse = np.mean((gt[z] - pred[z])**2)
        psnr = 20 * np.log10((range_)/np.sqrt(mse))
        psnr_indices[z] = psnr

    return psnr_indices

def MSE(gt, pred):
    """
    TODO add description

    """
    assert gt.shape == pred.shape, print("Stacks don't have the same dimensions")

    mse_values = np.zeros((gt.shape[0]), dtype='float32')
    for z in range(gt.shape[0]):
        mse_values[z] = np.mean((gt[z] - pred[z])**2)
    
    return mse_values

def MAE(gt, pred):
    """
    TODO add description

    """
    assert gt.shape == pred.shape, print("Stacks don't have the same dimensions")

    mae_values = np.zeros((gt.shape[0]), dtype='float32')
    for z in range(gt.shape[0]):
        mae_values[z] = np.mean((gt[z] - pred[z]))
    
    return mae_values