/*#############################################################################
#                                                                             #
# fireperf - A network benchmarking tool                                      #
# Copyright (C) 2021 IPFire Development Team                                  #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
#############################################################################*/

#include <errno.h>
#include <netinet/tcp.h>
#include <stdlib.h>
#include <signal.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/random.h>
#include <unistd.h>

#include "client.h"
#include "logging.h"
#include "main.h"
#include "util.h"

// Set to one when the timeout has expired
static int timeout_expired = 0;

static void handle_SIGALRM(int signal) {
	switch (signal) {
		// Terminate after timeout has expired
		case SIGALRM:
			timeout_expired = 1;
			break;
	}
}

const char ZERO[SOCKET_SEND_BUFFER_SIZE] = { 0 };

struct fireperf_random_pool {
	char* data;
	size_t size;
};

static void fireperf_random_pool_free(struct fireperf_random_pool* pool) {
	if (pool->data)
		free(pool->data);

	free(pool);
}

static struct fireperf_random_pool* fireperf_random_pool_create(struct fireperf_config* conf, size_t size) {
	struct fireperf_random_pool* pool = calloc(1, sizeof(*pool));
	if (!pool)
		return NULL;

	pool->size = size;

	// Allocate the data array
	pool->data = malloc(pool->size);
	if (!pool->data)
		goto ERROR;

	size_t offset = 0;
	while (offset < pool->size) {
		offset += getrandom(pool->data + offset, pool->size - offset, 0);
	}

	DEBUG(conf, "Allocated random pool of %zu bytes(s)\n", pool->size);

	return pool;

ERROR:
	fireperf_random_pool_free(pool);

	return NULL;
}

static const char* fireperf_random_pool_get_slice(struct fireperf_random_pool* pool, size_t size) {
	if (size > pool->size)
		return NULL;

	// Find a random value between the start and end of
	// the data region that is at least size bytes long.
	off_t offset = random() % (pool->size - size);

	return pool->data + offset;
}

static int open_connection(struct fireperf_config* conf) {
	// Open a new socket
	int fd = socket(AF_INET6, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
	if (fd < 0) {
		ERROR(conf, "Could not open socket: %s\n", strerror(errno));
		goto ERROR;
	}

	// Chose a random port
	int port = conf->port + (random() % conf->listening_sockets);

	DEBUG(conf, "Opening socket %d (port %d)...\n", fd, port);

	// Define the peer
	struct sockaddr_in6 peer = {
		.sin6_family = AF_INET6,
		.sin6_addr = conf->address,
		.sin6_port = htons(port),
	};

	// Enable keepalive
	int flags = 1;
	int r = setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, (void*)&flags, sizeof(flags));
	if (r) {
		ERROR(conf, "Could not set SO_KEEPALIVE on socket %d: %s\n",
			fd, strerror(errno));
		goto ERROR;
	}

	// Set socket buffer sizes
	flags = SOCKET_SEND_BUFFER_SIZE;
	r = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, (void*)&flags, sizeof(flags));
	if (r) {
		ERROR(conf, "Could not set send buffer size on socket %d: %s\n",
			fd, strerror(errno));
		goto ERROR;
	}

	// Set keepalive interval
	if (conf->keepalive_interval) {
		DEBUG(conf, "Setting keepalive interval to %d\n", conf->keepalive_interval);

		r = setsockopt(fd, SOL_TCP, TCP_KEEPINTVL,
			(void*)&conf->keepalive_interval, sizeof(conf->keepalive_interval));
		if (r) {
			ERROR(conf, "Could not set TCP_KEEPINTVL on socket %d: %s\n",
				fd, strerror(errno));
			goto ERROR;
		}

		DEBUG(conf, "Setting keepalive idle interval to %d\n", conf->keepalive_interval);

		flags = 1;
		r = setsockopt(fd, SOL_TCP, TCP_KEEPIDLE,
			(void*)&flags, sizeof(flags));
		if (r) {
			ERROR(conf, "Could not set TCP_KEEPIDLE on socket %d: %s\n",
				fd, strerror(errno));
			goto ERROR;
		}
	}

	// Set keepalive count
	if (conf->keepalive_count) {
		DEBUG(conf, "Setting keepalive count to %d\n", conf->keepalive_count);

		r = setsockopt(fd, SOL_TCP, TCP_KEEPCNT,
			(void*)&conf->keepalive_count, sizeof(conf->keepalive_count));
		if (r) {
			ERROR(conf, "Could not set TCP_KEEPCNT on socket %d: %s\n",
				fd, strerror(errno));
			goto ERROR;
		}
	}

	// Connect to the server
	r = connect(fd, &peer, sizeof(peer));
	if (r && (errno != EINPROGRESS)) {
		ERROR(conf, "Could not connect to server: %s\n", strerror(errno));
		goto ERROR;
	}

	return fd;

ERROR:
	if (fd > 0)
		close(fd);

	return -1;
}

static int send_data_to_server(struct fireperf_config* conf,
		struct fireperf_stats* stats, int fd, struct fireperf_random_pool* pool) {
	const char* buffer = ZERO;
	ssize_t bytes_sent;

	if (pool) {
		buffer = fireperf_random_pool_get_slice(pool, SOCKET_SEND_BUFFER_SIZE);
	}

	do {
		bytes_sent = send(fd, buffer, SOCKET_SEND_BUFFER_SIZE, 0);
	} while (bytes_sent < 0 && (errno == EAGAIN || errno == EWOULDBLOCK));

	DEBUG(conf, "bytes_sent = %zu\n", bytes_sent);

	// Update statistics
	stats->bytes_sent += bytes_sent;
	stats->total_bytes_sent += bytes_sent;

	return 0;
}

static int handle_connection_ready(struct fireperf_config* conf,
		struct fireperf_stats* stats, int fd, struct fireperf_random_pool* pool) {
	// Are we supposed to close this connection straight away?
	if (conf->close) {
		DEBUG(conf, "Closing connection %d\n", fd);
		close(fd);

		stats->open_connections--;

		return 0;
	}

	return send_data_to_server(conf, stats, fd, pool);
}

int fireperf_client(struct fireperf_config* conf, struct fireperf_stats* stats,
		int epollfd, int timerfd) {
	struct fireperf_random_pool* pool = NULL;

	DEBUG(conf, "Launching " PACKAGE_NAME " in client mode\n");

	// Initialize random pool
	if (!conf->zero) {
		pool = fireperf_random_pool_create(conf, CLIENT_RANDOM_POOL_SIZE);
		if (!pool) {
			ERROR(conf, "Could not allocate random data\n");
			return 1;
		}
	}

	int r = 1;

	struct epoll_event ev = {
		.events = EPOLLIN,
	};
	struct epoll_event events[EPOLL_MAX_EVENTS];

	// Let us know when the socket is ready for sending data
	if (!conf->keepalive_only)
		ev.events |= EPOLLOUT;

	DEBUG(conf, "Opening %lu connections...\n", conf->parallel);

	// Configure timeout if set
	if (conf->timeout) {
		// Register signal handler
		signal(SIGALRM, handle_SIGALRM);

		alarm(conf->timeout);
	}

	DEBUG(conf, "Entering main loop...\n");

	while (!conf->terminated && !timeout_expired) {
		// Open connections
		while (stats->open_connections < conf->parallel) {
			int fd = open_connection(conf);
			if (fd < 0)
				continue;

			ev.data.fd = fd;

			if (epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev)) {
				ERROR(conf, "Could not add socket file descriptor to epoll(): %s\n",
					strerror(errno));
				goto ERROR;
			}

			stats->open_connections++;
			stats->connections++;
		}

		int fds = epoll_wait(epollfd, events, EPOLL_MAX_EVENTS, -1);
		if (fds < 1) {
			// We terminate gracefully when we receive a signal
			if (errno == EINTR)
				break;

			ERROR(conf, "epoll_wait() failed: %s\n", strerror(errno));
			goto ERROR;
		}

		for (int i = 0; i < fds; i++) {
			int fd = events[i].data.fd;

			// What type of event are we handling?

			// Handle timer events
			if (fd == timerfd) {
				uint64_t expirations;

				// Read from the timer to disarm it
				ssize_t bytes_read = read(timerfd, &expirations, sizeof(expirations));
				if (bytes_read <= 0) {
					ERROR(conf, "Could not read from timerfd: %s\n", strerror(errno));
					goto ERROR;
				}

				r = fireperf_dump_stats(conf, stats, FIREPERF_MODE_CLIENT);
				if (r)
					goto ERROR;

			// Handle connection sockets
			} else {
				// Has the socket been disconnected?
				if (events[i].events & EPOLLHUP) {
					if (epoll_ctl(epollfd, EPOLL_CTL_DEL, fd, NULL)) {
						ERROR(conf, "Could not remove socket file descriptor from epoll(): %s\n",
							strerror(errno));
						goto ERROR;
					}

					close(fd);

					stats->open_connections--;

				} else if (events[i].events & EPOLLOUT) {
					r = handle_connection_ready(conf, stats, fd, pool);
					if (r)
						goto ERROR;
				}
			}
		}
	}

	// All okay
	r = 0;

ERROR:
	if (pool)
		fireperf_random_pool_free(pool);

	return r;
}
