#!/usr/local/bin/python3.12

from __future__ import print_function
from datetime import datetime
from json import loads
import re
import os
import time
import sys
import socket
import signal
import posix
import errno
import optparse
import random
import subprocess

VERSION = "0.93-22" # Automatically filled in.

random.seed()

# For Zeek-bundled installation, explictly add the Python path we've
# installed ourselves under, so we find the SubnetTree module:
ZEEK_PYTHON_DIR = '/usr/local/lib/zeek/python'
if os.path.isdir(ZEEK_PYTHON_DIR):
    sys.path.append(os.path.abspath(ZEEK_PYTHON_DIR))
else:
    ZEEK_PYTHON_DIR = None

import SubnetTree

class IntervalUpdate:
    pass

class IntervalList:
    def __init__(self):
        self.ints = []
        self.start = -1

    def finish(self):
        for i in self.ints:
            if i:
                i.start += self.start
                i.end += self.start
                i.applySampleFactor()

    def writeR(self, file, top_ports):
        file = open(file, "w")
        Interval.makeRHeader(file, top_ports)
        next_start = self.start

        for i in self.ints:
            if i:
                i.formatForR(file, top_ports)
                next_start = i.end
            else:
                empty = Interval()
                empty.start = next_start
                empty.end = next_start + Options.ilen
                empty.formatForR(file, top_ports)
                next_start = empty.end

        file.close()

class Interval:
    def __init__(self):
        self.start = 1e20
        self.end = 0
        self.bytes = 0
        self.payload = 0
        self.pkts = 0
        self.frags = 0
        self.updates = 0
        self.ports = {}
        self.prots = {}
        self.servs = {}
        self.srcs = {}
        self.dsts = {}
        self.states = {}

    def update(self, iupdate, adjusttime=True):

        self.updates += 1
        self.pkts += iupdate.pkts
        self.bytes += iupdate.bytes
        self.payload += iupdate.payload
        self.frags += iupdate.frags

        if Options.bytes:
            incr = iupdate.bytes
        else:
            incr = 1

        # For packets, we need to look at the source port, too.
        if not Options.conns:
            if ( iupdate.src_port < 1024 ) or \
                ( not Ports and not Options.save_mem ) or \
                ( Ports and iupdate.src_port in Ports ) or \
                Options.storeports:
                try:
                    self.ports[iupdate.src_port] += incr
                except KeyError:
                    self.ports[iupdate.src_port] = incr

        if ( iupdate.dst_port < 1024 ) or \
            ( not Ports and not Options.save_mem ) or \
            ( Ports and iupdate.dst_port in Ports ) or \
            Options.storeports:
            try:
                self.ports[iupdate.dst_port] += incr
            except KeyError:
                self.ports[iupdate.dst_port] = incr

        try:
            self.prots[iupdate.prot] += incr
        except KeyError:
            self.prots[iupdate.prot] = incr

        try:
            self.servs[iupdate.service] += incr
        except KeyError:
            self.servs[iupdate.service] = incr

        try:
            self.states[iupdate.state] += incr
        except KeyError:
            self.states[iupdate.state] = incr

        if adjusttime:
            if iupdate.start < self.start:
                self.start = iupdate.start

            if iupdate.end > self.end:
                self.end = iupdate.end

        if not Options.save_mem and not Options.R:
            try:
                self.srcs[iupdate.src_ip] += incr
            except KeyError:
                self.srcs[iupdate.src_ip] = incr

            try:
                self.dsts[iupdate.dst_ip] += incr
            except KeyError:
                self.dsts[iupdate.dst_ip] = incr

    def applySampleFactor(self):
        if Options.factor == 1:
            return

        self.bytes *= Options.factor
        self.payload *= Options.factor
        self.pkts *= Options.factor
        self.frags *= Options.factor
        self.updates *= Options.factor

        for i in self.ports.keys():
             self.ports[i] *= Options.factor
        for i in self.prots.keys():
             self.prots[i] *= Options.factor
        for i in self.servs.keys():
             self.servs[i] *= Options.factor
        for i in self.srcs.keys():
             self.srcs[i] *= Options.factor
        for i in self.dsts.keys():
             self.dsts[i] *= Options.factor
        for i in self.states.keys():
             self.states[i] *= Options.factor

    def format(self, conns=False, title=""):
        def fmt(tag, count, total=-1, sep=" - "):
            if total >= 0:
                try:
                    return "%s %5.1f%%%s" % (tag, (float(count) / total) * 100, sep)
                except ZeroDivisionError:
                    return "%s (??%%)%s" % (tag, sep)

            return "%s %s%s" % (tag, formatVal(count), sep)

        s = "\n>== %s === %s - %s\n   - " % (title, isoTime(self.start), isoTime(self.end))

        if not conns:
            # Information for packet traces.
            s += fmt("Bytes", self.bytes) + \
                 fmt("Payload", self.payload) + \
                 fmt("Pkts", self.pkts) + \
                 fmt("Frags", self.frags, self.pkts)

            try:
                mbit = self.bytes * 8 / 1024.0 / 1024.0 / (self.end - self.start)
            except ZeroDivisionError:
                mbit = 0

            s += "MBit/s %8.1f - " % mbit

        else:
            # Information for connection summaries.
            s += fmt("Connections", self.pkts) + \
                 fmt("Payload", self.payload)

        if Options.factor != 1:
            s += "Sampling %.2f%% -" % ( 100.0 / Options.factor )

        if Options.verbose:
            ports = topx(self.ports)
            srcs = topx(self.srcs)
            dsts = topx(self.dsts)
            prots = topx(self.prots)
            servs = topx(self.servs)

            servs = [ (count, svc.replace("icmp-", "i-").replace("netbios", "nb")) for count, svc in servs ]

            # Default column widths for IP addresses.
            srcwidth = 18
            dstwidth = 18

            # Check all IP addrs to see if column widths need to be increased
            # (due to the presence of long IPv6 addresses).
            src_over = 0
            dst_over = 0
            for i in range(Options.topx):
                for dict in (srcs, dsts):
                    try:
                        item = inet_ntox(dict[i][1])
                    except IndexError:
                        continue

                    # Note: 15 is longest possible IPv4 address.
                    oversize = len(item) - 15
                    if oversize > 0:
                        if dict is srcs:
                            src_over = max(src_over, oversize)
                        elif dict is dsts:
                            dst_over = max(dst_over, oversize)

            # Increase column widths, if necessary.
            srcwidth += src_over
            dstwidth += dst_over

            s += "\n     %-5s        | %-*s        | %-*s        | %-18s | %1s |" \
                % ("Ports", srcwidth, "Sources", dstwidth, "Destinations", "Services", "Protocols")

            if conns:
                s += " States        |"
                states = (topx(self.states), 6)
            else:
                states = ({}, 0)

            s += "\n"

            addrs = []

            for i in range(Options.topx):

                s += "     "

                for (dict, length) in ((ports, 5), (srcs, srcwidth), (dsts, dstwidth), (servs, 11), (prots, 2), states):
                    try:
                        item = None
                        if dict is srcs or dict is dsts:
                            item = inet_ntox(dict[i][1])
                            if Options.resolve:
                                addrs += [dict[i][1]]
                                item += "#%d" % len(addrs)
                        else:
                            item = str(dict[i][1])

                        s += fmt("%-*s" % (length, item), dict[i][0], (Options.bytes and self.bytes or self.pkts), sep=" | ")
                    except:
                        s += " " * length + "        | "

                s += "\n"

            if Options.resolve:
                s += "\n        "
                for i in range(1, len(addrs)+1):
                    s +=  "#%d=%s  " % (i, gethostbyaddr(inet_ntox(addrs[i-1])))
                    if i % 3 == 0:
                        s += "\n        "

            s += "\n"


        return s

    def makeRHeader(f, top_ports):
        f.write("start end count bytes payload frags srcs dsts prot.tcp prot.udp prot.icmp ")
        f.write("%s " % " ".join(["state.%s" % s.lower() for s in States]))
        f.write("%s " % " ".join(["top.port.%d" % (i+1) for i in range(0,Options.topx)]))
        f.write(" ".join(["port.%d" % i for i in range(0,1024)]))
        if not Options.save_mem:
            f.write(" %s" % " ".join(["port.%d" % p[1] for p in top_ports if p[1] >= 1024]))
        f.write("\n")

    makeRHeader = staticmethod(makeRHeader)

    def formatForR(self, f, top_ports):
        f.write("%.16g %.16g %s %s %s %s " % (self.start, self.end, self.pkts, self.bytes, self.payload, self.frags))
        f.write("%s %s " % (len(self.srcs), len(self.dsts)))
        f.write("%s %s %s " % (self.prots.get(6, 0), self.prots.get(17, 0), self.prots.get(1, 0)))
        f.write("%s " % " ".join([str(self.states.get(i, 0)) for i in States]))
        f.write("%s " % " ".join([str(p[1]) for p in topx(self.ports, True)]))
        f.write(" ".join([str(self.ports.get(i, 0)) for i in range(0,1024)]))
        if not Options.save_mem:
            f.write(" %s" % " ".join([str(self.ports.get(p[1], 0)) for p in top_ports if p[1] >= 1024]))
        f.write("\n")


    def __str__(self):
        return self.format(True)

def topx(dict, fill_if_empty=False):
    top = sorted([ (count, val) for val, count in dict.items() ], reverse=True)

    # Filter out zero vals.
    top = [(val, count) for (val, count) in top if count != 0]

    if fill_if_empty and len(top) < Options.topx:
        top += [(0,0)] * Options.topx

    return top[:Options.topx]

def findInterval(time, intervals):

    if intervals.start < 0:
        intervals.start = int(time / Options.ilen) * Options.ilen

    i = (time - intervals.start) / Options.ilen
    idx = int(i)

    # Interval may be earlier than current start
    if i < 0:

        if float(idx) != i:
            # minus 1 since we will multiply by -1
            idx -= 1

        idx *= -1

        for j in intervals.ints:
            if j:
                j.start += Options.ilen * idx
                j.end += Options.ilen * idx

        intervals.ints = ([None] * idx) + intervals.ints
        intervals.start = int(time / Options.ilen) * Options.ilen
        first = time
        idx = 0

    # Interval may be later than current end
    while idx >= len(intervals.ints):
        intervals.ints += [None]

    if not intervals.ints[idx]:
        interv = Interval()
        interv.start = float(idx * Options.ilen)
        interv.end =  float((idx+1) * Options.ilen)
        intervals.ints[idx] = interv
        return interv

    return intervals.ints[idx]

def isoTime(t):
    if t == 1e20 or t == 0:
        return "N/A"
    return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime(t))

def iso2Epoch(ts):
     p = '%Y-%m-%dT%H:%M:%S.%fZ'
     epoch = datetime(1970, 1, 1)
     time = (datetime.strptime(ts, p) - epoch).total_seconds()
     return time

def readPcap(file):
    global Total
    global Incoming
    global Outgoing

    proc = subprocess.Popen("ipsumdump -r %s --timestamp --src --dst --sport --dport --length --protocol --fragment --payload-length -Q" % file, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    for line in proc.stdout:
        line = line.decode()

        if line.startswith("!"):
            continue

        if Options.sample > 0 and random.random() > Options.sample:
            continue

        f = line.split()

        if len(f) < 10:
            print("Ignoring corrupt line: '%s'" % line.strip(), file=sys.stderr)
            continue

        try:
            time = float(f[0])
        except ValueError:
            print("Ignoring corrupt line: '%s'" % line.strip(), file=sys.stderr)
            continue

        if time < Options.mintime or time > Options.maxtime:
            continue

        if Options.chema:
            if f[6] != "T" or f[9] == "0":
                continue

        if Options.tcp and f[6] != "T":
            continue

        if Options.udp and f[6] != "U":
            continue

        iupdate = IntervalUpdate()
        iupdate.pkts = 1
        iupdate.bytes = int(f[5])
        iupdate.payload = int(f[8])
        iupdate.service = ""

        try:
            iupdate.src_ip = inet_xton(f[1])
        except socket.error:
            iupdate.src_ip = unspecified_addr(f[1])

        if iupdate.src_ip in ExcludeNets:
            continue

        try:
            iupdate.src_port = int(f[3])
        except:
            iupdate.src_port = 0

        try:
            iupdate.dst_ip = inet_xton(f[2])
        except socket.error:
            iupdate.dst_ip = unspecified_addr(f[2])

        if iupdate.dst_ip in ExcludeNets:
            continue

        try:
            iupdate.dst_port = int(f[4])
        except:
            iupdate.dst_port = 0

        try:
            iupdate.prot = Protocols[f[6]]
        except KeyError:
            iupdate.prot = 0
        iupdate.state = 0
        iupdate.start = time
        iupdate.end = time

        if f[7] != ".":
            iupdate.frags = 1
        else:
            iupdate.frags = 0

        if Options.external:
            if iupdate.src_ip in LocalNetsIntervals and iupdate.dst_ip in LocalNetsIntervals:
                continue

        Total.update(iupdate)

        if Options.ilen > 0:
            interval = findInterval(time, TotalIntervals)
            interval.update(iupdate, adjusttime=False)

        if Options.localnets:
            try:
                LocalNetsIntervals[iupdate.src_ip].update(iupdate)
                Outgoing.update(iupdate)
                if Options.ilen > 0:
                    interval = findInterval(time, OutgoingIntervals)
                    interval.update(iupdate, adjusttime=False)
            except KeyError:
                try:
                    LocalNetsIntervals[iupdate.dst_ip].update(iupdate)
                    Incoming.update(iupdate)
                    if Options.ilen > 0:
                        interval = findInterval(time, IncomingIntervals)
                        interval.update(iupdate, adjusttime=False)
                except KeyError:
                    global NonLocalCount
                    NonLocalCount += 1
                    if NonLocalCount < Options.topx:
                        NonLocalConns[(iupdate.src_ip, iupdate.dst_ip)] = 1

    status = proc.wait()
    if status != 0:
        print("ipsumdump returned exit status of %d" % status, file=sys.stderr)


Protocols = { "T": 6, "tcp": 6, "U": 17, "udp": 17, "I": 1, "icmp": 1 }
States = ["OTH", "REJ", "RSTO", "RSTOS0", "RSTR", "RSTRH", "S0", "S1", "S2", "S3", "SF", "SH", "SHR"]

def readConnSummaries(file):
    # Determine the field separator, unset field string, and field indices
    # for the specified conn.log file.
    (field_sep, unset_field, idx, max_idx_1, is_json, scope_separator) = getLogInfo(file)

    while True:
        try:
            for line in open(file):
                # Skip log metadata lines.
                if line[0] != "#":
                    parseConnLine(line, field_sep, unset_field, idx, max_idx_1, is_json, scope_separator)

        except IOError as e:
            if e.errno == errno.EINTR or e.errno == errno.EAGAIN:
                continue

            print(e, file=sys.stderr)

        return

def getLogInfo(file):
    is_json = False
    scope_separator = "."

    if Options.conn_version == 0:
        with open(file, "r") as fin:
            line = fin.readline()

            # Guess the conn.log version by checking for a metadata line.
            if line[0] == "#":
                Options.conn_version = 2
            elif line[0] == "{":
                Options.conn_version = 2
                is_json = True
                f = loads(line)
                if "id.orig_h" not in f:
                    pattern = re.compile("id(.)orig_h$")
                    for field in f:
                        m = pattern.match(field)
                        if m:
                            scope_separator = m.group(1)
            else:
                # Guess the conn.log version by looking at the number of
                # fields we have.
                m = line.split()

                if len(m) < 15:
                    Options.conn_version = 1
                else:
                    Options.conn_version = 2

    if Options.conn_version == 1:
        # Field names needed by this script, listed here in same order as
        # found in bro version 1.x conn.log.
        field_names = ("ts", "duration", "id.orig_h", "id.resp_h", "service", "id.orig_p", "id.resp_p", "proto", "orig_bytes", "resp_bytes", "conn_state")

        idx = {}
        for field in field_names:
            idx[field] = len(idx)

        # max_idx_1 is max. index value plus 1
        max_idx_1 = len(field_names)

        field_sep = " "
        unset_field = "?"

        return (field_sep, unset_field, idx, max_idx_1, is_json, scope_separator)

    # Field names needed by this script, listed here in same order as
    # found in conn.log.
    field_names = ("ts", "uid", "id.orig_h", "id.orig_p", "id.resp_h", "id.resp_p", "proto", "service", "duration", "orig_bytes", "resp_bytes", "conn_state")

    field_sep = "\t"
    unset_field = "-"
    idx = {}

    with open(file, "r") as fin:
        firstline = True

        for line in fin:
            if firstline:
                firstline = False
                if line[0] == "{":
                    is_json = True
                    f = loads(line)
                    if "id.orig_h" not in f:
                        pattern = re.compile("id(.)orig_h$")
                        for field in f:
                            m = pattern.match(field)
                            if m:
                                scope_separator = m.group(1)
                    break

            if line[0] != "#":
                break

            # Remove trailing '\n' so that it's not included in last item of
            # results from split().
            if line[-1] == "\n":
                line = line[:-1]

            if line.startswith("#separator"):
                try:
                    field_sep = line.split()[1]
                    if field_sep.startswith("\\x"):
                        field_sep = chr(int(field_sep[2:], 16))
                except (IndexError, ValueError):
                    # If no value found, then just use default.
                    print("Ignoring bad '#separator' line", file=sys.stderr)

            elif line.startswith("#unset_field"):
                try:
                    unset_field = line.split(field_sep)[1]
                except IndexError:
                    # If no value found, then just use default.
                    print("Ignoring bad '#unset_field' line", file=sys.stderr)

            elif line.startswith("#fields"):
                fields = line.split(field_sep)[1:]

                if "id.orig_h" not in fields:
                    # Either the "#fields" line is corrupt, or we're using a
                    # non-default field scope separator.
                    pattern = re.compile("id(.)orig_h$")
                    for field in fields:
                        m = pattern.match(field)
                        if m:
                            scope_separator = m.group(1)

                max_idx_1 = 0
                idx = {}
                for field in field_names:
                    try:
                        # Use original field name in "idx" (even if there is a
                        # non-default field scope separator).
                        idx[field] = fields.index(field.replace(".", scope_separator))
                    except ValueError as err:
                        # If any field is missing, then just use defaults.
                        idx = {}
                        print("Ignoring bad '#fields' line: %s" % err, file=sys.stderr)
                        break

                    max_idx_1 = max(max_idx_1, idx[field])

                max_idx_1 += 1

    # If no fields metadata was found, then just use default values.
    if not idx:
        # max_idx_1 is max. index value plus 1
        max_idx_1 = len(field_names)

        for field in field_names:
            idx[field] = len(idx)

    return (field_sep, unset_field, idx, max_idx_1, is_json, scope_separator)

def parseConnLine(line, field_sep, unset_field, idx, max_idx_1, is_json, scope_separator):
    global Total, Incoming, Outgoing, LastOutputTime, BaseTime

    if Options.sample > 0 and random.random() > Options.sample:
        return

    # Remove trailing '\n' so that it's not included in last item of
    # results from split().
    if line[-1] == "\n":
        line = line[:-1]

    if is_json:
        f = loads(line)
    else:
        f = line.split(field_sep, max_idx_1)
        if len(f) < max_idx_1:
            print("Ignoring corrupt line: %s" % line, file=sys.stderr)
            return

    if is_json:
        proto_val = f["proto"]
    else:
        proto_val = f[idx["proto"]]

    if Options.tcp and proto_val != "tcp":
        return

    if Options.udp and proto_val != "udp":
        return

    try:
        if is_json:
            # Check for ISO8601
            if "Z" not in str(f["ts"]):
                time = float(f["ts"])
            else:
                time = iso2Epoch(f["ts"])
        else:
            time = float(f[idx["ts"]])
    except ValueError:
        print("Invalid starting time on line: %s" % line, file=sys.stderr)
        return

    if is_json:
        try:
            duration_str = f["duration"]
        except KeyError:
            duration_str = unset_field
    else:
        duration_str = f[idx["duration"]]

    if duration_str != unset_field:
        try:
            duration = float(duration_str)
        except ValueError:
            # The default unset/empty field string can be changed from "-"
            # and in that case, it's hard to know if this exception is due
            # to that or because we're looking at the wrong column entirely,
            # so just print an error and continue with the assumption of
            # an unset/empty duration column.
            print("Invalid duration on line: %s" % line, file=sys.stderr)
            duration = 0
    else:
        duration = 0

    if time < Options.mintime or (time + duration) > Options.maxtime:
        return

    if not BaseTime:
        BaseTime = time
        LastOutputTime = time

    if time - LastOutputTime > 3600:
        # print("%d hours processed" % int((time - BaseTime) / 3600), file=sys.stderr)
        LastOutputTime = time

    if is_json:
        try:
            orig_bytes_str = f["orig_bytes"]
        except KeyError:
            orig_bytes_str = unset_field

        try:
            resp_bytes_str = f["resp_bytes"]
        except KeyError:
            resp_bytes_str = unset_field
    else:
        orig_bytes_str = f[idx["orig_bytes"]]
        resp_bytes_str = f[idx["resp_bytes"]]

    try:
        bytes_orig = int(orig_bytes_str)
    except ValueError:
        bytes_orig = 0

    try:
        bytes_resp = int(resp_bytes_str)
    except ValueError:
        bytes_resp = 0

    iupdate = IntervalUpdate()
    iupdate.pkts = 1 # no. connections
    iupdate.bytes = bytes_orig + bytes_resp

    try:
        if is_json:
            iupdate.src_ip = inet_xton(f["id" + scope_separator + "orig_h"])
            iupdate.src_port = int(f["id" + scope_separator + "orig_p"])
            iupdate.dst_ip = inet_xton(f["id" + scope_separator + "resp_h"])
            iupdate.dst_port = int(f["id" + scope_separator + "resp_p"])
        else:
            iupdate.src_ip = inet_xton(f[idx["id.orig_h"]])
            iupdate.src_port = int(f[idx["id.orig_p"]])
            iupdate.dst_ip = inet_xton(f[idx["id.resp_h"]])
            iupdate.dst_port = int(f[idx["id.resp_p"]])

        if iupdate.src_ip in ExcludeNets:
            return
        if iupdate.dst_ip in ExcludeNets:
            return
        iupdate.prot = Protocols[proto_val]

    except (KeyError, ValueError):
        print("Ignoring corrupt line: %s" % line, file=sys.stderr)
        return

    try:
        if is_json:
            iupdate.service = f["service"]
        else:
            iupdate.service = f[idx["service"]]

        if iupdate.service[-1] == "?":
            iupdate.service = iupdate.service[:-1]
    except (KeyError, IndexError):
        iupdate.service = unset_field

    iupdate.frags = 0
    if is_json:
        iupdate.state = f["conn_state"]
    else:
        iupdate.state = f[idx["conn_state"]]
    iupdate.start = time
    iupdate.end = time + duration

    payload_orig = bytes_orig
    payload_resp = bytes_resp

    if duration:
        bytes_to_mbps = 8 / (1024 * 1024 * duration)

        if payload_orig * bytes_to_mbps > 700:
            # Bandwidth exceed due to Bro bug.
            if Options.conn_version == 2:
                if is_json:
                    uid = f["uid"]
                else:
                    uid = f[idx["uid"]]
                print("UID %s originator exceeds bandwidth" % uid, file=sys.stderr)
            else:
                print("%.6f originator exceeds bandwidth" % time, file=sys.stderr)
            payload_orig = 0

        if payload_resp * bytes_to_mbps > 700:
            # Bandwidth exceed due to Bro bug.
            if Options.conn_version == 2:
                if is_json:
                    uid = f["uid"]
                else:
                    uid = f[idx["uid"]]
                print("UID %s originator exceeds bandwidth" % uid, file=sys.stderr)
            else:
                print("%.6f originator exceeds bandwidth" % time, file=sys.stderr)
            payload_resp = 0

    iupdate.payload = payload_orig + payload_resp

    if Options.external:
        if iupdate.src_ip in LocalNetsIntervals and iupdate.dst_ip in LocalNetsIntervals:
            return

    Total.update(iupdate)

    if Options.ilen > 0:
        interval = findInterval(time, TotalIntervals)
        interval.update(iupdate, adjusttime=False)

    if Options.localnets:

        try:
            LocalNetsIntervals[iupdate.src_ip].update(iupdate)
            Outgoing.update(iupdate)
            if Options.ilen > 0:
                interval = findInterval(time, OutgoingIntervals)
                interval.update(iupdate, adjusttime=False)
        except KeyError:
            try:
                LocalNetsIntervals[iupdate.dst_ip].update(iupdate)
                Incoming.update(iupdate)
                if Options.ilen > 0:
                    interval = findInterval(time, IncomingIntervals)
                    interval.update(iupdate, adjusttime=False)
            except KeyError:
                global NonLocalCount
                NonLocalCount += 1
                if NonLocalCount < Options.topx:
                    NonLocalConns[(iupdate.src_ip, iupdate.dst_ip)] = 1

Cache = {}

def gethostbyaddr( ip, timeout = 5, default = "<???>" ):

    try:
        return Cache[ip]
    except LookupError:
        pass

    host = default
    ( pin, pout ) = os.pipe()

    pid = os.fork()

    if not pid:
        # Child
        os.close( pin )
        try:
            host = socket.gethostbyaddr( ip )[0]
        except socket.herror:
            pass

        host = host.encode()

        os.write( pout, host )
        posix._exit(127)

    #Parent
    os.close( pout )

    signal.signal( signal.SIGALRM, lambda sig, frame: os.kill( pid, signal.SIGKILL ) )
    signal.alarm( timeout )

    try:
        childpid, status = os.waitpid(pid, 0)

        if os.WIFEXITED(status) and os.WEXITSTATUS(status) == 127:
            host = os.read(pin, 8192)
            host = host.decode()
    except OSError:
        # If the child process is killed while waitpid() is waiting, then
        # only Python 2 (not Python 3) raises OSError.
        pass

    signal.alarm( 0 )

    os.close( pin )

    Cache[ip] = host

    return host

def formatVal(val):
    for (prefix, unit, factor) in (("", "g", 1e9), ("", "m", 1e6), ("", "k", 1e3), (" ", "", 1e0)):
        if val >= factor:
            return "%s%3.1f%s" % (prefix, val / factor, unit)
    return val # Should not happen

def readNetworks(file):

    nets = []

    for line in open(file):
        line = line.strip()
        if not line or line.startswith("#"):
            continue

        fields = line.split()
        nets += [(fields[0], " ".join(fields[1:]))]

    return nets

def inet_xton(ipstr):
    family = socket.AF_INET

    if ':' in ipstr:
        family = socket.AF_INET6

    return socket.inet_pton(family, ipstr)

def inet_ntox(ipaddr):
    family = socket.AF_INET

    if len(ipaddr) != 4:
        family = socket.AF_INET6

    return socket.inet_ntop(family, ipaddr)

def unspecified_addr(ipstr):
    if ':' in ipstr:
        ipaddr = inet_xton("::")
    else:
        ipaddr = inet_xton("0.0.0.0")

    return ipaddr


####### Main

Total = Interval()
Incoming = Interval()
Outgoing = Interval()

TotalIntervals = IntervalList()
IncomingIntervals = IntervalList()
OutgoingIntervals = IntervalList()

BaseTime = None
LastOutputTime = None

LocalNets = {}
LocalNetsIntervals = SubnetTree.SubnetTree(True)
NonLocalConns = {}
NonLocalCount = 0

Ports = None

ExcludeNets = SubnetTree.SubnetTree(True)

optparser = optparse.OptionParser(usage="%prog [options] <pcap-file>|<conn-summaries>", version=VERSION)
optparser.add_option("-b", "--bytes", action="store_true", dest="bytes", default=False,
                     help="count fractions in terms of bytes rather than packets/connections")
optparser.add_option("-c", "--conn-summaries", action="store_true", dest="conns", default=False,
                     help="input file contains Zeek connection summaries")
optparser.add_option("--conn-version", action="store", type="int", dest="conn_version", default=0,
                     help="when used with -c, specify '1' for use with Bro version 1.x connection logs, or '2' for use with Bro 2.x format. '0' tries to guess the format")
optparser.add_option("-C", "--chema", action="store_true", dest="chema", default=False,
                     help="for packets: include only TCP, ignore when seq==0")
optparser.add_option("-e", "--external", action="store_true", dest="external", default=False,
                     help="ignore strictly internal traffic")
optparser.add_option("-E", "--exclude-nets", action="store", type="string", dest="excludenets", default=None,
                     help="excludes CIDRs in file from analysis")
optparser.add_option("-i", "--intervals", action="store", type="string", dest="ilen", default="0",
                     help="create summaries for time intervals of given length (seconds, or use suffix of 'h' for hours, or 'm' for minutes)")
optparser.add_option("-l", "--local-nets", action="store", type="string", dest="localnets", default=None,
                     help="differentiate in/out based on CIDRs in file")
optparser.add_option("-n", "--topn", action="store", type="int", dest="topx", default=10,
                     help="show top <n>")
optparser.add_option("-p", "--ports", action="store", type="string", dest="ports", default=None,
                     help="include only ports listed in file")
optparser.add_option("-P", "--write-ports", action="store", type="string", dest="storeports", default=None,
                     help="write top total/incoming/outgoing ports into file")
optparser.add_option("-r", "--resolve-host-names", action="store_true", dest="resolve", default=False,
                     help="resolve host names")
optparser.add_option("-R", "--R", action="store", type="string", dest="R", default=None, metavar="tag",
                     help="write output suitable for R into files <tag.*>")
optparser.add_option("-s", "--sample-factor", action="store", type="int", dest="factor", default=1,
                     help="sample factor of input")
optparser.add_option("-S", "--do-sample", action="store", type="float", dest="sample", default=-1.0,
                     help="sample input with probability (0.0 < prob < 1.0)")
optparser.add_option("-m", "--save-mem", action="store_true", dest="save_mem", default=False,
                     help="do not make memory-expensive statistics")
optparser.add_option("-t", "--tcp", action="store_true", dest="tcp", default=False,
                     help="include only TCP")
optparser.add_option("-u", "--udp", action="store_true", dest="udp", default=False,
                     help="include only UDP")
optparser.add_option("-U", "--min-time", action="store", type="string", dest="mintime", default=None,
                     help="minimum time in ISO format (e.g. 2005-12-31-23-59-00)")
optparser.add_option("-v", "--verbose", action="store_true", dest="verbose", default=False,
                     help="show top-n for every interval")
optparser.add_option("-V", "--max-time", action="store", type="string", dest="maxtime", default=None,
                     help="maximum time in ISO format")

(Options, args) = optparser.parse_args()

if len(args) > 2:
    optparser.error("Wrong number of arguments")

file = "-"

if len(args) > 0:
    file = args[0]

if Options.external and not Options.localnets:
    print("Need -l for -e.", file=sys.stderr)
    sys.exit(1)

if Options.topx < 0:
    print("Top-n value cannot be negative", file=sys.stderr)
    sys.exit(1)

# If reading pcap traces, then ipsumdump is required.
if not Options.conns:
    proc = subprocess.Popen("ipsumdump -v", shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if proc.wait() != 0:
        print("Can't read pcap trace: 'ipsumdump' is required.", file=sys.stderr)
        sys.exit(1)

# Make per-interval summaries.
if Options.ilen:
    if Options.ilen.endswith("m"):
        Options.ilen = int(Options.ilen[:-1]) * 60
    elif Options.ilen.endswith("h"):
        Options.ilen = int(Options.ilen[:-1]) * 60 * 60
    else:
        Options.ilen = int(Options.ilen)

    if Options.ilen < 0:
        print("Interval length cannot be negative", file=sys.stderr)
        sys.exit(1)


# Read local networks.
if Options.localnets:

    for (net, txt) in readNetworks(Options.localnets):
        try:
            i = Interval()
            LocalNetsIntervals[net] = i
            LocalNets[net] = (txt, i)
        except KeyError:
            print("Can't parse local network '%s'" % net, file=sys.stderr)

# Read networks to exclude.
if Options.excludenets:
    for (net, txt) in readNetworks(Options.excludenets):
        try:
            ExcludeNets[net] = txt
        except KeyError:
            print("Can't parse exclude network '%s'" % net, file=sys.stderr)

# Read ports file.
if Options.ports:
    Ports = {}
    for line in open(Options.ports):
        Ports[int(line.strip())] = 1

# Parse time-range if given.
if Options.mintime:
    Options.mintime = time.mktime(time.strptime(Options.mintime, "%Y-%m-%d-%H-%M-%S"))
else:
    Options.mintime = 0

if Options.maxtime:
    Options.maxtime = time.mktime(time.strptime(Options.maxtime, "%Y-%m-%d-%H-%M-%S"))
else:
    Options.maxtime = 1e20

if Options.factor <= 0:
    print("Sample factor must be > 0", file=sys.stderr)
    sys.exit(1)

if Options.sample > 0:
    if Options.sample > 1.0:
        print("Sample probability cannot be > 1", file=sys.stderr)
        sys.exit(1)
    Options.factor = 1.0 / Options.sample

if file == "-":
    file = "/dev/stdin"

try:
    if Options.conns:
        readConnSummaries(file)
    else:
        readPcap(file)
except KeyboardInterrupt:
    pass

TotalIntervals.finish()
IncomingIntervals.finish()
OutgoingIntervals.finish()

Total.applySampleFactor()
Incoming.applySampleFactor()
Outgoing.applySampleFactor()

unique = {}
for (count, port) in topx(Total.ports) + topx(Incoming.ports) + topx(Outgoing.ports):
    unique[port] = (count, port)

top_ports = sorted(unique.values(), key=lambda x: x[1])

if Options.storeports:
    f = open(Options.storeports, "w")
    for p in top_ports:
        f.write("%s\n" % p[1])
    f.close()

if Options.R:
    file = open(Options.R + ".dat", "w")

    file.write("tag ")
    Interval.makeRHeader(file, top_ports)
    file.write("total ")
    Total.formatForR(file, top_ports)

    file.write("incoming ")
    Incoming.formatForR(file, top_ports)
    file.write("outgoing ")
    Outgoing.formatForR(file, top_ports)

    for (net, data) in LocalNets.items():

        (txt, i) = data

        if i.updates:
            file.write("%s " % net.replace(" ", "_"))
            i.start += TotalIntervals.start
            i.end += TotalIntervals.start
            i.applySampleFactor()
            i.formatForR(file, top_ports)

    file.close()

    TotalIntervals.writeR(Options.R + ".total.dat", top_ports)
    IncomingIntervals.writeR(Options.R + ".incoming.dat", top_ports)
    OutgoingIntervals.writeR(Options.R + ".outgoing.dat", top_ports)

    sys.exit(0)

for i in TotalIntervals.ints:
    if i:
        print(i.format(conns=Options.conns))

Options.verbose = True

print(Total.format(conns=Options.conns, title="Total"))

locals = list(LocalNets.keys())

for net in locals:
    (txt, i) = LocalNets[net]
    if i.updates:
        i.applySampleFactor()

if locals:

    type = "packets"
    if Options.conns:
        type = "connections"

    locals.sort(key=lambda x: LocalNets[x][1].pkts, reverse=True)

    print("\n>== Top %d local networks by number of %s\n" % (Options.topx, type))

    for i in range(min(len(locals), Options.topx)):
        print("    %2d %5s  %-16s %s " % (i+1, formatVal(LocalNets[locals[i]][1].pkts), locals[i], LocalNets[locals[i]][0]))
    print()

    if len(NonLocalConns):
        print("\n>== %d %s did not have any local address. Here are the first %d:\n" % (NonLocalCount, type, Options.topx))

        for (src,dst) in sorted(NonLocalConns.keys()):
            print("    %s <-> %s" % (inet_ntox(src), inet_ntox(dst)))

if Options.localnets:
    print(Incoming.format(conns=Options.conns, title="Incoming"))
    print(Outgoing.format(conns=Options.conns, title="Outgoing"))

for net in locals:
    (txt, i) = LocalNets[net]

    if i.updates:
        print(i.format(conns=Options.conns, title=net + " " + txt))

print("First: %16s (%.6f) Last: %s %.6f" % (isoTime(Total.start), Total.start, isoTime(Total.end), Total.end))
