# -*- coding: utf-8 -*-

# This file is part of the desktop management solution opsi http://www.opsi.org
# Copyright (c) 2023-2024 uib GmbH <info@uib.de>
# This code is owned by the uib GmbH, Mainz, Germany (uib.de). All rights reserved.
# License: AGPL-3.0

from __future__ import annotations

import asyncio
from asyncio import AbstractEventLoop, DatagramTransport
from io import BytesIO
from ipaddress import IPv4Address, IPv6Address
from pathlib import Path
from socket import AF_INET, IPPROTO_UDP
from threading import Event, Thread
from types import TracebackType
from typing import IO, TYPE_CHECKING, Callable

from confz import BaseConfig
from opsicommon.logging import get_logger
from py3tftp.file_io import FileReader  # type: ignore[import-untyped]
from py3tftp.netascii import Netascii  # type: ignore[import-untyped]
from py3tftp.protocols import BaseTFTPProtocol, ProtocolException, RRQProtocol  # type: ignore[import-untyped]
from py3tftp.protocols import TFTPServerProtocol as Py3TFTPServerProtocol  # type: ignore[import-untyped]
from py3tftp.tftp_packet import BaseTFTPPacket  # type: ignore[import-untyped]
from pydantic import Field

if TYPE_CHECKING:
	from opsiagent.opsiagent import OpsiAgent

	from . import BootServerPluginConfig

logger = get_logger("opsiagent.plugin.boot_server.tftp_server")


class TFTPServerConfig(BaseConfig):  # type: ignore[metaclass]
	enabled: bool = Field(default=False)
	interface: IPv4Address | IPv6Address = IPv4Address("0.0.0.0")
	port: int = Field(default=69, ge=1, le=65535)
	base_dir: Path = Field(default=Path("/tftpboot"))
	ack_timeout: float = Field(default=0.5, ge=0.0)
	connection_timeout: float = Field(default=3.0, ge=0.0)
	block_size: int = Field(default=512, ge=8, le=65464)
	window_size: int = Field(default=1, ge=1)


class VirtFileReader(FileReader):
	def __init__(self, filename: str, data: bytes, *, chunk_size: int = 0, mode: bytes | None = None) -> None:
		super().__init__(filename, chunk_size, mode)
		self._size = len(data)
		self._f = BytesIO(data)

	def _open_file(self) -> IO:
		return self._f

	def file_size(self) -> int:
		return self._size


class StaticFileReader(FileReader):
	def __init__(self, filename: Path, *, chunk_size: int = 0, mode: bytes | None = None) -> None:
		self.fname = filename
		self.chunk_size = chunk_size
		self._f = self._open_file()
		self.finished = False

		if mode == b"netascii":
			self._f = Netascii(self._f)


class TFTPServerProtocol(Py3TFTPServerProtocol):
	def __init__(self, server: TFTPServer):
		self._server = server
		super().__init__(
			host_interface=self._server.config.interface.exploded,
			loop=self._server.loop,
			extra_opts={
				b"ack_timeout": self._server.config.ack_timeout,
				b"conn_timeout": self._server.config.connection_timeout,
				b"blksize": self._server.config.block_size,
				b"windowsize": self._server.config.window_size,
			},
		)

	def select_protocol(self, packet: BaseTFTPPacket) -> BaseTFTPProtocol:
		logger.debug("packet type: {}".format(packet.pkt_type))
		if packet.is_rrq():
			return RRQProtocol
		raise ProtocolException("Received incompatible request, ignoring.")

	def handle_file(self, filename: bytes | str, chunk_size: int = 0, mode: bytes | None = None) -> FileReader:
		if not isinstance(filename, str):
			filename = filename.decode("utf-8")
		logger.debug("Handling file: filename=%s, chunk_size=%s, mode=%s", filename, chunk_size, mode)

		# return VirtFileReader(filename, b"", chunk_size=chunk_size, node=mode)
		return StaticFileReader(self._server.config.base_dir / filename.lstrip("/"), chunk_size=chunk_size, mode=mode)

	def select_file_handler(self, packet: BaseTFTPPacket) -> Callable:
		if packet.is_rrq():
			return lambda filename, chunk_size: self.handle_file(filename, chunk_size, packet.mode)
		raise ProtocolException("Received incompatible request, ignoring.")


class TFTPServer(Thread):
	def __init__(self, opsi_agent: OpsiAgent) -> None:
		Thread.__init__(self)
		self.opsi_agent = opsi_agent
		self.loop: AbstractEventLoop = asyncio.new_event_loop()
		self._transport: DatagramTransport | None = None
		self._running = Event()
		self._error: Exception | None = None
		self.service_client = self.opsi_agent.opsi_service.get_service_client()

	@property
	def config(self) -> TFTPServerConfig:
		return self.boot_server_config.tftp_server  # type: ignore[attr-defined]

	@property
	def boot_server_config(self) -> BootServerPluginConfig:
		return self.opsi_agent.config.plugin.boot_server  # type: ignore[attr-defined]

	def __str__(self) -> str:
		return f"TFTPServer(port={self.config.port}, running={self.running})"

	__repr__ = __str__

	def __enter__(self) -> TFTPServer:
		self.start()
		return self

	def __exit__(
		self,
		exc_type: type[BaseException] | None,
		exc_val: BaseException | None,
		exc_tb: TracebackType | None,
	) -> None:
		self.stop()

	@property
	def running(self) -> bool:
		return self._running.is_set()

	def stop(self) -> None:
		if self._transport:
			self._transport.close()
		if self.loop.is_running():
			self.loop.call_soon_threadsafe(self.loop.stop)
		if self.is_alive():
			self.join(3.0)

	def _run(self) -> None:
		endpoint = self.loop.create_datagram_endpoint(
			lambda: TFTPServerProtocol(self),
			family=AF_INET,
			proto=IPPROTO_UDP,
			local_addr=(self.config.interface.exploded, self.config.port),
		)
		self._transport, _protocol = self.loop.run_until_complete(endpoint)
		self._running.set()
		logger.notice("TFTPServer running on port %d", self.config.port)

		self.loop.run_forever()
		self.loop.close()
		self._running.clear()

	def run(self) -> None:
		try:
			self._run()
		except Exception as err:
			self._error = err
			logger.error(err, exc_info=True)

	def start(self) -> None:
		super().start()
		for _ in range(5):
			if self._running.wait(1.0) or self._error:
				break
		if not self._running.is_set():
			raise RuntimeError(f"Failed to start {self}: {self._error or 'Unknown error'}")
