''' Tests for fortran sequential files '''

import tempfile
import shutil
import os
from os import path
from glob import iglob
import threading
import re

from numpy.testing import assert_equal, assert_allclose
import numpy as np
import pytest

from scipy.io import (FortranFile,
                      FortranEOFError,
                      FortranFormattingError)


DATA_PATH = path.join(path.dirname(__file__), 'data')


@pytest.fixture
def io_lock():
    return threading.Lock()


def test_fortranfiles_read(io_lock):
    for filename in iglob(path.join(DATA_PATH, "fortran-*-*x*x*.dat")):
        m = re.search(r'fortran-([^-]+)-(\d+)x(\d+)x(\d+).dat', filename, re.I)
        if not m:
            raise RuntimeError(f"Couldn't match {filename} filename to regex")

        dims = (int(m.group(2)), int(m.group(3)), int(m.group(4)))

        dtype = m.group(1).replace('s', '<')

        with io_lock:
            f = FortranFile(filename, 'r', '<u4')
            data = f.read_record(dtype=dtype).reshape(dims, order='F')
            f.close()

        expected = np.arange(np.prod(dims)).reshape(dims).astype(dtype)
        assert_equal(data, expected)


def test_fortranfiles_mixed_record(io_lock):
    filename = path.join(DATA_PATH, "fortran-mixed.dat")
    with io_lock:
        with FortranFile(filename, 'r', '<u4') as f:
            record = f.read_record('<i4,<f4,<i8,2<f8')

    assert_equal(record['f0'][0], 1)
    assert_allclose(record['f1'][0], 2.3)
    assert_equal(record['f2'][0], 4)
    assert_allclose(record['f3'][0], [5.6, 7.8])


def test_fortranfiles_write():
    for filename in iglob(path.join(DATA_PATH, "fortran-*-*x*x*.dat")):
        m = re.search(r'fortran-([^-]+)-(\d+)x(\d+)x(\d+).dat', filename, re.I)
        if not m:
            raise RuntimeError(f"Couldn't match {filename} filename to regex")
        dims = (int(m.group(2)), int(m.group(3)), int(m.group(4)))

        dtype = m.group(1).replace('s', '<')
        data = np.arange(np.prod(dims)).reshape(dims).astype(dtype)

        tmpdir = tempfile.mkdtemp()
        try:
            testFile = path.join(str(threading.get_native_id()),
                                 tmpdir,path.basename(filename))
            f = FortranFile(testFile, 'w','<u4')
            f.write_record(data.T)
            f.close()
            originalfile = open(filename, 'rb')
            newfile = open(testFile, 'rb')
            assert_equal(originalfile.read(), newfile.read(),
                         err_msg=filename)
            originalfile.close()
            newfile.close()
        finally:
            shutil.rmtree(tmpdir)


def test_fortranfile_read_mixed_record(io_lock):
    # The data file fortran-3x3d-2i.dat contains the program that
    # produced it at the end.
    #
    # double precision :: a(3,3)
    # integer :: b(2)
    # ...
    # open(1, file='fortran-3x3d-2i.dat', form='unformatted')
    # write(1) a, b
    # close(1)
    #

    filename = path.join(DATA_PATH, "fortran-3x3d-2i.dat")
    with io_lock:
        with FortranFile(filename, 'r', '<u4') as f:
            record = f.read_record('(3,3)<f8', '2<i4')

    ax = np.arange(3*3).reshape(3, 3).astype(np.float64)
    bx = np.array([-1, -2], dtype=np.int32)

    assert_equal(record[0], ax.T)
    assert_equal(record[1], bx.T)


def test_fortranfile_write_mixed_record(tmpdir):
    tf = path.join(str(tmpdir), str(threading.get_native_id()), 'test.dat')
    os.makedirs(path.dirname(tf), exist_ok=True)

    r1 = (('f4', 'f4', 'i4'), (np.float32(2), np.float32(3), np.int32(100)))
    r2 = (('4f4', '(3,3)f4', '8i4'),
          (np.random.randint(255, size=[4]).astype(np.float32),
           np.random.randint(255, size=[3, 3]).astype(np.float32),
           np.random.randint(255, size=[8]).astype(np.int32)))
    records = [r1, r2]

    for dtype, a in records:
        with FortranFile(tf, 'w') as f:
            f.write_record(*a)

        with FortranFile(tf, 'r') as f:
            b = f.read_record(*dtype)

        assert_equal(len(a), len(b))

        for aa, bb in zip(a, b):
            assert_equal(bb, aa)

def read_unformatted_double(m, n, k, filename):
    """
    Read a Fortran-style unformatted binary file written with a single write() call,
    assuming it wraps the data with 4-byte record markers.

    Returns:
        np.ndarray of shape (m, n, k) with dtype float64

    Reference:
        Fortran implementation:
        https://github.com/scipy/scipy/blob/maintenance/1.15.x/scipy/io/_test_fortran.f#L1-L9
    """
    with open(filename.strip(), 'rb') as f:
        f.read(4)  # Skip initial 4-byte record marker
        data = np.fromfile(f, dtype=np.float64, count=m * n * k)
        f.read(4)  # Skip trailing 4-byte record marker

    if data.size != m * n * k:
        raise ValueError(f"Expected {m*n*k} elements, got {data.size}")

    return data.reshape((m, n, k), order='F')  # Fortran column-major order

def read_unformatted_mixed(m, n, k, filename):
    """
    Read a Fortran unformatted binary file that contains a mix of:
    - a double precision array a(m, n)
    - an integer array b(k)

    Assumes a single write(10) a, b was used and file is wrapped
    with Fortran record markers.

    Returns:
        a: np.ndarray of shape (m, n) with dtype float64
        b: np.ndarray of shape (k,) with dtype int32

    Reference:
        Fortran implementation:
        https://github.com/scipy/scipy/blob/maintenance/1.15.x/scipy/io/_test_fortran.f#L21-L30
    """
    with open(filename.strip(), 'rb') as f:
        f.read(4)  # Skip initial 4-byte record marker

        # Read a(m,n): total m*n float64 values
        a_flat = np.fromfile(f, dtype=np.float64, count=m * n)

        # Read b(k): total k int32 values (assuming Fortran default integer*4)
        b = np.fromfile(f, dtype=np.int32, count=k)

        f.read(4)  # Skip trailing 4-byte record marker

    # Reshape a to (m,n) Fortran-style
    a = a_flat.reshape((m, n), order='F')

    return a, b

def read_unformatted_int(m, n, k, filename):
    """
    Read a Fortran unformatted binary file
    containing a 3D integer array (m, n, k).
    Assumes the array is written with a single
    write(10) a and wrapped with record markers.

    Returns:
        np.ndarray: 3D array of shape (m, n, k) with dtype int32

    Reference:
        Fortran implementation:
        https://github.com/scipy/scipy/blob/maintenance/1.15.x/scipy/io/_test_fortran.f#L11-L19
    """
    with open(filename.strip(), 'rb') as f:
        f.read(4)  # Skip Fortran record marker at start

        # Read m*n*k integers (Fortran default = 4 bytes per integer)
        data = np.fromfile(f, dtype=np.int32, count=m * n * k)

        f.read(4)  # Skip Fortran record marker at end

    if data.size != m * n * k:
        raise ValueError(f"Expected {m*n*k} elements, got {data.size}")

    return data.reshape((m, n, k), order='F')  # Fortran-style column-major order

def test_fortran_roundtrip(tmpdir, io_lock):
    filename = path.join(str(tmpdir), str(threading.get_native_id()),
                         'test.dat')
    os.makedirs(path.dirname(filename), exist_ok=True)

    rng = np.random.RandomState(1)

    # double precision
    m, n, k = 5, 3, 2
    a = rng.randn(m, n, k)
    with FortranFile(filename, 'w') as f:
        f.write_record(a.T)
    with io_lock:
        a2 = read_unformatted_double(m, n, k, filename)

    with FortranFile(filename, 'r') as f:
        a3 = f.read_record('(2,3,5)f8').T
    assert_equal(a2, a)
    assert_equal(a3, a)

    # integer
    m, n, k = 5, 3, 2
    a = rng.randn(m, n, k).astype(np.int32)
    with FortranFile(filename, 'w') as f:
        f.write_record(a.T)
    with io_lock:
        a2 = read_unformatted_int(m, n, k, filename)
    with FortranFile(filename, 'r') as f:
        a3 = f.read_record('(2,3,5)i4').T
    assert_equal(a2, a)
    assert_equal(a3, a)

    # mixed
    m, n, k = 5, 3, 2
    a = rng.randn(m, n)
    b = rng.randn(k).astype(np.intc)
    with FortranFile(filename, 'w') as f:
        f.write_record(a.T, b.T)
    with io_lock:
        a2, b2 = read_unformatted_mixed(m, n, k, filename)
    with FortranFile(filename, 'r') as f:
        a3, b3 = f.read_record('(3,5)f8', '2i4')
        a3 = a3.T
    assert_equal(a2, a)
    assert_equal(a3, a)
    assert_equal(b2, b)
    assert_equal(b3, b)


def test_fortran_eof_ok(tmpdir):
    filename = path.join(str(tmpdir), str(threading.get_native_id()),
                         "scratch")
    os.makedirs(path.dirname(filename), exist_ok=True)
    rng = np.random.RandomState(1)
    with FortranFile(filename, 'w') as f:
        f.write_record(rng.randn(5))
        f.write_record(rng.randn(3))
    with FortranFile(filename, 'r') as f:
        assert len(f.read_reals()) == 5
        assert len(f.read_reals()) == 3
        with pytest.raises(FortranEOFError):
            f.read_reals()


def test_fortran_eof_broken_size(tmpdir):
    filename = path.join(str(tmpdir), str(threading.get_native_id()),
                         "scratch")
    os.makedirs(path.dirname(filename), exist_ok=True)
    rng = np.random.RandomState(1)
    with FortranFile(filename, 'w') as f:
        f.write_record(rng.randn(5))
        f.write_record(rng.randn(3))
    with open(filename, "ab") as f:
        f.write(b"\xff")
    with FortranFile(filename, 'r') as f:
        assert len(f.read_reals()) == 5
        assert len(f.read_reals()) == 3
        with pytest.raises(FortranFormattingError):
            f.read_reals()


def test_fortran_bogus_size(tmpdir):
    filename = path.join(str(tmpdir), str(threading.get_native_id()),
                         "scratch")
    os.makedirs(path.dirname(filename), exist_ok=True)
    rng = np.random.RandomState(1)
    with FortranFile(filename, 'w') as f:
        f.write_record(rng.randn(5))
        f.write_record(rng.randn(3))
    with open(filename, "w+b") as f:
        f.write(b"\xff\xff")
    with FortranFile(filename, 'r') as f:
        with pytest.raises(FortranFormattingError):
            f.read_reals()


def test_fortran_eof_broken_record(tmpdir):
    filename = path.join(str(tmpdir), str(threading.get_native_id()),
                         "scratch")
    os.makedirs(path.dirname(filename), exist_ok=True)
    rng = np.random.RandomState(1)
    with FortranFile(filename, 'w') as f:
        f.write_record(rng.randn(5))
        f.write_record(rng.randn(3))
    with open(filename, "ab") as f:
        f.truncate(path.getsize(filename)-20)
    with FortranFile(filename, 'r') as f:
        assert len(f.read_reals()) == 5
        with pytest.raises(FortranFormattingError):
            f.read_reals()


def test_fortran_eof_multidimensional(tmpdir):
    filename = path.join(str(tmpdir), str(threading.get_native_id()),
                         "scratch")
    os.makedirs(path.dirname(filename), exist_ok=True)
    n, m, q = 3, 5, 7
    dt = np.dtype([("field", np.float64, (n, m))])
    a = np.zeros(q, dtype=dt)
    with FortranFile(filename, 'w') as f:
        f.write_record(a[0])
        f.write_record(a)
        f.write_record(a)
    with open(filename, "ab") as f:
        f.truncate(path.getsize(filename)-20)
    with FortranFile(filename, 'r') as f:
        assert len(f.read_record(dtype=dt)) == 1
        assert len(f.read_record(dtype=dt)) == q
        with pytest.raises(FortranFormattingError):
            f.read_record(dtype=dt)
