/*#############################################################################
#                                                                             #
# fireperf - A network benchmarking tool                                      #
# Copyright (C) 2021 IPFire Development Team                                  #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
#############################################################################*/

#include <errno.h>
#include <netinet/tcp.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <sys/epoll.h>
#include <time.h>
#include <unistd.h>

#include "logging.h"
#include "main.h"
#include "server.h"
#include "util.h"

#define SOCKET_BACKLOG   1024

static int enable_keepalive(struct fireperf_config* conf, int fd) {
	// Enable keepalive
	int flags = 1;
	int r = setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, (void*)&flags, sizeof(flags));
	if (r) {
		ERROR(conf, "Could not set SO_KEEPALIVE on socket %d: %s\n",
			fd, strerror(errno));
		return 1;
	}

	// Set keepalive interval
	if (conf->keepalive_interval) {
		DEBUG(conf, "Setting keepalive interval to %d\n", conf->keepalive_interval);

		r = setsockopt(fd, SOL_TCP, TCP_KEEPINTVL,
			(void*)&conf->keepalive_interval, sizeof(conf->keepalive_interval));
		if (r) {
			ERROR(conf, "Could not set TCP_KEEPINTVL on socket %d: %s\n",
				fd, strerror(errno));
			return 1;
		}

		DEBUG(conf, "Setting keepalive idle interval to %d\n", conf->keepalive_interval);

		flags = 1;
		r = setsockopt(fd, SOL_TCP, TCP_KEEPIDLE,
			(void*)&flags, sizeof(flags));
		if (r) {
			ERROR(conf, "Could not set TCP_KEEPIDLE on socket %d: %s\n",
				fd, strerror(errno));
			return 1;
		}
	}

	// Set keepalive count
	if (conf->keepalive_count) {
		DEBUG(conf, "Setting keepalive count to %d\n", conf->keepalive_count);

		r = setsockopt(fd, SOL_TCP, TCP_KEEPCNT,
			(void*)&conf->keepalive_count, sizeof(conf->keepalive_count));
		if (r) {
			ERROR(conf, "Could not set TCP_KEEPCNT on socket %d: %s\n",
				fd, strerror(errno));
			return 1;
		}
	}

	return 0;
}

static int create_socket(struct fireperf_config* conf, int i) {
	int r;

	// Open a new socket
	int fd = socket(AF_INET6, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
	if (fd < 0) {
		ERROR(conf, "Could not open socket: %s\n", strerror(errno));
		goto ERROR;
	}

	int flags = 1;
	r = setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &flags, sizeof(flags));
	if (r) {
		ERROR(conf, "Could not set SO_REUSEPORT on socket %d: %s\n",
			fd, strerror(errno));
		goto ERROR;
	}

	// Set socket buffer sizes
	r = set_socket_buffer_sizes(conf, fd);
	if (r)
		goto ERROR;

	// Enable keepalive
	if (conf->keepalive_only) {
		r = enable_keepalive(conf, fd);
		if (r)
			goto ERROR;
	}

	struct sockaddr_in6 addr = {
		.sin6_family = AF_INET6,
		.sin6_port = htons(conf->port + i),
	};

	// Bind it to the selected port
	r = bind(fd, &addr, sizeof(addr));
	if (r) {
		ERROR(conf, "Could not bind socket: %s\n", strerror(errno));
		goto ERROR;
	}

	// Listen
	r = listen(fd, SOCKET_BACKLOG);
	if (r) {
		ERROR(conf, "Could not listen on socket: %s\n", strerror(errno));
		goto ERROR;
	}

	DEBUG(conf, "Created listening socket %d\n", fd);

	return fd;

ERROR:
	close(fd);

	return -1;
}

static int accept_connection(struct fireperf_config* conf, int sockfd) {
	struct sockaddr_in6 addr;
	socklen_t l = sizeof(addr);

	int fd = -1;

	// The listening socket is ready, there is a new connection waiting to be accepted
	do {
		fd = accept(sockfd, &addr, &l);
	} while (fd < 0 && (errno == EAGAIN || errno == EWOULDBLOCK));

	if (fd < 0) {
		ERROR(conf, "Could not accept a new connection: %s\n", strerror(errno));
		return -1;
	}

	DEBUG(conf, "New connection accepted on socket %d\n", fd);

	return fd;
}

static int is_listening_socket(struct fireperf_config* conf, int* sockets, int fd) {
	for (unsigned int i = 0; i < conf->listening_sockets; i++) {
		if (sockets[i] == fd)
			return 1;
	}

	return 0;
}

int fireperf_server(struct fireperf_config* conf, struct fireperf_stats* stats,
		int epollfd, int timerfd) {
	DEBUG(conf, "Launching " PACKAGE_NAME " in server mode\n");

	int listening_sockets[conf->listening_sockets];

	int r = 1;
	struct epoll_event ev;
	struct epoll_event events[EPOLL_MAX_EVENTS];

	// Create listening sockets
	for (unsigned int i = 0; i < conf->listening_sockets; i++) {
		int sockfd = create_socket(conf, i);
		if (sockfd < 0)
			return 1;

		listening_sockets[i] = sockfd;

		// Add listening socket to epoll
		ev.events  = EPOLLIN;
		ev.data.fd = sockfd;

		if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &ev)) {
			ERROR(conf, "Could not add socket file descriptor to epoll(): %s\n",
				strerror(errno));
			goto ERROR;
		}
	}

	DEBUG(conf, "Entering main loop...\n");

	while (!conf->terminated) {
		int fds = epoll_wait(epollfd, events, EPOLL_MAX_EVENTS, -1);
		if (fds < 1) {
			// We terminate gracefully when we receive a signal
			if (errno == EINTR)
				break;

			ERROR(conf, "epoll_wait() failed: %s\n", strerror(errno));
			goto ERROR;
		}

		for (int i = 0; i < fds; i++) {
			int fd = events[i].data.fd;

			// The listening socket
			if (is_listening_socket(conf, listening_sockets, fd)) {
				int connfd = accept_connection(conf, fd);
				if (connfd < 0)
					goto ERROR;

				// Add the new socket to epoll()
				ev.data.fd = connfd;
				ev.events  = EPOLLIN|EPOLLRDHUP;
				if (!conf->keepalive_only)
					ev.events |= EPOLLOUT;

				if (epoll_ctl(epollfd, EPOLL_CTL_ADD, connfd, &ev)) {
					ERROR(conf, "Could not add socket file descriptor to epoll(): %s\n",
						strerror(errno));
					goto ERROR;
				}

				// A connection has been opened
				stats->open_connections++;
				stats->connections++;

			// Handle timer events
			} else if (fd == timerfd) {
				uint64_t expirations;

				// Read from the timer to disarm it
				ssize_t bytes_read = read(timerfd, &expirations, sizeof(expirations));
				if (bytes_read <= 0) {
					ERROR(conf, "Could not read from timerfd: %s\n", strerror(errno));
					goto ERROR;
				}

				r = fireperf_dump_stats(conf, stats, FIREPERF_MODE_SERVER);
				if (r)
					goto ERROR;

			// Handle any connection events
			} else {
				if (events[i].events & EPOLLRDHUP) {
					DEBUG(conf, "Connection %d has closed\n", fd);

					// Remove the file descriptor from epoll()
					if (epoll_ctl(epollfd, EPOLL_CTL_DEL, fd, NULL)) {
						ERROR(conf, "Could not remove socket file descriptfor from epoll(): %s\n",
							strerror(errno));
					}

					// Free up any resources
					close(fd);

					// This connection is now closed
					stats->open_connections--;

					// Skip processing anything else, because it would be pointless
					continue;
				}

				// Close connections immediately when -x is set
				if (conf->close) {
					DEBUG(conf, "Closing connection %d\n", fd);
					close(fd);

					stats->open_connections--;
					continue;
				}

				// Handle incoming data
				if (events[i].events & EPOLLIN) {
					r = handle_connection_recv(conf, stats, fd);
					if (r < 0)
						goto ERROR;
				}

				// Handle outgoing data
				if (events[i].events & EPOLLOUT) {
					r = handle_connection_send(conf, stats, fd);
					if (r < 0)
						goto ERROR;
				}
			}
		}
	}

ERROR:
	for (unsigned int i = 0; i < conf->listening_sockets; i++) {
		if (listening_sockets[i] > 0)
			close(listening_sockets[i]);
	}

	return r;
}
