// Copyright 2020-2023, Collabora, Ltd.
// SPDX-License-Identifier: BSL-1.0
/*!
 * @file
 * @brief  Main hub of the remote driver.
 * @author Jakob Bornecrantz <jakob@collabora.com>
 * @ingroup drv_remote
 */

#include "r_internal.h"

#include "util/u_var.h"
#include "util/u_misc.h"
#include "util/u_debug.h"
#include "util/u_space_overseer.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#if defined(XRT_OS_WINDOWS)
#include <winsock2.h>
#include <ws2tcpip.h>
#include "xrt/xrt_windows.h"
#else
#include <unistd.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#endif

#ifndef _BSD_SOURCE
#define _BSD_SOURCE // same, but for musl // NOLINT
#endif

#ifndef __USE_MISC
#define __USE_MISC // SOL_TCP on C11
#endif

// Define the format to use to print a socket descriptor
#ifdef XRT_OS_WINDOWS
// On Windows, this is a SOCKET, aka an unsigned long long
#define R_SOCKET_FMT "%llu"
#else
// On non-Windows, this is a file descriptor, aka an int
#define R_SOCKET_FMT "%i"
#endif


/*
 *
 * Small helpers.
 *
 */

DEBUG_GET_ONCE_LOG_OPTION(remote_log, "REMOTE_LOG", U_LOGGING_INFO)

#define R_TRACE(R, ...) U_LOG_IFL_T((R)->rc.log_level, __VA_ARGS__)
#define R_DEBUG(R, ...) U_LOG_IFL_D((R)->rc.log_level, __VA_ARGS__)
#define R_INFO(R, ...) U_LOG_IFL_I((R)->rc.log_level, __VA_ARGS__)
#define R_WARN(R, ...) U_LOG_IFL_W((R)->rc.log_level, __VA_ARGS__)
#define R_ERROR(R, ...) U_LOG_IFL_E((R)->rc.log_level, __VA_ARGS__)

#define RC_TRACE(RC, ...) U_LOG_IFL_T((RC)->log_level, __VA_ARGS__)
#define RC_DEBUG(RC, ...) U_LOG_IFL_D((RC)->log_level, __VA_ARGS__)
#define RC_INFO(RC, ...) U_LOG_IFL_I((RC)->log_level, __VA_ARGS__)
#define RC_WARN(RC, ...) U_LOG_IFL_W((RC)->log_level, __VA_ARGS__)
#define RC_ERROR(RC, ...) U_LOG_IFL_E((RC)->log_level, __VA_ARGS__)


/*
 *
 * Platform socket wrapper functions.
 *
 */

#if defined(XRT_OS_WINDOWS)

static inline void
socket_close(r_socket_t id)
{
	closesocket(id);
}

static inline r_socket_t
socket_create(void)
{
	return socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
}

static inline int
socket_set_opt(r_socket_t id, int flag)
{
	return setsockopt(id, SOL_SOCKET, SO_REUSEADDR, (const char *)&flag, sizeof(flag));
}

static inline ssize_t
socket_read(r_socket_t id, void *ptr, size_t size, size_t current)
{
	return recv(id, (char *)ptr, (int)(size - current), 0);
}

static inline ssize_t
socket_write(r_socket_t id, void *ptr, size_t size, size_t current)
{
	return send(id, (const char *)ptr, (int)(size - current), 0);
}

#elif defined(XRT_OS_UNIX)

static inline void
socket_close(r_socket_t id)
{
	close(id);
}

static inline r_socket_t
socket_create(void)
{
	return socket(AF_INET, SOCK_STREAM, 0);
}

static inline int
socket_set_opt(r_socket_t id, int flag)
{
	return setsockopt(id, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(flag));
}

static inline ssize_t
socket_read(r_socket_t id, void *ptr, size_t size, size_t current)
{
	return read(id, ptr, size - current);
}

static inline ssize_t
socket_write(r_socket_t id, void *ptr, size_t size, size_t current)
{
	return write(id, ptr, size - current);
}

#endif // XRT_OS_UNIX


/*
 *
 * Helper functions.
 *
 */

static r_socket_t
setup_accept_fd(struct r_hub *r)
{
	struct sockaddr_in server_address = {0};
#if defined(XRT_OS_WINDOWS)
	// Initialize Winsock.
	WSADATA wsaData;
	if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
		int error = WSAGetLastError();
		R_ERROR(r, "Failed to do WSAStartup %d", error);
		return error;
	}
#endif
	r_socket_t ret = socket_create();

	if (ret < 0) {
		R_ERROR(r, "socket: " R_SOCKET_FMT, ret);
		goto cleanup;
	}

	r->accept_fd = ret;

	int flag = 1;
	ret = socket_set_opt(r->accept_fd, flag);
	if (ret < 0) {
		R_ERROR(r, "setsockopt: " R_SOCKET_FMT, ret);
		socket_close(r->accept_fd);
		r->accept_fd = -1;
		goto cleanup;
	}

	server_address.sin_family = AF_INET;
	server_address.sin_addr.s_addr = htonl(INADDR_ANY);
	server_address.sin_port = htons(r->port);

	ret = bind(r->accept_fd, (struct sockaddr *)&server_address, sizeof(server_address));
	if (ret < 0) {
		R_ERROR(r, "bind: " R_SOCKET_FMT, ret);
		socket_close(r->accept_fd);
		r->accept_fd = -1;
		goto cleanup;
	}

	R_INFO(r, "Listen address %s on port %d", inet_ntoa(server_address.sin_addr), r->port);

	listen(r->accept_fd, 5);

	return 0;
cleanup:
#if defined(XRT_OS_WINDOWS)
	WSACleanup();
#endif
	return ret;
}

static bool
wait_for_read_and_to_continue(struct r_hub *r, r_socket_t socket)
{
	fd_set set;
	int ret = 0;

	// To be more roboust
	if (socket < 0) {
		return false;
	}

	while (os_thread_helper_is_running(&r->oth) && ret == 0) {
		// Select can modify timeout, reset each loop.
		struct timeval timeout = {.tv_sec = 1, .tv_usec = 0};

		// Reset each loop.
		FD_ZERO(&set);
		FD_SET(socket, &set);

		ret = select((int)socket + 1, &set, NULL, NULL, &timeout);
	}

	if (ret < 0) {
		R_ERROR(r, "select: %i", ret);
		return false;
	} else if (ret > 0) {
		return true;
	} else {
		return false;
	}
}

static r_socket_t
do_accept(struct r_hub *r)
{
	struct sockaddr_in addr = {0};
	r_socket_t ret = 0;
	if (!wait_for_read_and_to_continue(r, r->accept_fd)) {
		R_ERROR(r, "Failed to wait for id " R_SOCKET_FMT, r->accept_fd);
		return -1;
	}

	socklen_t addr_length = (socklen_t)sizeof(addr);
	ret = accept(r->accept_fd, (struct sockaddr *)&addr, &addr_length);
	if (ret < 0) {
		R_ERROR(r, "accept: " R_SOCKET_FMT, ret);
		return ret;
	}

	r_socket_t conn_fd = ret;

	int flags = 1;
	ret = socket_set_opt(r->accept_fd, flags);
	if (ret < 0) {
		R_ERROR(r, "setsockopt: " R_SOCKET_FMT, ret);
		socket_close(conn_fd);
		return ret;
	}

	r->rc.fd = conn_fd;

	R_INFO(r, "Connection received! " R_SOCKET_FMT, r->rc.fd);

	return 0;
}

static ssize_t
read_one(struct r_hub *r, struct r_remote_data *data)
{
	struct r_remote_connection *rc = &r->rc;

	const size_t size = sizeof(*data);
	size_t current = 0;

	while (current < size) {
		void *ptr = (uint8_t *)data + current;

		if (!wait_for_read_and_to_continue(r, rc->fd)) {
			return -1;
		}

		ssize_t ret = socket_read(rc->fd, ptr, size, current);
		if (ret < 0) {
#if defined(XRT_OS_WINDOWS)
			RC_ERROR(rc, "recv: %i", WSAGetLastError());
#else
			RC_ERROR(rc, "read: %zi", ret);
#endif
			return ret;
		} else if (ret > 0) {
			current += (size_t)ret;
		} else {
			R_INFO(r, "Disconnected!");
			return -1;
		}
	}

	return 0;
}

static void *
run_thread(void *ptr)
{
	struct r_hub *r = (struct r_hub *)ptr;
	r_socket_t ret;

	ret = setup_accept_fd(r);
	if (ret < 0) {
		R_INFO(r, "Leaving thread");
		return NULL;
	}

	while (os_thread_helper_is_running(&r->oth)) {
		R_INFO(r, "Listening on port '%i'.", r->port);

		ret = do_accept(r);
		if (ret < 0) {
			R_INFO(r, "Leaving thread");
			return NULL;
		}

		r_remote_connection_write_one(&r->rc, &r->reset);
		r_remote_connection_write_one(&r->rc, &r->latest);

		while (true) {
			struct r_remote_data data;

			ret = read_one(r, &data);
			if (ret < 0) {
				break;
			}

			r->latest = data;
		}
	}

	R_INFO(r, "Leaving thread");

	return NULL;
}

static xrt_result_t
r_hub_system_devices_get_roles(struct xrt_system_devices *xsysd, struct xrt_system_roles *out_roles)
{
	struct r_hub *r = (struct r_hub *)xsysd;

	struct xrt_system_roles roles = XRT_SYSTEM_ROLES_INIT;
	roles.generation_id = 1;
	roles.left = r->left_index;
	roles.right = r->right_index;

	*out_roles = roles;

	return XRT_SUCCESS;
}

static void
r_hub_system_devices_destroy(struct xrt_system_devices *xsysd)
{
	struct r_hub *r = (struct r_hub *)xsysd;

	R_DEBUG(r, "Destroying");

	// Stop the thread first.
	os_thread_helper_stop_and_wait(&r->oth);

	// Destroy all of the devices now.
	for (uint32_t i = 0; i < ARRAY_SIZE(r->base.xdevs); i++) {
		xrt_device_destroy(&r->base.xdevs[i]);
	}

	// Should be safe to destroy the sockets now.
	if (r->accept_fd >= 0) {
		socket_close(r->accept_fd);
		r->accept_fd = -1;
	}

	if (r->rc.fd >= 0) {
		socket_close(r->rc.fd);
		r->rc.fd = -1;
	}

	free(r);

#if defined(XRT_OS_WINDOWS)
	// Clean up Winsock.
	WSACleanup();
#endif
}


/*
 *
 * 'Exported' create function.
 *
 */

xrt_result_t
r_create_devices(uint16_t port,
                 uint32_t view_count,
                 struct xrt_session_event_sink *broadcast,
                 struct xrt_system_devices **out_xsysd,
                 struct xrt_space_overseer **out_xso)
{
	struct r_hub *r = U_TYPED_CALLOC(struct r_hub);
	int ret;

	r->base.destroy = r_hub_system_devices_destroy;
	r->base.get_roles = r_hub_system_devices_get_roles;
	r->origin.type = XRT_TRACKING_TYPE_RGB;
	r->origin.initial_offset = (struct xrt_pose)XRT_POSE_IDENTITY;
	r->reset.head.center = (struct xrt_pose)XRT_POSE_IDENTITY;
	r->reset.head.center.position.y = 1.6f;
	r->reset.left.active = true;
	r->reset.left.hand_tracking_active = true;
	r->reset.left.pose.position.x = -0.2f;
	r->reset.left.pose.position.y = 1.3f;
	r->reset.left.pose.position.z = -0.5f;
	r->reset.left.pose.orientation.w = 1.0f;

	r->reset.right.active = true;
	r->reset.right.hand_tracking_active = true;
	r->reset.right.pose.position.x = 0.2f;
	r->reset.right.pose.position.y = 1.3f;
	r->reset.right.pose.position.z = -0.5f;
	r->reset.right.pose.orientation.w = 1.0f;
	r->latest = r->reset;
	r->rc.log_level = debug_get_log_option_remote_log();
	r->gui.hmd = true;
	r->gui.left = true;
	r->gui.right = true;
	r->port = port;
	r->view_count = view_count;
	r->accept_fd = -1;
	r->rc.fd = -1;

	snprintf(r->origin.name, sizeof(r->origin.name), "Remote Simulator");

	ret = os_thread_helper_init(&r->oth);
	if (ret != 0) {
		R_ERROR(r, "Failed to init threading!");
		r_hub_system_devices_destroy(&r->base);
		return XRT_ERROR_ALLOCATION;
	}

	ret = os_thread_helper_start(&r->oth, run_thread, r);
	if (ret != 0) {
		R_ERROR(r, "Failed to start thread!");
		r_hub_system_devices_destroy(&r->base);
		return XRT_ERROR_ALLOCATION;
	}


	/*
	 * Setup system devices.
	 */

	struct xrt_device *head = r_hmd_create(r);
	struct xrt_device *left = r_device_create(r, true);
	struct xrt_device *right = r_device_create(r, false);

	r->base.xdevs[r->base.xdev_count++] = head;
	r->left_index = (int32_t)r->base.xdev_count;
	r->base.xdevs[r->base.xdev_count++] = left;
	r->right_index = (int32_t)r->base.xdev_count;
	r->base.xdevs[r->base.xdev_count++] = right;

	r->base.static_roles.head = head;
	r->base.static_roles.hand_tracking.left = left;
	r->base.static_roles.hand_tracking.right = right;


	/*
	 * Space overseer.
	 */

	struct u_space_overseer *uso = u_space_overseer_create(broadcast);
	struct xrt_space_overseer *xso = (struct xrt_space_overseer *)uso;
	assert(uso != NULL);

	struct xrt_space *root = xso->semantic.root; // Convenience
	struct xrt_space *offset = NULL;
	u_space_overseer_create_offset_space(uso, root, &r->origin.initial_offset, &offset);

	for (uint32_t i = 0; i < r->base.xdev_count; i++) {
		u_space_overseer_link_space_to_device(uso, offset, r->base.xdevs[i]);
	}

	// Unreference now
	xrt_space_reference(&offset, NULL);

	// Set root as stage space.
	xrt_space_reference(&xso->semantic.stage, root);

	// Local 1.6 meters up.
	struct xrt_pose local_offset = {XRT_QUAT_IDENTITY, {0.0f, 1.6f, 0.0f}};
	u_space_overseer_create_offset_space(uso, root, &local_offset, &xso->semantic.local);

	// Local floor at the same place as local except at floor height.
	struct xrt_pose local_floor_offset = local_offset;
	local_floor_offset.position.y = 0.0f;
	u_space_overseer_create_offset_space(uso, root, &local_floor_offset, &xso->semantic.local_floor);

	// Make view space be the head pose.
	u_space_overseer_create_pose_space(uso, head, XRT_INPUT_GENERIC_HEAD_POSE, &xso->semantic.view);


	/*
	 * Setup variable tracker.
	 */

	u_var_add_root(r, "Remote Hub", true);
	// u_var_add_gui_header(r, &r->gui.hmd, "MHD");
	u_var_add_pose(r, &r->latest.head.center, "head.center");
	// u_var_add_gui_header(r, &r->gui.left, "Left");
	u_var_add_bool(r, &r->latest.left.active, "left.active");
	u_var_add_pose(r, &r->latest.left.pose, "left.pose");
	// u_var_add_gui_header(r, &r->gui.right, "Right");
	u_var_add_bool(r, &r->latest.right.active, "right.active");
	u_var_add_pose(r, &r->latest.right.pose, "right.pose");

	/*
	 * Done now.
	 */

	*out_xsysd = &r->base;
	*out_xso = xso;

	return XRT_SUCCESS;
}


/*
 *
 * 'Exported' connection functions.
 *
 */

r_socket_t
r_remote_connection_init(struct r_remote_connection *rc, const char *ip_addr, uint16_t port)
{
	struct sockaddr_in addr = {0};
	r_socket_t sock_fd;
	r_socket_t conn_fd;
	int ret;

	// Set log level.
	rc->log_level = debug_get_log_option_remote_log();

#if defined(XRT_OS_WINDOWS)
	// Initialize Winsock.
	WSADATA wsaData;
	ret = WSAStartup(MAKEWORD(2, 2), &wsaData);
	if (ret != 0) {
		RC_ERROR(rc, "Failed to do WSAStartup %i", WSAGetLastError());
		return ret;
	}
#endif

	// Address
	addr.sin_family = AF_INET;
	addr.sin_port = htons(port);

	// inet_pton/InetPton resolves "localhost" as 0.0.0.0 or 255.255.255.255, and it causes connection error. To
	// avoid this issue, the following logic converts "localhost" to "127.0.0.1" first.
	if (strcmp("localhost", ip_addr) == 0) {
		ret = inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);
	} else {
		ret = inet_pton(AF_INET, ip_addr, &addr.sin_addr);
	}
	if (ret < 0) {
		RC_ERROR(rc, "Failed to do inet pton for %s: %i", ip_addr, ret);
		goto cleanup;
	}

	sock_fd = socket_create();
#if defined(XRT_OS_WINDOWS)
	if (sock_fd == INVALID_SOCKET) {
		RC_ERROR(rc, "Failed to create socket %i", WSAGetLastError());
		goto cleanup;
	}
#else
	if (sock_fd < 0) {
		RC_ERROR(rc, "Failed to create socket: %i", ret);
		goto cleanup;
	}
#endif

	conn_fd = sock_fd;

	ret = connect(conn_fd, (struct sockaddr *)&addr, sizeof(addr));
	// If connect operation succeed, both Windows and POSIX returns 0.
	if (ret != 0) {
#if defined(XRT_OS_WINDOWS)
		RC_ERROR(rc, "Failed to connect id " R_SOCKET_FMT " and addr %s with failure %d", conn_fd,
		         inet_ntoa(addr.sin_addr), WSAGetLastError());
#else
		RC_ERROR(rc, "Failed to connect id " R_SOCKET_FMT " and addr %s with failure %d", conn_fd,
		         inet_ntoa(addr.sin_addr), ret);
#endif
		socket_close(conn_fd);
		goto cleanup;
	}

	int flags = 1;
	ret = socket_set_opt(conn_fd, flags);
	if (ret < 0) {
		RC_ERROR(rc, "Failed to setsockopt: %i", ret);
		socket_close(conn_fd);
		goto cleanup;
	}

	rc->fd = conn_fd;

	return 0;

cleanup:
#if defined(XRT_OS_WINDOWS)
	WSACleanup();
#endif
	return ret;
}

int
r_remote_connection_read_one(struct r_remote_connection *rc, struct r_remote_data *data)
{
	const size_t size = sizeof(*data);
	size_t current = 0;

	while (current < size) {
		void *ptr = (uint8_t *)data + current;
		ssize_t ret = socket_read(rc->fd, ptr, size, current);
		if (ret < 0) {
			RC_ERROR(rc, "read: %zi", ret);
			return ret;
		}
		if (ret > 0) {
			current += (size_t)ret;
		} else {
			RC_INFO(rc, "Disconnected!");
			return -1;
		}
	}

	return 0;
}

int
r_remote_connection_write_one(struct r_remote_connection *rc, const struct r_remote_data *data)
{
	const size_t size = sizeof(*data);
	size_t current = 0;

	while (current < size) {
		void *ptr = (uint8_t *)data + current;

		ssize_t ret = socket_write(rc->fd, ptr, size, current);
		if (ret < 0) {
			RC_ERROR(rc, "write: %zi", ret);
			return ret;
		}
		if (ret > 0) {
			current += (size_t)ret;
		} else {
			RC_INFO(rc, "Disconnected!");
			return -1;
		}
	}

	return 0;
}
