# -*- coding: utf-8 -*-

# This file is part of the desktop management solution opsi http://www.opsi.org
# Copyright (c) 2023-2024 uib GmbH <info@uib.de>
# This code is owned by the uib GmbH, Mainz, Germany (uib.de). All rights reserved.
# License: AGPL-3.0

from __future__ import annotations

from enum import Enum, StrEnum
from ipaddress import IPv4Address
from random import randint
from typing import Any, Generator
from uuid import UUID

import scapy.layers.dhcp  # type: ignore[import]
from opsicommon.logging import get_logger
from pydantic_extra_types.mac_address import MacAddress
from scapy.layers.dhcp import BOOTP, DHCP  # type: ignore[import]
from scapy.layers.dhcp6 import DUID_EN, DUID_LL, DUID_LLT, DUID_UUID  # type: ignore[import]
from scapy.layers.l2 import Ether  # type: ignore[import]
from scapy.packet import Packet  # type: ignore[import]

logger = get_logger("opsiagent.plugin.boot_server.dhcp")

scapy.layers.dhcp.DHCPOptions[175] = "etherboot"
scapy.layers.dhcp.DHCPRevOptions["etherboot"] = (175, None)
scapy.layers.dhcp.DHCPOptions[119] = "domain_search"
scapy.layers.dhcp.DHCPRevOptions["domain_search"] = (119, None)


class DHCPOperationType(Enum):
	BOOT_REQUEST = 1
	BOOT_REPLY = 2


class DHCPMessageType(Enum):
	DISCOVER = 1
	OFFER = 2
	REQUEST = 3
	DECLINE = 4
	ACK = 5
	NAK = 6
	RELEASE = 7
	INFORM = 8
	FORCE_RENEW = 9
	LEASE_QUERY = 10
	LEASE_UNASSIGNED = 11
	LEASE_UNKNOWN = 12
	LEASE_ACTIVE = 13
	BULK_LEASE_QUERY = 14
	LEASE_QUERY_DONE = 15
	ACTIVE_LEASE_QUERY = 16
	LEASE_QUERY_STATUS = 17
	TLS = 18


class DUIDType(Enum):
	LLT = 1
	EN = 2
	LL = 3
	UUID = 4


class PXEProtocol(StrEnum):
	TFTP = "TFTP"
	HTTP = "HTTP"


class BIOSType(StrEnum):
	BIOS = "BIOS"
	UEFI = "UEFI"


class Architecture(StrEnum):
	X86 = "x86"
	X64 = "x64"
	ARM = "arm"
	ARM64 = "arm64"


# https://www.iana.org/assignments/dhcpv6-parameters/dhcpv6-parameters.xhtml#processor-architecture
class DHCPClientSystemArchitecture(Enum):
	X86_BIOS = 0x0000
	NEC_PC98 = 0x0001
	ITANIUM = 0x0002
	DEC_ALPHA = 0x0003
	ARC_X86 = 0x0004
	INTEL_LEAN_CLIENT = 0x0005
	X86_UEFI = 0x0006
	X64_UEFI = 0x0007
	EFI_XSCALE = 0x0008
	EBC = 0x0009  # EBC = EFI BC = EFI Byte Code
	ARM_32_BIT_UEFI = 0x000A
	ARM_64_BIT_UEFI = 0x000B
	POWERPC_OPEN_FIRMWARE = 0x000C
	POWERPC_EPAPR = 0x000D
	POWER_OPAL_V3 = 0x000E
	X86_UEFI_BOOT_FROM_HTTP = 0x000F
	X64_UEFI_BOOT_FROM_HTTP = 0x0010
	EBC_BOOT_FROM_HTTP = 0x0011
	ARM_UEFI_32_BOOT_FROM_HTTP = 0x0012
	ARM_UEFI_64_BOOT_FROM_HTTP = 0x0013
	PC_AT_BIOS_BOOT_FROM_HTTP = 0x0014
	ARM_32_UBOOT = 0x0015
	ARM_64_UBOOT = 0x0016
	ARM_UBOOT_32_BOOT_FROM_HTTP = 0x0017
	ARM_UBOOT_64_BOOT_FROM_HTTP = 0x0018
	RISC_V_32_BIT_UEFI = 0x0019
	RISC_V_32_BIT_UEFI_BOOT_FROM_HTTP = 0x001A
	RISC_V_64_BIT_UEFI = 0x001B
	RISC_V_64_BIT_UEFI_BOOT_FROM_HTTP = 0x001C
	RISC_V_128_BIT_UEFI = 0x001D
	RISC_V_128_BIT_UEFI_BOOT_FROM_HTTP = 0x001E
	S390_BASIC = 0x001F
	S390_EXTENDED = 0x0020
	MIPS_32_BIT_UEFI = 0x0021
	MIPS_64_BIT_UEFI = 0x0022
	SUNWAY_32_BIT_UEFI = 0x0023
	SUNWAY_64_BIT_UEFI = 0x0024
	LOONGARCH_32_BIT_UEFI = 0x0025
	LOONGARCH_32_BIT_UEFI_BOOT_FROM_HTTP = 0x0026
	LOONGARCH_64_BIT_UEFI = 0x0027
	LOONGARCH_64_BIT_UEFI_BOOT_FROM_HTTP = 0x0028
	ARM_RPIBOOT = 0x0029

	@classmethod
	def from_vendor_class_id(cls, vendor_class_id: str) -> DHCPClientSystemArchitecture:
		# https://datatracker.ietf.org/doc/html/rfc5970
		# https://datatracker.ietf.org/doc/html/rfc4578#section-2.1
		# https://datatracker.ietf.org/doc/html/rfc4578#section-2.2
		# https://www.iana.org/assignments/dhcpv6-parameters/dhcpv6-parameters.xhtml#processor-architecture
		# <PXEClient|HTTPClient>:Arch:xxxxx:UNDI:yyyzzz
		# xxxxx = client architecture 0 - 65535
		# yyy = UNDI Major version 0 - 255
		# zzz = UNDI Minor version 0 - 255
		parts = vendor_class_id.split(":")
		assert parts[0] in ("PXEClient", "HTTPClient")
		assert parts[1] == "Arch"
		return cls(int(parts[2]))

	@property
	def architecture(self) -> Architecture:
		if self.value in (0x0002, 0x0006):
			return Architecture.X86
		if self.value in (0x0007, 0x0008, 0x0009, 0x0011):
			return Architecture.X64
		if "X86" in self.name:
			return Architecture.X86
		if "X64" in self.name:
			return Architecture.X64
		if "ARM" in self.name:
			if "32" in self.name:
				return Architecture.ARM
			if "64" in self.name:
				return Architecture.ARM64
		raise ValueError(f"Unhandled architecture {self.name}")

	@property
	def bios_type(self) -> BIOSType:
		if self.value in (0x0002, 0x0006, 0x0007, 0x0008, 0x0009, 0x0011):
			return BIOSType.UEFI
		if "EFI" in self.name:
			return BIOSType.UEFI
		return BIOSType.BIOS

	@property
	def protocol(self) -> PXEProtocol:
		if "HTTP" in self.name:
			return PXEProtocol.HTTP
		return PXEProtocol.TFTP

	def __str__(self) -> str:
		return f"{self.value:05} ({self.architecture.value} {self.bios_type.value} {self.protocol.value})"

	def __repr__(self) -> str:
		return f"{self.__class__.__name__}(value={self.value:05}, architecture={self.architecture}, bios_type={self.bios_type}, protocol={self.protocol})"


class ClientId:
	@classmethod
	def from_bytes(cls, data: bytes) -> ClientId:
		option_hardware_type = data[0]
		if option_hardware_type == 1:
			return MacClientId.from_bytes(data)
		elif option_hardware_type == 255:
			return DUIDClientId.from_bytes(data)
		raise ValueError(f"Unhandled hardware type {option_hardware_type}")

	def to_bytes(self) -> bytes:
		raise NotImplementedError()


class MacClientId(ClientId):
	option_hardware_type = 1

	def __init__(self, mac_address: MacAddress | str) -> None:
		if not isinstance(mac_address, MacAddress):
			mac_address = MacAddress._validate(mac_address, type(mac_address))
		self.mac_address = mac_address

	def __repr__(self) -> str:
		return f"MacClientId({self.mac_address})"

	def __eq__(self, other: object) -> bool:
		return isinstance(other, MacClientId) and self.mac_address == other.mac_address

	@classmethod
	def from_bytes(cls, data: bytes) -> MacClientId:
		return cls(MacAddress(":".join(data[1:17].hex()[i : i + 2] for i in range(0, 12, 2))))

	def to_bytes(self) -> bytes:
		return self.option_hardware_type.to_bytes() + bytes.fromhex(self.mac_address.replace(":", ""))


class DUIDClientId(ClientId):
	"""
	The client ID consists of two components.
	A DHCP Unique Identifier (DUID) and an Identity Association Identifier (IAID).
	The DUID identifies the client system and the IAID identifies the interface on this system.
	"""

	option_hardware_type = 255

	def __init__(self, iaid: int, duid: DUID_EN | DUID_LL | DUID_LLT | DUID_UUID) -> None:
		self.iaid = iaid
		self.duid = duid

	@classmethod
	def from_bytes(cls, data: bytes) -> DUIDClientId:
		iaid = int.from_bytes(data[1:5], "big")
		duid_type = DUIDType(int.from_bytes(data[5:7], "big"))
		if duid_type == DUIDType.EN:
			return ENClientId(iaid=iaid, duid=DUID_EN(data[5:]))
		if duid_type == DUIDType.LL:
			return LLClientId(iaid=iaid, duid=DUID_LL(data[5:]))
		if duid_type == DUIDType.LLT:
			return LLTClientId(iaid=iaid, duid=DUID_LLT(data[5:]))
		if duid_type == DUIDType.UUID:
			return UUIDClientId(iaid=iaid, duid=DUID_UUID(data[5:]))
		raise ValueError(f"Unhandled DUID type {duid_type}")

	def to_bytes(self) -> bytes:
		return self.option_hardware_type.to_bytes() + self.iaid.to_bytes(4, "big") + bytes(self.duid)

	def __eq__(self, other: object) -> bool:
		return (
			isinstance(other, DUIDClientId)
			and self.option_hardware_type == other.option_hardware_type
			and self.iaid == other.iaid
			and self.duid == other.duid
		)


class ENClientId(DUIDClientId):
	duid_type = DUIDType.EN
	duid: DUID_EN

	@property
	def enterprise_number(self) -> int:
		return self.duid.enterprisenum

	@property
	def identifier(self) -> bytes:
		return self.duid.id

	def __repr__(self) -> str:
		return f"ENClientId(iaid={self.iaid:04x} enterprise_number={self.enterprise_number}, identifier={self.identifier.hex()})"


class LLClientId(DUIDClientId):
	duid_type = DUIDType.LL
	duid: DUID_LL

	@property
	def hardware_type(self) -> int:
		return self.duid.hwtype

	@property
	def linklayer_address(self) -> bytes:
		return self.duid.lladdr


class LLTClientId(LLClientId):
	duid_type = DUIDType.LLT
	duid: DUID_LLT  # type: ignore[assignment]

	@property
	def time(self) -> int:
		return self.duid.timeval


class UUIDClientId(DUIDClientId):
	duid_type = DUIDType.UUID
	duid: DUID_UUID

	@property
	def uuid(self) -> UUID:
		return self.duid.uuid


def read_search_list_encoding(data: bytes, offset: int = 0, read_ref: bool = False) -> Generator[str, None, None]:
	parts: list[str] = []
	while (offset) < len(data):
		dat = data[offset]
		offset += 1
		if dat == 0:
			domain = ".".join(parts)
			yield domain
			if read_ref:
				return
			parts = []
			continue
		if dat == 0xC0:
			parts.extend(read_search_list_encoding(data, data[offset], True))
			domain = ".".join(parts)
			yield domain
			parts = []
			offset += 1
			continue

		parts.append(data[offset : offset + dat].decode("utf-8"))
		offset += dat


def write_search_list_encoding(search_list: list[str]) -> bytes:
	data = b""
	refs: dict[str, int] = {}
	for search in search_list:
		parts = search.rstrip(".").split(".")
		ref_use = False
		for idx, part in enumerate(parts):
			domain = ".".join(parts[idx:])
			if domain in refs:
				dat = b"\xc0" + refs[domain].to_bytes(1, "big")
				ref_use = True
			else:
				refs[domain] = len(data)
				dat = len(part).to_bytes(1, "big") + part.encode("utf-8")
			data += dat
			if ref_use:
				break
		if not ref_use:
			data += b"\0"
	return data


class DHCPv4Packet:
	def __init__(
		self,
		operation_type: DHCPOperationType,
		client_mac_address: MacAddress,
		hops: int | None = None,
		transaction_id: int | None = None,
		seconds_elapsed: int | None = None,
		client_ip_address: IPv4Address | None = None,
		your_ip_address: IPv4Address | None = None,
		next_server_ip_address: IPv4Address | None = None,
		gateway_ip_address: IPv4Address | None = None,
		server_name: str | None = None,
		boot_file: str | None = None,
		options: dict[str, Any] | None = None,
	) -> None:
		self.operation_type = operation_type
		self.hops = hops or 0
		self.client_mac_address = client_mac_address
		self.transaction_id = transaction_id or randint(0, 0xFFFFFF)
		self.seconds_elapsed = seconds_elapsed or 0
		self.client_ip_address = client_ip_address or IPv4Address("0.0.0.0")
		self.your_ip_address = your_ip_address or IPv4Address("0.0.0.0")
		self.next_server_ip_address = next_server_ip_address or IPv4Address("0.0.0.0")
		self.gateway_ip_address = gateway_ip_address or IPv4Address("0.0.0.0")
		self.server_name = server_name or None
		self.boot_file = boot_file or None
		self.options = options or {}

	def dump(self) -> str:
		return str(self.to_packet().show(dump=True))

	@classmethod
	def from_bytes(cls, data: bytes, ether: bool = False) -> DHCPv4Packet:
		if ether:
			packet: BOOTP = Ether(data)[BOOTP]
		else:
			packet = BOOTP(data)
		return cls.from_packet(packet)

	@classmethod
	def from_packet(cls, packet: BOOTP) -> DHCPv4Packet:
		options = {}
		for option in packet[DHCP].options:
			if not isinstance(option, tuple):
				continue
			label, *values = option
			value = values[0]
			if label == "message-type":
				value = DHCPMessageType(value)
			elif label == "client_id":
				value = ClientId.from_bytes(value)
			elif label in ("server_id", "requested_addr", "subnet_mask", "router"):
				value = IPv4Address(value)
			elif label == "pxe_client_architecture":
				value = DHCPClientSystemArchitecture(int.from_bytes(value, "big"))
			elif label == "pxe_client_network_interface":
				# <Type> <Major> <Minor>
				ver = [str(value[i]) for i in range(1, len(value))]
				if len(ver) == 1:
					ver.append("0")
				ver[1] = ver[1].rjust(2, "0")
				value = ".".join(ver)
			elif label in ("hostname", "vendor_class_id", "user_class", "domain", "tftp_server_name", "boot-file-name"):
				value = value.decode("utf-8")
			elif label == "pxe_client_machine_identifier":
				value = UUID(bytes_le=value[1:])
			elif label in ("NetBIOS_server", "name_server"):
				value = [IPv4Address(v) for v in values]
			elif label == "domain_search":
				value = list(read_search_list_encoding(value))
			options[label] = value

		return DHCPv4Packet(
			operation_type=DHCPOperationType(packet.op),
			hops=packet.hops,
			transaction_id=packet.xid,
			seconds_elapsed=packet.secs,
			client_ip_address=IPv4Address(packet.ciaddr),
			your_ip_address=IPv4Address(packet.yiaddr),
			next_server_ip_address=IPv4Address(packet.siaddr),
			gateway_ip_address=IPv4Address(packet.giaddr),
			client_mac_address=MacAddress(":".join(packet.chaddr.hex()[i : i + 2] for i in range(0, 12, 2))),
			server_name=packet.sname.decode("utf-8").rstrip("\0") or None,
			boot_file=packet.file.decode("utf-8").rstrip("\0") or None,
			options=options,
		)

	def is_pxe_request(self) -> bool:
		# DHCP DISCOVER (port 67 / DHCP) or DHCP REQUEST (port 4011 / PXE)
		return (
			self.operation_type == DHCPOperationType.BOOT_REQUEST
			and self.options.get("message-type") in (DHCPMessageType.DISCOVER, DHCPMessageType.REQUEST)
			and str(self.options.get("vendor_class_id", "")).startswith(("PXEClient", "HTTPClient"))
		)

	def pxe_client_arch(self) -> DHCPClientSystemArchitecture:
		# We prefer the option pxe_client_architecture over the vendor_class_id.
		# Escpecially in the case of a PXE packet on port 4011 (DHCP server sends option 60 = "PXEClient")
		# it is expected that the vendor class ID is "PXEClient" or "HTTPClient", without further Arch information.
		pxe_client_architecture = self.options.get("pxe_client_architecture")
		if pxe_client_architecture:
			return pxe_client_architecture
		vendor_class_id = str(self.options.get("vendor_class_id", ""))
		if vendor_class_id and vendor_class_id.startswith(("PXEClient:Arch:", "HTTPClient:Arch:")):
			return DHCPClientSystemArchitecture.from_vendor_class_id(vendor_class_id)
		raise ValueError("No PXE client architecture found")

	def to_packet(self) -> Packet:
		mac_address = bytes.fromhex(self.client_mac_address.replace(":", ""))
		options: list[tuple | str] = []
		for label, value in self.options.items():
			if label == "message-type":
				value = value.value
			elif label == "client_id":
				value = value.to_bytes()
			elif label in ("server_id", "requested_addr", "subnet_mask", "router"):
				value = value.exploded
			elif label == "pxe_client_architecture":
				value = value.value.to_bytes(2, "big")
			elif label == "pxe_client_network_interface":
				value = b"\x01" + bytes([int(v) for v in value.split(".")])
			elif label in ("hostname", "vendor_class_id", "user_class", "domain"):
				value = value.encode("utf-8")
			elif label == "pxe_client_machine_identifier":
				value = b"\x00" + value.bytes_le
			elif label in ("NetBIOS_server", "name_server"):
				value = [v.exploded for v in value]
			elif label == "domain_search":
				value = write_search_list_encoding(value)
			elif label == "param_req_list":
				value = [value]
			if not isinstance(value, list):
				value = [value]
			value.insert(0, label)
			options.append(tuple(value))

		options.append("end")

		return BOOTP(
			op=self.operation_type.value,
			hops=self.hops,
			xid=self.transaction_id,
			secs=self.seconds_elapsed,
			ciaddr=self.client_ip_address.exploded,
			yiaddr=self.your_ip_address.exploded,
			siaddr=self.next_server_ip_address.exploded,
			giaddr=self.gateway_ip_address.exploded,
			chaddr=mac_address,
			sname=self.server_name or "",
			file=self.boot_file or "",
		) / DHCP(options=options)
