import pytest

import numpy as np
from numpy.testing import assert_allclose

from scipy.integrate import quad_vec
from scipy._lib._array_api import make_xp_test_case

from multiprocessing.dummy import Pool


quadrature_params = pytest.mark.parametrize(
    'quadrature', [None, "gk15", "gk21", "trapezoid"])

def _lorenzian(x):
    return 1 / (1 + x**2)

def _func_with_args(x, a):
    return x * (x + a) * np.arange(3)


@make_xp_test_case(quad_vec)
class TestQuadVec:
    @quadrature_params
    def test_quad_vec_simple(self, quadrature):
        n = np.arange(10)
        def f(x):
            return x ** n
        for epsabs in [0.1, 1e-3, 1e-6]:
            if quadrature == 'trapezoid' and epsabs < 1e-4:
                # slow: skip
                continue

            kwargs = dict(epsabs=epsabs, quadrature=quadrature)

            exact = 2**(n+1)/(n + 1)

            res, err = quad_vec(f, 0, 2, norm='max', **kwargs)
            assert_allclose(res, exact, rtol=0, atol=epsabs)

            res, err = quad_vec(f, 0, 2, norm='2', **kwargs)
            assert np.linalg.norm(res - exact) < epsabs

            res, err = quad_vec(f, 0, 2, norm='max', points=(0.5, 1.0), **kwargs)
            assert_allclose(res, exact, rtol=0, atol=epsabs)

            res, err, *rest = quad_vec(f, 0, 2, norm='max',
                                    epsrel=1e-8,
                                    full_output=True,
                                    limit=10000,
                                    **kwargs)
            assert_allclose(res, exact, rtol=0, atol=epsabs)


    @quadrature_params
    def test_quad_vec_simple_inf(self, quadrature):
        def f(x):
            return 1 / (1 + np.float64(x) ** 2)

        for epsabs in [0.1, 1e-3, 1e-6]:
            if quadrature == 'trapezoid' and epsabs < 1e-4:
                # slow: skip
                continue

            kwargs = dict(norm='max', epsabs=epsabs, quadrature=quadrature)

            res, err = quad_vec(f, 0, np.inf, **kwargs)
            assert_allclose(res, np.pi/2, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, 0, -np.inf, **kwargs)
            assert_allclose(res, -np.pi/2, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, -np.inf, 0, **kwargs)
            assert_allclose(res, np.pi/2, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, np.inf, 0, **kwargs)
            assert_allclose(res, -np.pi/2, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, -np.inf, np.inf, **kwargs)
            assert_allclose(res, np.pi, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, np.inf, -np.inf, **kwargs)
            assert_allclose(res, -np.pi, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, np.inf, np.inf, **kwargs)
            assert_allclose(res, 0, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, -np.inf, -np.inf, **kwargs)
            assert_allclose(res, 0, rtol=0, atol=max(epsabs, err))

            res, err = quad_vec(f, 0, np.inf, points=(1.0, 2.0), **kwargs)
            assert_allclose(res, np.pi/2, rtol=0, atol=max(epsabs, err))

        def f(x):
            return np.sin(x + 2) / (1 + x ** 2)
        exact = np.pi / np.e * np.sin(2)
        epsabs = 1e-5

        res, err, info = quad_vec(f, -np.inf, np.inf, limit=1000, norm='max',
                                  epsabs=epsabs, quadrature=quadrature,
                                  full_output=True)
        assert info.status == 1
        assert_allclose(res, exact, rtol=0, atol=max(epsabs, 1.5 * err))


    def test_quad_vec_args(self):
        def f(x, a):
            return x * (x + a) * np.arange(3)
        a = 2
        exact = np.array([0, 4/3, 8/3])

        res, err = quad_vec(f, 0, 1, args=(a,))
        assert_allclose(res, exact, rtol=0, atol=1e-4)

    @pytest.mark.fail_slow(10)
    def test_quad_vec_pool(self):
        f = _lorenzian
        res, err = quad_vec(f, -np.inf, np.inf, norm='max', epsabs=1e-4, workers=4)
        assert_allclose(res, np.pi, rtol=0, atol=1e-4)

        with Pool(10) as pool:
            def f(x):
                return 1 / (1 + x ** 2)
            res, _ = quad_vec(f, -np.inf, np.inf, norm='max', epsabs=1e-4,
                              workers=pool.map)
            assert_allclose(res, np.pi, rtol=0, atol=1e-4)

    @pytest.mark.fail_slow(10)
    @pytest.mark.parametrize('extra_args', [2, (2,)])
    @pytest.mark.parametrize(
        'workers',
        [1, pytest.param(10, marks=pytest.mark.parallel_threads_limit(4))]
    )
    def test_quad_vec_pool_args(self, extra_args, workers):
        f = _func_with_args
        exact = np.array([0, 4/3, 8/3])

        res, err = quad_vec(f, 0, 1, args=extra_args, workers=workers)
        assert_allclose(res, exact, rtol=0, atol=1e-4)

        with Pool(workers) as pool:
            res, err = quad_vec(f, 0, 1, args=extra_args, workers=pool.map)
            assert_allclose(res, exact, rtol=0, atol=1e-4)

    @quadrature_params
    def test_num_eval(self, quadrature):
        def f(x):
            count[0] += 1
            return x**5

        count = [0]
        res = quad_vec(f, 0, 1, norm='max', full_output=True, quadrature=quadrature)
        assert res[2].neval == count[0]

    def test_info(self):
        def f(x):
            return np.ones((3, 2, 1))

        res, err, info = quad_vec(f, 0, 1, norm='max', full_output=True)

        assert info.success is True
        assert info.status == 0
        assert info.message == 'Target precision reached.'
        assert info.neval > 0
        assert info.intervals.shape[1] == 2
        assert info.integrals.shape == (info.intervals.shape[0], 3, 2, 1)
        assert info.errors.shape == (info.intervals.shape[0],)

    def test_nan_inf(self):
        def f_nan(x):
            return np.nan

        def f_inf(x):
            return np.inf if x < 0.1 else 1/x

        res, err, info = quad_vec(f_nan, 0, 1, full_output=True)
        assert info.status == 3

        res, err, info = quad_vec(f_inf, 0, 1, full_output=True)
        assert info.status == 3


    @pytest.mark.parametrize('a,b', [(0, 1), (0, np.inf), (np.inf, 0),
                                    (-np.inf, np.inf), (np.inf, -np.inf)])
    def test_points(self, a, b):
        # Check that initial interval splitting is done according to
        # `points`, by checking that consecutive sets of 15 point (for
        # gk15) function evaluations lie between `points`

        points = (0, 0.25, 0.5, 0.75, 1.0)
        points += tuple(-x for x in points)

        quadrature_points = 15
        interval_sets = []
        count = 0

        def f(x):
            nonlocal count

            if count % quadrature_points == 0:
                interval_sets.append(set())

            count += 1
            interval_sets[-1].add(float(x))
            return 0.0

        quad_vec(f, a, b, points=points, quadrature='gk15', limit=0)

        # Check that all point sets lie in a single `points` interval
        for p in interval_sets:
            j = np.searchsorted(sorted(points), tuple(p))
            assert np.all(j == j[0])
