from __future__ import annotations

import binascii
import ipaddress
import os
from dataclasses import dataclass
from enum import IntEnum

from .._compat import DATACLASS_KWARGS
from .._hazmat import AeadAes128Gcm, Buffer, RangeSet

PACKET_LONG_HEADER = 0x80
PACKET_FIXED_BIT = 0x40
PACKET_SPIN_BIT = 0x20

CONNECTION_ID_MAX_SIZE = 20
PACKET_NUMBER_MAX_SIZE = 4
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
RETRY_INTEGRITY_TAG_SIZE = 16
STATELESS_RESET_TOKEN_SIZE = 16


class QuicErrorCode(IntEnum):
    NO_ERROR = 0x0
    INTERNAL_ERROR = 0x1
    CONNECTION_REFUSED = 0x2
    FLOW_CONTROL_ERROR = 0x3
    STREAM_LIMIT_ERROR = 0x4
    STREAM_STATE_ERROR = 0x5
    FINAL_SIZE_ERROR = 0x6
    FRAME_ENCODING_ERROR = 0x7
    TRANSPORT_PARAMETER_ERROR = 0x8
    CONNECTION_ID_LIMIT_ERROR = 0x9
    PROTOCOL_VIOLATION = 0xA
    INVALID_TOKEN = 0xB
    APPLICATION_ERROR = 0xC
    CRYPTO_BUFFER_EXCEEDED = 0xD
    KEY_UPDATE_ERROR = 0xE
    AEAD_LIMIT_REACHED = 0xF
    VERSION_NEGOTIATION_ERROR = 0x11
    CRYPTO_ERROR = 0x100


class QuicPacketType(IntEnum):
    INITIAL = 0
    ZERO_RTT = 1
    HANDSHAKE = 2
    RETRY = 3
    VERSION_NEGOTIATION = 4
    ONE_RTT = 5


# For backwards compatibility only, use `QuicPacketType` in new code.
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL

# QUIC version 1
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
    QuicPacketType.INITIAL: 0,
    QuicPacketType.ZERO_RTT: 1,
    QuicPacketType.HANDSHAKE: 2,
    QuicPacketType.RETRY: 3,
}
PACKET_LONG_TYPE_DECODE_VERSION_1 = {
    v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
}

# QUIC version 2
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
    QuicPacketType.INITIAL: 1,
    QuicPacketType.ZERO_RTT: 2,
    QuicPacketType.HANDSHAKE: 3,
    QuicPacketType.RETRY: 0,
}
PACKET_LONG_TYPE_DECODE_VERSION_2 = {
    v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
}


class QuicProtocolVersion(IntEnum):
    NEGOTIATION = 0
    VERSION_1 = 0x00000001
    VERSION_2 = 0x6B3343CF


@dataclass(**DATACLASS_KWARGS)
class QuicHeader:
    version: int | None
    "The protocol version. Only present in long header packets."

    packet_type: QuicPacketType
    "The type of the packet."

    packet_length: int
    "The total length of the packet, in bytes."

    destination_cid: bytes
    "The destination connection ID."

    source_cid: bytes
    "The destination connection ID."

    token: bytes
    "The address verification token. Only present in `INITIAL` and `RETRY` packets."

    integrity_tag: bytes
    "The retry integrity tag. Only present in `RETRY` packets."

    supported_versions: list[int]
    "Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."


def get_retry_integrity_tag(
    packet_without_tag: bytes, original_destination_cid: bytes, version: int
) -> bytes:
    """
    Calculate the integrity tag for a RETRY packet.
    """
    # build Retry pseudo packet
    buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
    buf.push_uint8(len(original_destination_cid))
    buf.push_bytes(original_destination_cid)
    buf.push_bytes(packet_without_tag)
    assert buf.eof()

    if version == QuicProtocolVersion.VERSION_2:
        aead_key = RETRY_AEAD_KEY_VERSION_2
        aead_nonce = RETRY_AEAD_NONCE_VERSION_2
    else:
        aead_key = RETRY_AEAD_KEY_VERSION_1
        aead_nonce = RETRY_AEAD_NONCE_VERSION_1

    # run AES-128-GCM
    aead = AeadAes128Gcm(aead_key, b"null!12bytes")
    integrity_tag = aead.encrypt_with_nonce(aead_nonce, b"", buf.data)
    assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
    return integrity_tag


def get_spin_bit(first_byte: int) -> bool:
    if first_byte & PACKET_SPIN_BIT:
        return True
    return False


def is_long_header(first_byte: int) -> bool:
    if first_byte & PACKET_LONG_HEADER:
        return True
    return False


def pretty_protocol_version(version: int) -> str:
    """
    Return a user-friendly representation of a protocol version.
    """
    try:
        version_name = QuicProtocolVersion(version).name
    except ValueError:
        version_name = "UNKNOWN"
    return f"0x{version:08x} ({version_name})"


def pull_quic_header(buf: Buffer, host_cid_length: int | None = None) -> QuicHeader:
    packet_start = buf.tell()

    version: int | None

    integrity_tag = b""
    supported_versions = []
    token = b""

    first_byte = buf.pull_uint8()

    if is_long_header(first_byte):
        # Long Header Packets.
        # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
        version = buf.pull_uint32()

        destination_cid_length = buf.pull_uint8()
        if destination_cid_length > CONNECTION_ID_MAX_SIZE:
            raise ValueError(
                f"Destination CID is too long ({destination_cid_length} bytes)"
            )
        destination_cid = buf.pull_bytes(destination_cid_length)

        source_cid_length = buf.pull_uint8()
        if source_cid_length > CONNECTION_ID_MAX_SIZE:
            raise ValueError(f"Source CID is too long ({source_cid_length} bytes)")
        source_cid = buf.pull_bytes(source_cid_length)

        if version == QuicProtocolVersion.NEGOTIATION:
            # Version Negotiation Packet.
            # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
            packet_type = QuicPacketType.VERSION_NEGOTIATION
            while not buf.eof():
                supported_versions.append(buf.pull_uint32())
            packet_end = buf.tell()
        else:
            if not (first_byte & PACKET_FIXED_BIT):
                raise ValueError("Packet fixed bit is zero")

            if version == QuicProtocolVersion.VERSION_2:
                packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
                    (first_byte & 0x30) >> 4
                ]
            else:
                packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
                    (first_byte & 0x30) >> 4
                ]

            if packet_type == QuicPacketType.INITIAL:
                token_length = buf.pull_uint_var()
                token = buf.pull_bytes(token_length)
                rest_length = buf.pull_uint_var()
            elif packet_type == QuicPacketType.ZERO_RTT:
                rest_length = buf.pull_uint_var()
            elif packet_type == QuicPacketType.HANDSHAKE:
                rest_length = buf.pull_uint_var()
            else:
                token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
                token = buf.pull_bytes(token_length)
                integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
                rest_length = 0

            # Check remainder length.
            packet_end = buf.tell() + rest_length

            if packet_end > buf.capacity:
                raise ValueError("Packet payload is truncated")

    else:
        # https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
        if not (first_byte & PACKET_FIXED_BIT):
            raise ValueError("Packet fixed bit is zero")

        version = None
        packet_type = QuicPacketType.ONE_RTT
        destination_cid = buf.pull_bytes(host_cid_length)
        source_cid = b""
        packet_end = buf.capacity

    return QuicHeader(
        version=version,
        packet_type=packet_type,
        packet_length=packet_end - packet_start,
        destination_cid=destination_cid,
        source_cid=source_cid,
        token=token,
        integrity_tag=integrity_tag,
        supported_versions=supported_versions,
    )


def encode_long_header_first_byte(
    version: int, packet_type: QuicPacketType, bits: int
) -> int:
    """
    Encode the first byte of a long header packet.
    """
    if version == QuicProtocolVersion.VERSION_2:
        long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
    else:
        long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
    return (
        PACKET_LONG_HEADER
        | PACKET_FIXED_BIT
        | long_type_encode[packet_type] << 4
        | bits
    )


def encode_quic_retry(
    version: int,
    source_cid: bytes,
    destination_cid: bytes,
    original_destination_cid: bytes,
    retry_token: bytes,
    unused: int = 0,
) -> bytes:
    buf = Buffer(
        capacity=7
        + len(destination_cid)
        + len(source_cid)
        + len(retry_token)
        + RETRY_INTEGRITY_TAG_SIZE
    )
    buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
    buf.push_uint32(version)
    buf.push_uint8(len(destination_cid))
    buf.push_bytes(destination_cid)
    buf.push_uint8(len(source_cid))
    buf.push_bytes(source_cid)
    buf.push_bytes(retry_token)
    buf.push_bytes(
        get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
    )
    assert buf.eof()
    return buf.data


def encode_quic_version_negotiation(
    source_cid: bytes, destination_cid: bytes, supported_versions: list[int]
) -> bytes:
    buf = Buffer(
        capacity=7
        + len(destination_cid)
        + len(source_cid)
        + 4 * len(supported_versions)
    )
    buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
    buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
    buf.push_uint8(len(destination_cid))
    buf.push_bytes(destination_cid)
    buf.push_uint8(len(source_cid))
    buf.push_bytes(source_cid)
    for version in supported_versions:
        buf.push_uint32(version)
    return buf.data


# TLS EXTENSION


@dataclass(**DATACLASS_KWARGS)
class QuicPreferredAddress:
    ipv4_address: tuple[str, int] | None
    ipv6_address: tuple[str, int] | None
    connection_id: bytes
    stateless_reset_token: bytes


@dataclass(**DATACLASS_KWARGS)
class QuicVersionInformation:
    chosen_version: int
    available_versions: list[int]


@dataclass()
class QuicTransportParameters:
    original_destination_connection_id: bytes | None = None
    max_idle_timeout: int | None = None
    stateless_reset_token: bytes | None = None
    max_udp_payload_size: int | None = None
    initial_max_data: int | None = None
    initial_max_stream_data_bidi_local: int | None = None
    initial_max_stream_data_bidi_remote: int | None = None
    initial_max_stream_data_uni: int | None = None
    initial_max_streams_bidi: int | None = None
    initial_max_streams_uni: int | None = None
    ack_delay_exponent: int | None = None
    max_ack_delay: int | None = None
    disable_active_migration: bool | None = False
    preferred_address: QuicPreferredAddress | None = None
    active_connection_id_limit: int | None = None
    initial_source_connection_id: bytes | None = None
    retry_source_connection_id: bytes | None = None
    version_information: QuicVersionInformation | None = None
    max_datagram_frame_size: int | None = None
    quantum_readiness: bytes | None = None


PARAMS = {
    0x00: ("original_destination_connection_id", bytes),
    0x01: ("max_idle_timeout", int),
    0x02: ("stateless_reset_token", bytes),
    0x03: ("max_udp_payload_size", int),
    0x04: ("initial_max_data", int),
    0x05: ("initial_max_stream_data_bidi_local", int),
    0x06: ("initial_max_stream_data_bidi_remote", int),
    0x07: ("initial_max_stream_data_uni", int),
    0x08: ("initial_max_streams_bidi", int),
    0x09: ("initial_max_streams_uni", int),
    0x0A: ("ack_delay_exponent", int),
    0x0B: ("max_ack_delay", int),
    0x0C: ("disable_active_migration", bool),
    0x0D: ("preferred_address", QuicPreferredAddress),
    0x0E: ("active_connection_id_limit", int),
    0x0F: ("initial_source_connection_id", bytes),
    0x10: ("retry_source_connection_id", bytes),
    # https://datatracker.ietf.org/doc/html/rfc9368#section-3
    0x11: ("version_information", QuicVersionInformation),
    # extensions
    0x0020: ("max_datagram_frame_size", int),
    0x0C37: ("quantum_readiness", bytes),
}


def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
    ipv4_address = None
    ipv4_host = buf.pull_bytes(4)
    ipv4_port = buf.pull_uint16()
    if ipv4_host != bytes(4):
        ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)

    ipv6_address = None
    ipv6_host = buf.pull_bytes(16)
    ipv6_port = buf.pull_uint16()
    if ipv6_host != bytes(16):
        ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)

    connection_id_length = buf.pull_uint8()
    connection_id = buf.pull_bytes(connection_id_length)
    stateless_reset_token = buf.pull_bytes(16)

    return QuicPreferredAddress(
        ipv4_address=ipv4_address,
        ipv6_address=ipv6_address,
        connection_id=connection_id,
        stateless_reset_token=stateless_reset_token,
    )


def push_quic_preferred_address(
    buf: Buffer, preferred_address: QuicPreferredAddress
) -> None:
    if preferred_address.ipv4_address is not None:
        buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
        buf.push_uint16(preferred_address.ipv4_address[1])
    else:
        buf.push_bytes(bytes(6))

    if preferred_address.ipv6_address is not None:
        buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
        buf.push_uint16(preferred_address.ipv6_address[1])
    else:
        buf.push_bytes(bytes(18))

    buf.push_uint8(len(preferred_address.connection_id))
    buf.push_bytes(preferred_address.connection_id)
    buf.push_bytes(preferred_address.stateless_reset_token)


def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
    chosen_version = buf.pull_uint32()
    available_versions = []
    for i in range(length // 4 - 1):
        available_versions.append(buf.pull_uint32())

    # If an endpoint receives a Chosen Version equal to zero, or any Available Version
    # equal to zero, it MUST treat it as a parsing failure.
    #
    # https://datatracker.ietf.org/doc/html/rfc9368#section-4
    if chosen_version == 0 or 0 in available_versions:
        raise ValueError("Version Information must not contain version 0")

    return QuicVersionInformation(
        chosen_version=chosen_version,
        available_versions=available_versions,
    )


def push_quic_version_information(
    buf: Buffer, version_information: QuicVersionInformation
) -> None:
    buf.push_uint32(version_information.chosen_version)
    for version in version_information.available_versions:
        buf.push_uint32(version)


def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
    params = QuicTransportParameters()
    while not buf.eof():
        param_id = buf.pull_uint_var()
        param_len = buf.pull_uint_var()
        param_start = buf.tell()
        if param_id in PARAMS:
            # parse known parameter
            param_name, param_type = PARAMS[param_id]
            if param_type is int:
                setattr(params, param_name, buf.pull_uint_var())
            elif param_type is bytes:
                setattr(params, param_name, buf.pull_bytes(param_len))
            elif param_type is QuicPreferredAddress:
                setattr(params, param_name, pull_quic_preferred_address(buf))
            elif param_type is QuicVersionInformation:
                setattr(
                    params,
                    param_name,
                    pull_quic_version_information(buf, param_len),
                )
            else:
                setattr(params, param_name, True)
        else:
            # skip unknown parameter
            buf.pull_bytes(param_len)

        if buf.tell() != param_start + param_len:
            raise ValueError("Transport parameter length does not match")

    return params


def push_quic_transport_parameters(
    buf: Buffer, params: QuicTransportParameters
) -> None:
    for param_id, (param_name, param_type) in PARAMS.items():
        param_value = getattr(params, param_name)
        if param_value is not None and param_value is not False:
            param_buf = Buffer(capacity=65536)
            if param_type is int:
                param_buf.push_uint_var(param_value)
            elif param_type is bytes:
                param_buf.push_bytes(param_value)
            elif param_type is QuicPreferredAddress:
                push_quic_preferred_address(param_buf, param_value)
            elif param_type is QuicVersionInformation:
                push_quic_version_information(param_buf, param_value)
            buf.push_uint_var(param_id)
            buf.push_uint_var(param_buf.tell())
            buf.push_bytes(param_buf.data)


# FRAMES


class QuicFrameType(IntEnum):
    PADDING = 0x00
    PING = 0x01
    ACK = 0x02
    ACK_ECN = 0x03
    RESET_STREAM = 0x04
    STOP_SENDING = 0x05
    CRYPTO = 0x06
    NEW_TOKEN = 0x07
    STREAM_BASE = 0x08
    MAX_DATA = 0x10
    MAX_STREAM_DATA = 0x11
    MAX_STREAMS_BIDI = 0x12
    MAX_STREAMS_UNI = 0x13
    DATA_BLOCKED = 0x14
    STREAM_DATA_BLOCKED = 0x15
    STREAMS_BLOCKED_BIDI = 0x16
    STREAMS_BLOCKED_UNI = 0x17
    NEW_CONNECTION_ID = 0x18
    RETIRE_CONNECTION_ID = 0x19
    PATH_CHALLENGE = 0x1A
    PATH_RESPONSE = 0x1B
    TRANSPORT_CLOSE = 0x1C
    APPLICATION_CLOSE = 0x1D
    HANDSHAKE_DONE = 0x1E
    DATAGRAM = 0x30
    DATAGRAM_WITH_LENGTH = 0x31


NON_ACK_ELICITING_FRAME_TYPES = frozenset(
    [
        QuicFrameType.ACK,
        QuicFrameType.ACK_ECN,
        QuicFrameType.PADDING,
        QuicFrameType.TRANSPORT_CLOSE,
        QuicFrameType.APPLICATION_CLOSE,
    ]
)
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
    [
        QuicFrameType.ACK,
        QuicFrameType.ACK_ECN,
        QuicFrameType.TRANSPORT_CLOSE,
        QuicFrameType.APPLICATION_CLOSE,
    ]
)

PROBING_FRAME_TYPES = frozenset(
    [
        QuicFrameType.PATH_CHALLENGE,
        QuicFrameType.PATH_RESPONSE,
        QuicFrameType.PADDING,
        QuicFrameType.NEW_CONNECTION_ID,
    ]
)


@dataclass(**DATACLASS_KWARGS)
class QuicResetStreamFrame:
    error_code: int
    final_size: int
    stream_id: int


@dataclass(**DATACLASS_KWARGS)
class QuicStopSendingFrame:
    error_code: int
    stream_id: int


@dataclass(**DATACLASS_KWARGS)
class QuicStreamFrame:
    data: bytes = b""
    fin: bool = False
    offset: int = 0


def pull_ack_frame(buf: Buffer) -> tuple[RangeSet, int]:
    rangeset = RangeSet()
    end = buf.pull_uint_var()  # largest acknowledged
    delay = buf.pull_uint_var()
    ack_range_count = buf.pull_uint_var()
    ack_count = buf.pull_uint_var()  # first ack range
    rangeset.add(end - ack_count, end + 1)
    end -= ack_count
    for _ in range(ack_range_count):
        end -= buf.pull_uint_var() + 2
        ack_count = buf.pull_uint_var()
        rangeset.add(end - ack_count, end + 1)
        end -= ack_count
    return rangeset, delay


def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
    ranges = len(rangeset)
    index = ranges - 1
    r = rangeset[index]
    buf.push_uint_var(r[1] - 1)
    buf.push_uint_var(delay)
    buf.push_uint_var(index)
    buf.push_uint_var(r[1] - 1 - r[0])
    start = r[0]
    while index > 0:
        index -= 1
        r = rangeset[index]
        buf.push_uint_var(start - r[1] - 1)
        buf.push_uint_var(r[1] - r[0] - 1)
        start = r[0]
    return ranges
