/*#############################################################################
#                                                                             #
# fireperf - A network benchmarking tool                                      #
# Copyright (C) 2021 IPFire Development Team                                  #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
#############################################################################*/

#include <arpa/inet.h>
#include <errno.h>
#include <getopt.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/resource.h>
#include <sys/time.h>
#include <sys/timerfd.h>
#include <time.h>
#include <unistd.h>

#include "client.h"
#include "main.h"
#include "logging.h"
#include "server.h"
#include "util.h"

static int parse_address(const char* string, struct in6_addr* address6) {
	// Try parsing this address
	int r = inet_pton(AF_INET6, string, address6);

	// Success!
	if (r == 1)
		return 0;

	// Try parsing this as an IPv4 address
	struct in_addr address4;
	r = inet_pton(AF_INET, string, &address4);
	if (r == 1) {
		// Convert to IPv6-mapped address
		address6->s6_addr32[0] = htonl(0x0000);
		address6->s6_addr32[1] = htonl(0x0000);
		address6->s6_addr32[2] = htonl(0xffff);
		address6->s6_addr32[3] = address4.s_addr;

		return 0;
	}

	// Could not parse this
	return 1;
}

static int check_port(int port) {
	if (port <= 0 || port >= 65536) {
		fprintf(stderr, "Invalid port number: %u\n", port);
		return 2;
	}

	return 0;
}

static int parse_port_range(struct fireperf_config* conf, const char* optarg) {
	int first_port, last_port;

	int r = sscanf(optarg, "%d:%d", &first_port, &last_port);
	if (r != 2)
		return 1;

	// Check if both ports are in range
	r = check_port(first_port);
	if (r)
		return r;

	r = check_port(last_port);
	if (r)
		return r;

	if (first_port > last_port) {
		fprintf(stderr, "Invalid port range: %s\n", optarg);
		return 2;
	}

	conf->port = first_port;
	conf->listening_sockets = (last_port - first_port) + 1;

	return 0;
}

static int parse_port(struct fireperf_config* conf, const char* optarg) {
	conf->port = atoi(optarg);
	conf->listening_sockets = 1;

	return check_port(conf->port);
}

static int set_limits(struct fireperf_config* conf) {
	struct rlimit limit;

	// Increase limit of open files
	limit.rlim_cur = limit.rlim_max = conf->parallel + 128;

	int r = setrlimit(RLIMIT_NOFILE, &limit);
	if (r) {
		ERROR(conf, "Could not set open file limit to %lu: %s\n",
			(unsigned long)limit.rlim_max, strerror(errno));
		return 1;
	}

	return 0;
}

static int parse_argv(int argc, char* argv[], struct fireperf_config* conf) {
	static struct option long_options[] = {
		{"client",     required_argument, 0, 'c'},
		{"close",      no_argument,       0, 'x'},
		{"debug",      no_argument,       0, 'd'},
		{"keepalive",  no_argument,       0, 'k'},
		{"parallel",   required_argument, 0, 'P'},
		{"port",       required_argument, 0, 'p'},
		{"server",     no_argument,       0, 's'},
		{"timeout",    required_argument, 0, 't'},
		{"version",    no_argument,       0, 'V'},
		{"zero",       no_argument,       0, 'z'},
		{0, 0, 0, 0},
	};

	int option_index = 0;
	int done = 0;

	while (!done) {
		int c = getopt_long(argc, argv, "c:dkp:st:xzP:V", long_options, &option_index);

		// End
		if (c == -1)
			break;

		switch (c) {
			case 0:
				if (long_options[option_index].flag != 0)
					break;

				printf("option %s", long_options[option_index].name);

				if (optarg)
					printf("  with arg: %s", optarg);

				printf("\n");
				break;

			case '?':
				// getopt_long already printed the error message
				return 1;

			case 'V':
				printf("%s %s\n", PACKAGE_NAME, PACKAGE_VERSION);
				printf("Copyright (C) 2021 The IPFire Project (https://www.ipfire.org/)\n");
				printf("License GPLv3+: GNU GPL version 3 or later <https://gnu.org/licenses/gpl.html>\n");
				printf("This is free software: you are free to change and redistribute it\n");
				printf("There is NO WARRANTY, to the extent permitted by law.\n\n");
				printf("Written by Michael Tremer\n");

				exit(0);
				break;

			case 'c':
				conf->mode = FIREPERF_MODE_CLIENT;

				// Parse the given IP address
				int r = parse_address(optarg, &conf->address);
				if (r) {
					fprintf(stderr, "Could not parse IP address %s\n", optarg);
					return 2;
				}
				break;

			case 'd':
				conf->loglevel = LOG_DEBUG;
				break;

			case 'k':
				conf->keepalive_only = 1;
				break;

			case 'P':
				conf->parallel = strtoul(optarg, NULL, 10);

				if (conf->parallel > MAX_PARALLEL) {
					fprintf(stderr, "Number of parallel connections is too high: %lu\n",
						conf->parallel);
					return 2;
				}
				break;

			case 'p':
				// Try parsing the port range first.
				// If this fails, we try parsing a single port
				r = parse_port_range(conf, optarg);
				if (r == 1)
					r = parse_port(conf, optarg);
				if (r)
					return r;

				break;

			case 's':
				conf->mode = FIREPERF_MODE_SERVER;
				break;

			case 't':
				conf->timeout = strtoul(optarg, NULL, 10);
				break;

			case 'x':
				conf->close = 1;
				break;

			case 'z':
				conf->zero = 1;
				break;

			default:
				done = 1;
				break;
		}
	}

	return 0;
}

int main(int argc, char* argv[]) {
	struct fireperf_config conf = {
		.keepalive_count = DEFAULT_KEEPALIVE_COUNT,
		.keepalive_interval = DEFAULT_KEEPALIVE_INTERVAL,
		.listening_sockets = DEFAULT_LISTENING_SOCKETS,
		.loglevel = DEFAULT_LOG_LEVEL,
		.mode = FIREPERF_MODE_NONE,
		.port = DEFAULT_PORT,
		.parallel = DEFAULT_PARALLEL,
		.timeout = DEFAULT_TIMEOUT,
	};
	struct fireperf_stats stats = { 0 };
	int r;

	// Parse command line
	r = parse_argv(argc, argv, &conf);
	if (r)
		return r;

	// Initialise random number generator
	srandom(time(NULL));

	// Set limits
	r = set_limits(&conf);
	if (r)
		return r;

	// Initialize epoll()
	int epollfd = epoll_create1(0);
	if (epollfd < 0) {
		ERROR(&conf, "Could not initialize epoll(): %s\n", strerror(errno));
		r = 1;
		goto ERROR;
	}

	// Create timerfd() to print statistics
	int timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK|TFD_CLOEXEC);
	if (timerfd < 0) {
		ERROR(&conf, "timerfd_create() failed: %s\n", strerror(errno));
		r = 1;
		goto ERROR;
	}

	struct epoll_event ev = {
		.events  = EPOLLIN,
		.data.fd = timerfd,
	};

	if (epoll_ctl(epollfd, EPOLL_CTL_ADD, timerfd, &ev)) {
		ERROR(&conf, "Could not add timerfd to epoll(): %s\n", strerror(errno));
		r = 1;
		goto ERROR;
	}

	// Let the timer ping us once a second
	struct itimerspec timer = {
		.it_interval.tv_sec = 1,
		.it_value.tv_sec = 1,
	};

	r = timerfd_settime(timerfd, 0, &timer, NULL);
	if (r) {
		ERROR(&conf, "Could not set timer: %s\n", strerror(errno));
		r = 1;
		goto ERROR;
	}

	switch (conf.mode) {
		case FIREPERF_MODE_CLIENT:
			return fireperf_client(&conf, &stats, epollfd, timerfd);

		case FIREPERF_MODE_SERVER:
			return fireperf_server(&conf, &stats, epollfd, timerfd);

		case FIREPERF_MODE_NONE:
			fprintf(stderr, "No mode selected\n");
			r = 2;
			break;
	}

ERROR:
	if (epollfd > 0)
		close(epollfd);

	if (timerfd > 0)
		close(timerfd);

	return r;
}


int fireperf_dump_stats(struct fireperf_config* conf, struct fireperf_stats* stats, int mode) {
	struct timespec now;

	// Fetch the time
	int r = clock_gettime(CLOCK_REALTIME, &now);
	if (r) {
		ERROR(conf, "Could not fetch the time: %s\n", strerror(errno));
		return 1;
	}

	double delta = timespec_delta(&now, &stats->last_printed);

	// Called too soon again?
	if (delta < 0.1)
		return 0;

	// Format timestamp
	const char* timestamp = format_timespec(&now);

	INFO( conf, "--- %s --------------------\n", timestamp);
	DEBUG(conf, "  %-20s: %19.4fs\n", "Delta", delta);
	INFO( conf, "  %-20s: %20u\n", "Open Connection(s)", stats->open_connections);
	INFO( conf, "  %-20s: %18.2f/s\n", "New Connections", stats->connections / delta);

	// Show current bandwidth
	char* bps = NULL;
	switch (mode) {
		case FIREPERF_MODE_CLIENT:
			bps = format_size(stats->bytes_sent * 8 / delta, FIREPERF_FORMAT_BITS);
			break;

		case FIREPERF_MODE_SERVER:
			bps = format_size(stats->bytes_received * 8 / delta, FIREPERF_FORMAT_BITS);
			break;
	}

	if (bps) {
		INFO( conf, "  %-20s: %18s/s\n", "Current Bandwidth", bps);
		free(bps);
	}

	// Total bytes sent/received
	char* total_bytes = NULL;
	switch (mode) {
		case FIREPERF_MODE_CLIENT:
			total_bytes = format_size(stats->total_bytes_sent, FIREPERF_FORMAT_BYTES);
			INFO(conf, "  %-20s: %20s\n", "Total Bytes Sent", total_bytes);
			break;

		case FIREPERF_MODE_SERVER:
			total_bytes = format_size(stats->total_bytes_received, FIREPERF_FORMAT_BYTES);
			INFO(conf, "  %-20s: %20s\n", "Total Bytes Received", total_bytes);
			break;
	}

	// Remember when this was printed last
	stats->last_printed = now;

	// Reset statistics
	stats->connections = 0;
	stats->bytes_received = 0;
	stats->bytes_sent = 0;

	// Cleanup
	if (total_bytes)
		free(total_bytes);

	return 0;
}
