# -*- coding: utf-8 -*-

# This file is part of the desktop management solution opsi http://www.opsi.org
# Copyright (c) 2023-2024 uib GmbH <info@uib.de>
# This code is owned by the uib GmbH, Mainz, Germany (uib.de). All rights reserved.
# License: AGPL-3.0

from __future__ import annotations

import asyncio
from asyncio import AbstractEventLoop, DatagramProtocol, DatagramTransport
from dataclasses import dataclass
from ipaddress import IPv4Address
from socket import AF_INET, IPPROTO_UDP
from threading import Event, Thread
from types import TracebackType
from typing import TYPE_CHECKING, Any
from uuid import UUID

from confz import BaseConfig
from opsicommon.logging import get_logger
from opsicommon.logging.constants import TRACE
from pydantic import Field, field_validator

from opsiagent.config import get_server_role

from .dhcp import (
	Architecture,
	BIOSType,
	DHCPClientSystemArchitecture,
	DHCPMessageType,
	DHCPOperationType,
	DHCPv4Packet,
	MacAddress,
	PXEProtocol,
	UUIDClientId,
)

if TYPE_CHECKING:
	from opsiagent.opsiagent import OpsiAgent

	from . import BootServerPluginConfig

logger = get_logger("opsiagent.plugin.boot_server.dhcp_server")


@dataclass(kw_only=True, slots=True)
class BootConfig:
	pxe_boot_server: IPv4Address | None = None
	pxe_boot_filename: str | None = None
	grub_config: str | None = None
	linux_bootimage_kernel_params: dict[str, str | None] | None = None


class ProxyDHCPServerConfig(BaseConfig):  # type: ignore[metaclass]
	enabled: bool = Field(default=False)
	# Default DHCP server port is 67, PXE boot server port is 4011.
	# If the ProxyDHCP should be running on the same host as the DHCP server
	# use port 4011 for the ProxyDHCP server and set DHCP-Server Option 60 to "PXEClient"
	port: int = Field(default=67, ge=1, le=65535)
	broadcast_addresses: list[IPv4Address] = Field(
		default_factory=lambda: [IPv4Address("255.255.255.255")],
	)

	@field_validator("broadcast_addresses", mode="before")
	def _validate_broadcast_addresses(cls, value: list[str]) -> list[IPv4Address]:
		return [IPv4Address(addr) for addr in value]


class ProxyDHCPServerProtocol(DatagramProtocol):
	def __init__(self, server: ProxyDHCPServer):
		self._server = server

	def connection_made(self, transport: DatagramTransport) -> None:  # type: ignore[override]
		logger.debug("connection_made: transport=%r", transport)
		self._transport = transport

	def datagram_received(self, data: bytes, address: tuple[str | Any, int]) -> None:
		logger.debug("Packet received from %r", address)
		try:
			self._process_datagram(data, address)
		except Exception as err:
			logger.error(err, exc_info=True)

	def _process_datagram(self, data: bytes, address: tuple[str | Any, int]) -> None:
		logger.debug("Processing datagram from %r", address)
		request = DHCPv4Packet.from_bytes(data)

		if logger.isEnabledFor(TRACE):
			logger.trace("Received DHCP packet:\n%s", request.dump())

		if not request.is_pxe_request():
			logger.debug("Not a DHCP PXE request: %s", request)
			return

		logger.debug("DHCP PXE request received: %s", request)
		pxe_client_arch = request.pxe_client_arch()

		client_mac_address = request.client_mac_address
		client_uuid: UUID | None = None
		client_id = request.options.get("client_id")
		if client_id and isinstance(client_id, UUIDClientId):
			client_uuid = client_id.uuid

		logger.info(
			"Processing DHCP PXE request: arch=%s, mac=%s, uuid=%s",
			pxe_client_arch,
			client_mac_address,
			client_uuid,
		)

		boot_conf = self._server.get_host_boot_config(
			pxe_client_arch=pxe_client_arch, mac_address=client_mac_address, system_uuid=client_uuid
		)
		if not boot_conf:
			logger.info("No boot configuration found for mac=%s, uuid=%s", client_mac_address, client_uuid)
			return

		logger.notice(
			"Boot configuration for mac=%s, uuid=%s: %r,",
			client_mac_address,
			client_uuid,
			boot_conf,
		)

		response_message_type = DHCPMessageType.OFFER
		response_addresses = [a.exploded for a in self._server.config.broadcast_addresses]
		if request.options["message-type"] == DHCPMessageType.REQUEST:
			response_message_type = DHCPMessageType.ACK
			response_addresses = [address[0]]

		response = DHCPv4Packet(
			operation_type=DHCPOperationType.BOOT_REPLY,
			transaction_id=request.transaction_id,
			# iPXE 1.20.1: next server must not be empty
			next_server_ip_address=boot_conf.pxe_boot_server,
			# boot_file=boot_conf.pxe_boot_filename if pxe_client_arch.protocol == PXEProtocol.TFTP else "",
			boot_file=boot_conf.pxe_boot_filename,
			client_mac_address=client_mac_address,
			options={
				"message-type": response_message_type,
				"vendor_class_id": "HTTPClient" if pxe_client_arch.protocol == PXEProtocol.HTTP else "PXEClient",
				"tftp_server_name": boot_conf.pxe_boot_server.exploded if boot_conf.pxe_boot_server else "",
				"boot-file-name": boot_conf.pxe_boot_filename,
			},
		).to_packet()

		if logger.isEnabledFor(TRACE):
			logger.trace("Sending DHCP PXE response: %s", response.show(dump=True))

		for response_address in response_addresses:
			logger.debug("Sending DHCP PXE response to: %s", response_address)
			self._transport.sendto(response.build(), (response_address, address[1]))

	def error_received(self, exception: Exception) -> None:
		logger.warning("error_received: %r", exception)

	def connection_lost(self, exception: Exception | None) -> None:
		logger.debug("connection_lost: exception=%r", exception)


class ProxyDHCPServer(Thread):
	def __init__(self, opsi_agent: OpsiAgent) -> None:
		Thread.__init__(self)
		self.opsi_agent = opsi_agent
		self._loop: AbstractEventLoop = asyncio.new_event_loop()
		self._transport: DatagramTransport | None = None
		self._running = Event()
		self._error: Exception | None = None
		self.service_client = self.opsi_agent.opsi_service.get_service_client()
		logger.info("Using %r as PXE boot server", self.boot_server_config.pxe_boot.boot_server_address)

	@property
	def config(self) -> ProxyDHCPServerConfig:
		return self.boot_server_config.proxy_dhcp_server  # type: ignore[attr-defined]

	@property
	def boot_server_config(self) -> BootServerPluginConfig:
		return self.opsi_agent.config.plugin.boot_server  # type: ignore[attr-defined]

	def __str__(self) -> str:
		return f"ProxyDHCPServer(port={self.config.port}, broadcast_addresses={self.config.broadcast_addresses}, running={self.running})"

	__repr__ = __str__

	def __enter__(self) -> ProxyDHCPServer:
		self.start()
		return self

	def __exit__(
		self,
		exc_type: type[BaseException] | None,
		exc_val: BaseException | None,
		exc_tb: TracebackType | None,
	) -> None:
		self.stop()

	@property
	def running(self) -> bool:
		return self._running.is_set()

	def stop(self) -> None:
		if self._transport:
			self._transport.close()
		if self._loop.is_running():
			self._loop.call_soon_threadsafe(self._loop.stop)
		if self.is_alive():
			self.join(3.0)

	def _run(self) -> None:
		endpoint = self._loop.create_datagram_endpoint(
			lambda: ProxyDHCPServerProtocol(self),
			family=AF_INET,
			proto=IPPROTO_UDP,
			local_addr=("0.0.0.0", self.config.port),
			allow_broadcast=True,
		)
		self._transport, _protocol = self._loop.run_until_complete(endpoint)
		self._running.set()
		logger.notice("ProxyDHCPServer running on port %d", self.config.port)

		self._loop.run_forever()
		self._loop.close()
		self._running.clear()

	def run(self) -> None:
		try:
			self._run()
		except Exception as err:
			self._error = err
			logger.error(err, exc_info=True)

	def start(self) -> None:
		super().start()
		for _ in range(5):
			if self._running.wait(1.0) or self._error:
				break
		if not self._running.is_set():
			raise RuntimeError(f"Failed to start {self}: {self._error or 'Unknown error'}")

	def get_host_boot_config(
		self,
		*,
		pxe_client_arch: DHCPClientSystemArchitecture,
		mac_address: MacAddress,
		system_uuid: UUID | None = None,
	) -> BootConfig | None:
		if not get_server_role():
			return None

		server_address = self.boot_server_config.pxe_boot.boot_server_address
		if pxe_client_arch.architecture == Architecture.X86 and pxe_client_arch.bios_type == BIOSType.BIOS:
			boot_filename = self.boot_server_config.pxe_boot.boot_filename_x86_bios
		elif pxe_client_arch.architecture == Architecture.X64 and pxe_client_arch.bios_type == BIOSType.UEFI:
			boot_filename = self.boot_server_config.pxe_boot.boot_filename_x64_uefi
		else:
			raise ValueError(f"Unsupported PXE client architecture: {pxe_client_arch!r}")
		logger.info("Looking for host with mac=%s, uuid=%s", mac_address, system_uuid)
		hosts = []
		if system_uuid:
			hosts = self.service_client.host_getObjects(systemUUID=system_uuid)  # type: ignore[attr-defined]
		if not hosts:
			logger.debug("No host found with uuid %r, searching for mac address", system_uuid)
			hosts = self.service_client.host_getObjects(hardwareAddress=str(mac_address))  # type: ignore[attr-defined]
			if not hosts:
				logger.info("No host found with mac address %r or uuid %r", mac_address, system_uuid)
				return None
		if len(hosts) > 1:
			logger.error("Multiple hosts found!")
			return None
		relevant_pocs = self.service_client.productOnClient_getObjects(  # type: ignore[attr-defined]
			clientId=hosts[0].id, productType="NetbootProduct", actionRequest="setup"
		)
		if not relevant_pocs:
			logger.info("No NetbootProduct found for host %r", hosts[0].id)
			return None
		if len(relevant_pocs) > 1:
			logger.warning("Multiple NetbootProducts found for host %s", hosts[0].id)
			return None
		logger.notice("NetbootProduct %s is set for host %s. Delivering BootConfig", relevant_pocs[0].productId, hosts[0].id)
		if pxe_client_arch.protocol == PXEProtocol.HTTP:
			return BootConfig(pxe_boot_server=None, pxe_boot_filename=f"http://{server_address.exploded}/{boot_filename}")
		return BootConfig(pxe_boot_server=server_address, pxe_boot_filename=boot_filename)
