import pytest
import numpy as np

from scipy._lib._array_api import (
    xp_assert_close, xp_assert_equal, xp_swapaxes, is_torch, make_xp_test_case,
    _xp_copy_to_numpy
)

from scipy.ndimage import convolve1d   # type: ignore[attr-defined]

from scipy.signal import savgol_coeffs, savgol_filter
from scipy.signal._savitzky_golay import _polyder

skip_xp_backends = pytest.mark.skip_xp_backends
xfail_xp_backends = pytest.mark.xfail_xp_backends

pytestmark = pytest.mark.skip_xp_backends(
    "dask.array", reason="linalg.lstsq is missing rcond"
)


def check_polyder(p, m, expected, xp):
    dp = _polyder(p, m, xp=xp)
    xp_assert_equal(dp, expected)


def test_polyder(xp):
    cases = [
        ([5], 0, [5]),
        ([5], 1, [0]),
        ([3, 2, 1], 0, [3, 2, 1]),
        ([3, 2, 1], 1, [6, 2]),
        ([3, 2, 1], 2, [6]),
        ([3, 2, 1], 3, [0]),
        ([[3, 2, 1], [5, 6, 7]], 0, [[3, 2, 1], [5, 6, 7]]),
        ([[3, 2, 1], [5, 6, 7]], 1, [[6, 2], [10, 6]]),
        ([[3, 2, 1], [5, 6, 7]], 2, [[6], [10]]),
        ([[3, 2, 1], [5, 6, 7]], 3, [[0], [0]]),
    ]
    for p, m, expected in cases:
        pp = xp.asarray(p)
        ee = xp.asarray(expected)
        check_polyder(
            pp.T if pp.ndim == 2 else pp,
            m,
            ee.T if ee.ndim ==2 else ee,
            xp
        )


#--------------------------------------------------------------------
# savgol_coeffs tests
#--------------------------------------------------------------------

def alt_sg_coeffs(window_length, polyorder, pos, xp):
    """This is an alternative implementation of the SG coefficients.

    It uses numpy.polyfit and numpy.polyval. The results should be
    equivalent to those of savgol_coeffs(), but this implementation
    is slower.

    window_length should be odd.

    """
    if pos is None:
        pos = window_length // 2
    t = np.arange(window_length)
    unit = (t == pos).astype(int)
    h = np.polyval(np.polyfit(t, unit, polyorder), t)
    return xp.asarray(h)


@make_xp_test_case(savgol_coeffs)
def test_sg_coeffs_trivial(xp):
    # Test a trivial case of savgol_coeffs: polyorder = window_length - 1
    h = savgol_coeffs(1, 0, xp=xp)
    xp_assert_close(h, xp.asarray([1.0], dtype=xp.float64))

    h = savgol_coeffs(3, 2, xp=xp)
    xp_assert_close(h, xp.asarray([0.0, 1, 0], dtype=xp.float64), atol=1e-10)

    h = savgol_coeffs(5, 4, xp=xp)
    xp_assert_close(h, xp.asarray([0.0, 0, 1, 0, 0], dtype=xp.float64), atol=1e-10)

    h = savgol_coeffs(5, 4, pos=1, xp=xp)
    xp_assert_close(h, xp.asarray([0.0, 0, 0, 1, 0], dtype=xp.float64), atol=1e-10)

    h = savgol_coeffs(5, 4, pos=1, use='dot', xp=xp)
    xp_assert_close(h, xp.asarray([0.0, 1, 0, 0, 0], dtype=xp.float64), atol=1e-10)


def compare_coeffs_to_alt(window_length, order, xp):
    # For the given window_length and order, compare the results
    # of savgol_coeffs and alt_sg_coeffs for pos from 0 to window_length - 1.
    # Also include pos=None.
    for pos in [None] + list(range(window_length)):
        h1 = savgol_coeffs(window_length, order, pos=pos, use='dot', xp=xp)
        h2 = alt_sg_coeffs(window_length, order, pos=pos, xp=xp)
        xp_assert_close(
            h1, h2, atol=2e-10,
            err_msg=f"window_length = {window_length}, order = {order}, pos = {pos}"
        )


@make_xp_test_case(savgol_coeffs)
def test_sg_coeffs_compare(xp):
    # Compare savgol_coeffs() to alt_sg_coeffs().
    for window_length in range(1, 8, 2):
        for order in range(window_length):
            compare_coeffs_to_alt(window_length, order, xp=xp)


@make_xp_test_case(savgol_coeffs)
def test_sg_coeffs_exact(xp):
    polyorder = 4
    window_length = 9
    halflen = window_length // 2

    x = xp.linspace(0, 21, 43)
    delta = x[1] - x[0]

    # The data is a cubic polynomial.  We'll use an order 4
    # SG filter, so the filtered values should equal the input data
    # (except within half window_length of the edges).
    y = 0.5 * x ** 3 - x
    h = savgol_coeffs(window_length, polyorder, xp=xp)
    y0 = xp.asarray(convolve1d(_xp_copy_to_numpy(y), _xp_copy_to_numpy(h)))
    xp_assert_close(y0[halflen:-halflen], y[halflen:-halflen])

    # Check the same input, but use deriv=1.  dy is the exact result.
    dy = 1.5 * x ** 2 - 1
    h = savgol_coeffs(window_length, polyorder, deriv=1, delta=delta, xp=xp)
    y1 = xp.asarray(convolve1d(_xp_copy_to_numpy(y), _xp_copy_to_numpy(h)))
    xp_assert_close(y1[halflen:-halflen], dy[halflen:-halflen])

    # Check the same input, but use deriv=2. d2y is the exact result.
    d2y = 3.0 * x
    h = savgol_coeffs(window_length, polyorder, deriv=2, delta=delta, xp=xp)
    y2 = xp.asarray(convolve1d(_xp_copy_to_numpy(y), _xp_copy_to_numpy(h)))
    xp_assert_close(y2[halflen:-halflen], d2y[halflen:-halflen])


@make_xp_test_case(savgol_coeffs)
def test_sg_coeffs_deriv(xp):
    # The data in `x` is a sampled parabola, so using savgol_coeffs with an
    # order 2 or higher polynomial should give exact results.
    i = xp.asarray([-2.0, 0.0, 2.0, 4.0, 6.0], dtype=xp.float64)
    x = i ** 2 / 4
    dx = i / 2
    d2x = xp.full_like(i, 0.5)
    for pos in range(x.shape[0]):
        coeffs0 = savgol_coeffs(5, 3, pos=pos, delta=2.0, use='dot', xp=xp)
        xp_assert_close(coeffs0 @ x, x[pos], atol=1e-10)
        coeffs1 = savgol_coeffs(5, 3, pos=pos, delta=2.0, use='dot', deriv=1, xp=xp)
        xp_assert_close(coeffs1 @ x, dx[pos], atol=1e-10)
        coeffs2 = savgol_coeffs(5, 3, pos=pos, delta=2.0, use='dot', deriv=2, xp=xp)
        xp_assert_close(coeffs2 @ x , d2x[pos], atol=1e-10)


@make_xp_test_case(savgol_coeffs)
def test_sg_coeffs_deriv_gt_polyorder(xp):
    """
    If deriv > polyorder, the coefficients should be all 0.
    This is a regression test for a bug where, e.g.,
        savgol_coeffs(5, polyorder=1, deriv=2)
    raised an error.
    """
    coeffs = savgol_coeffs(5, polyorder=1, deriv=2, xp=xp)
    xp_assert_equal(coeffs, xp.zeros(5, dtype=xp.float64))
    coeffs = savgol_coeffs(7, polyorder=4, deriv=6, xp=xp)
    xp_assert_equal(coeffs, xp.zeros(7, dtype=xp.float64))


@xfail_xp_backends(
    "torch", reason="torch loses precision (worse with f32 default dtype)"
)
@make_xp_test_case(savgol_coeffs)
def test_sg_coeffs_large(xp):
    # Test that for large values of window_length and polyorder the array of
    # coefficients returned is symmetric. The aim is to ensure that
    # no potential numeric overflow occurs.
    coeffs0 = savgol_coeffs(31, 9, xp=xp)
    xp_assert_close(
        coeffs0, xp.flip(coeffs0), atol=5e-4 if is_torch(xp) else 1.5e-6
    )

    coeffs1 = savgol_coeffs(31, 9, deriv=1, xp=xp)
    xp_assert_close(
        coeffs1, -xp.flip(coeffs1), atol=1.2e-2 if is_torch(xp) else 1.5e-6
    )

# --------------------------------------------------------------------
# savgol_coeffs tests for even window length
# --------------------------------------------------------------------

@make_xp_test_case(savgol_coeffs)
def test_sg_coeffs_even_window_length(xp):
    # Simple case - deriv=0, polyorder=0, 1
    window_lengths = [4, 6, 8, 10, 12, 14, 16]
    for length in window_lengths:
        h_p_d = savgol_coeffs(length, 0, 0, xp=xp)
        xp_assert_close(h_p_d, xp.ones_like(h_p_d) / length)

    # Verify with closed forms
    # deriv=1, polyorder=1, 2
    def h_p_d_closed_form_1(k, m):
        return 6*(k - 0.5)/((2*m + 1)*m*(2*m - 1))

    # deriv=2, polyorder=2
    def h_p_d_closed_form_2(k, m):
        numer = 15*(-4*m**2 + 1 + 12*(k - 0.5)**2)
        denom = 4*(2*m + 1)*(m + 1)*m*(m - 1)*(2*m - 1)
        return numer/denom

    for length in window_lengths:
        m = length//2
        expected_output = [h_p_d_closed_form_1(k, m)
                           for k in range(-m + 1, m + 1)][::-1]
        expected_output = xp.asarray(expected_output, dtype=xp.float64)
        actual_output = savgol_coeffs(length, 1, 1, xp=xp)
        xp_assert_close(actual_output, expected_output)
        actual_output = savgol_coeffs(length, 2, 1, xp=xp)
        xp_assert_close(actual_output, expected_output)

        expected_output = [h_p_d_closed_form_2(k, m)
                           for k in range(-m + 1, m + 1)][::-1]
        expected_output = xp.asarray(expected_output, dtype=xp.float64)
        actual_output = savgol_coeffs(length, 2, 2, xp=xp)
        xp_assert_close(actual_output, expected_output)
        actual_output = savgol_coeffs(length, 3, 2, xp=xp)
        xp_assert_close(actual_output, expected_output)

#--------------------------------------------------------------------
# savgol_filter tests
#--------------------------------------------------------------------

@make_xp_test_case(savgol_filter)
def test_sg_filter_trivial(xp):
    """ Test some trivial edge cases for savgol_filter()."""
    x = xp.asarray([1.0])
    y = savgol_filter(x, 1, 0)
    xp_assert_equal(y, xp.asarray([1.0]))

    # Input is a single value. With a window length of 3 and polyorder 1,
    # the value in y is from the straight-line fit of (-1,0), (0,3) and
    # (1, 0) at 0. This is just the average of the three values, hence 1.0.
    x = xp.asarray([3.0])
    y = savgol_filter(x, 3, 1, mode='constant')
    xp_assert_close(y, xp.asarray([1.0]), atol=1.5e-15)

    x = xp.asarray([3.0])
    y = savgol_filter(x, 3, 1, mode='nearest')
    xp_assert_close(y, xp.asarray([3.0]), atol=1.5e-15)

    x = xp.asarray([1.0] * 3)
    y = savgol_filter(x, 3, 1, mode='wrap')
    xp_assert_close(y, xp.asarray([1.0, 1.0, 1.0]), atol=1.5e-15)


@make_xp_test_case(savgol_filter)
def test_sg_filter_basic(xp):
    # Some basic test cases for savgol_filter().
    x = xp.asarray([1.0, 2.0, 1.0])
    y = savgol_filter(x, 3, 1, mode='constant')
    xp_assert_close(y, xp.asarray([1.0, 4.0 / 3, 1.0]))

    y = savgol_filter(x, 3, 1, mode='mirror')
    xp_assert_close(y, xp.asarray([5.0 / 3, 4.0 / 3, 5.0 / 3]))

    y = savgol_filter(x, 3, 1, mode='wrap')
    xp_assert_close(y, xp.asarray([4.0 / 3, 4.0 / 3, 4.0 / 3]))


@make_xp_test_case(savgol_filter)
def test_sg_filter_2d(xp):
    x = xp.asarray([[1.0, 2.0, 1.0],
                    [2.0, 4.0, 2.0]])
    expected = xp.asarray([[1.0, 4.0 / 3, 1.0],
                           [2.0, 8.0 / 3, 2.0]])
    y = savgol_filter(x, 3, 1, mode='constant')
    xp_assert_close(y, expected)

    y = savgol_filter(x.T, 3, 1, mode='constant', axis=0)
    xp_assert_close(y, expected.T)


@make_xp_test_case(savgol_filter)
def test_sg_filter_interp_edges(xp):
    # Another test with low degree polynomial data, for which we can easily
    # give the exact results. In this test, we use mode='interp', so
    # savgol_filter should match the exact solution for the entire data set,
    # including the edges.
    t = xp.linspace(-5, 5, 21, dtype=xp.float64)
    delta = t[1] - t[0]
    # Polynomial test data.
    x = xp.stack([t,
                  3 * t ** 2,
                  t ** 3 - t])
    dx = xp.stack([xp.ones_like(t),
                   6 * t,
                   3 * t ** 2 - 1.0])
    d2x = xp.stack([xp.zeros_like(t),
                    xp.full_like(t, 6),
                    6 * t])

    window_length = 7

    y = savgol_filter(x, window_length, 3, axis=-1, mode='interp')
    xp_assert_close(y, x, atol=1e-12)

    y1 = savgol_filter(x, window_length, 3, axis=-1, mode='interp',
                       deriv=1, delta=delta)
    xp_assert_close(y1, dx, atol=1e-12)

    y2 = savgol_filter(x, window_length, 3, axis=-1, mode='interp',
                       deriv=2, delta=delta)
    xp_assert_close(y2, d2x, atol=1e-12)

    # Transpose everything, and test again with axis=0.

    x = x.T
    dx = dx.T
    d2x = d2x.T

    y = savgol_filter(x, window_length, 3, axis=0, mode='interp')
    xp_assert_close(y, x, atol=1e-12)

    y1 = savgol_filter(x, window_length, 3, axis=0, mode='interp',
                       deriv=1, delta=delta)
    xp_assert_close(y1, dx, atol=1e-12)

    y2 = savgol_filter(x, window_length, 3, axis=0, mode='interp',
                       deriv=2, delta=delta)
    xp_assert_close(y2, d2x, atol=1e-12)


@make_xp_test_case(savgol_filter)
def test_sg_filter_interp_edges_3d(xp):
    # Test mode='interp' with a 3-D array.
    t = xp.linspace(-5, 5, 21)
    delta = t[1] - t[0]
    x1 = xp.stack([t, -t])
    x2 = xp.stack([t ** 2, 3 * t ** 2 + 5])
    x3 = xp.stack([t ** 3, 2 * t ** 3 + t ** 2 - 0.5 * t])
    dx1 = xp.stack([xp.ones_like(t), -xp.ones_like(t)])
    dx2 = xp.stack([2 * t, 6 * t])
    dx3 = xp.stack([3 * t ** 2, 6 * t ** 2 + 2 * t - 0.5])

    # z has shape (3, 2, 21)
    z = xp.stack([x1, x2, x3])
    dz = xp.stack([dx1, dx2, dx3])

    y = savgol_filter(z, 7, 3, axis=-1, mode='interp', delta=delta)
    xp_assert_close(y, z, atol=1e-10)

    dy = savgol_filter(z, 7, 3, axis=-1, mode='interp', deriv=1, delta=delta)
    xp_assert_close(dy, dz, atol=1e-10)

    # z has shape (3, 21, 2)
    z = xp.stack([x1.T, x2.T, x3.T])
    dz = xp.stack([dx1.T, dx2.T, dx3.T])

    y = savgol_filter(z, 7, 3, axis=1, mode='interp', delta=delta)
    xp_assert_close(y, z, atol=1e-10)

    dy = savgol_filter(z, 7, 3, axis=1, mode='interp', deriv=1, delta=delta)
    xp_assert_close(dy, dz, atol=1e-10)

    # z has shape (21, 3, 2)
    z = xp.asarray(xp_swapaxes(z, 0, 1, xp=xp), copy=True)
    dz = xp.asarray(xp_swapaxes(dz, 0, 1, xp=xp), copy=True)

    y = savgol_filter(z, 7, 3, axis=0, mode='interp', delta=delta)
    xp_assert_close(y, z, atol=1e-10)

    dy = savgol_filter(z, 7, 3, axis=0, mode='interp', deriv=1, delta=delta)
    xp_assert_close(dy, dz, atol=1e-10)


@make_xp_test_case(savgol_filter)
def test_sg_filter_valid_window_length_3d(xp):
    """Tests that the window_length check is using the correct axis."""

    x = xp.ones((10, 20, 30))

    savgol_filter(x, window_length=29, polyorder=3, mode='interp')

    with pytest.raises(ValueError, match='window_length must be less than'):
        # window_length is more than x.shape[-1].
        savgol_filter(x, window_length=31, polyorder=3, mode='interp')

    savgol_filter(x, window_length=9, polyorder=3, axis=0, mode='interp')

    with pytest.raises(ValueError, match='window_length must be less than'):
        # window_length is more than x.shape[0].
        savgol_filter(x, window_length=11, polyorder=3, axis=0, mode='interp')
