#!/usr/bin/python3
# pythonfilter -- A python framework for Courier global filters
# Copyright (C) 2003-2008  Gordon Messmer <gordon@dragonsdawn.net>
#
# This file is part of pythonfilter.
#
# pythonfilter is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# pythonfilter is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with pythonfilter.  If not, see <http://www.gnu.org/licenses/>.

"""Use: filterctl start pythonfilter

pythonfilter will be activated within the Courier configuration, and
the courierfilter process will start the program.

"""

##############################
##############################

import os
import resource
import sys
import select
import socket
import time
import traceback
import _thread
import courier.config
import courier.control


##############################
# Config Options
##############################

# See if fd #3 is open, indicating that courierfilter is waiting for us
# to notify of init completion.
try:
    os.fstat(3)
    notify_after_init = 1
except OSError:
    notify_after_init = 0

# Set filter_all to 1 if you do not want users to be able to whitelist
# specific senders
filter_all = 1


class LockedCounter():
    def __init__(self):
        self.lock = _thread.allocate_lock()
        self.count = 0

    def inc(self):
        self.lock.acquire()
        self.count += 1
        self.lock.release()

    def dec(self):
        self.lock.acquire()
        self.count -= 1
        self.lock.release()


def open_config():
    # First, locate and open the configuration file.
    config = None
    try:
        config_dirs = ('/etc', '/usr/local/etc', courier.config.sysconfdir)
        for x in config_dirs:
            if os.access('%s/pythonfilter.conf' % x, os.R_OK):
                config = open('%s/pythonfilter.conf' % x)
                break
    except IOError:
        sys.stderr.write('Could not open config file for reading.\n')
        sys.exit()
    if not config:
        sys.stderr.write('Could not locate a configuration file in any of: %s\n' %
                         (config_dirs,))
        sys.exit()
    return config


def run_init_filter(module, module_name):
    if hasattr(module, 'initFilter'):
        try:
            module.initFilter()
        except AttributeError:
            # Log bad modules
            error = sys.exc_info()
            sys.stderr.write('Failed to run "initFilter" '
                             'function from %s\n' %
                             module_name)
            sys.stderr.write('Exception : %s:%s\n' %
                             (error[0], error[1]))
            sys.stderr.write(''.join(traceback.format_tb(error[2])))
    if hasattr(module, 'init_filter'):
        try:
            module.init_filter()
        except AttributeError:
            # Log bad modules
            error = sys.exc_info()
            sys.stderr.write('Failed to run "init_filter" '
                             'function from %s\n' %
                             module_name)
            sys.stderr.write('Exception : %s:%s\n' %
                             (error[0], error[1]))
            sys.stderr.write(''.join(traceback.format_tb(error[2])))


def save_do_filter(module, module_name, bypass, filters):
    if hasattr(module, 'doFilter'):
        try:
            # Store the name of the filter module and a reference to its
            # dofilter function in the "filters" array.
            filters.append((module_name, module.doFilter, bypass))
        except AttributeError:
            # Log bad modules
            import_error = sys.exc_info()
            sys.stderr.write('Failed to load "doFilter" '
                             'function from %s\n' %
                             module_name)
            sys.stderr.write('Exception : %s:%s\n' %
                             (import_error[0], import_error[1]))
            sys.stderr.write(''.join(traceback.format_tb(import_error[2])))
    if hasattr(module, 'do_filter'):
        try:
            # Store the name of the filter module and a reference to its
            # dofilter function in the "filters" array.
            filters.append((module_name, module.do_filter, bypass))
        except AttributeError:
            # Log bad modules
            import_error = sys.exc_info()
            sys.stderr.write('Failed to load "do_filter" '
                             'function from %s\n' %
                             module_name)
            sys.stderr.write('Exception : %s:%s\n' %
                             (import_error[0], import_error[1]))
            sys.stderr.write(''.join(traceback.format_tb(import_error[2])))


def load_filters():
    config = open_config()
    # Load filters
    filters = []
    # Read the lines from the configuration file and load any module listed
    # therein.  Ignore lines that begin with a hash character.
    for x in config.readlines():
        if x[0] in '#\n':
            continue
        words = x.split()
        module_name = words[0]
        # "module for a b c" means that filters a, b, and c will be bypassed
        # if module returns a 2xx code.
        if len(words) > 1 and words[1] == 'for':
            bypass = set(words[2:])
        else:
            bypass = None
        try:
            module = __import__('pythonfilter.%s' % module_name)
            components = module_name.split('.')
            for c in components:
                module = getattr(module, c)
        except ImportError:
            import_error = sys.exc_info()
            sys.stderr.write('Module "%s" indicated in pythonfilter.conf could not be loaded.'
                             '  It may be missing, or one of the modules that it requires may'
                             ' be missing.\n' %
                             module_name)
            sys.stderr.write('Exception : %s:%s\n' %
                             (import_error[0], import_error[1]))
            sys.stderr.write(''.join(traceback.format_tb(import_error[2])))
            sys.exit()
        run_init_filter(module, module_name)
        save_do_filter(module, module_name, bypass, filters)
    return filters


def try_unlink(path):
    try:
        os.unlink(path)
    except OSError:
        pass


def create_socket():
    if filter_all:
        filter_dir = 'allfilters'
    else:
        filter_dir = 'filters'
    filter_socket_path1 = '%s/%s/.pythonfilter' % (courier.config.localstatedir, filter_dir)
    filter_socket_path = '%s/%s/pythonfilter' % (courier.config.localstatedir, filter_dir)
    filter_socket_check1 = '%s/%s/pythonfilter' % (courier.config.localstatedir, 'filters')
    filter_socket_check2 = '%s/%s/pythonfilter' % (courier.config.localstatedir, 'allfilters')
    # Setup socket for courierfilter connection if filters loaded
    # completely
    try:
        # Remove stale sockets to prevent exceptions
        try_unlink(filter_socket_check1)
        try_unlink(filter_socket_check2)
        try_unlink(filter_socket_path1)
        filter_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        filter_socket.bind(filter_socket_path1)
        os.rename(filter_socket_path1, filter_socket_path)
        os.chmod(filter_socket_path, 0o660)
        filter_socket.listen(64)
    except Exception:
        # If the socket creation failed, remove sockets that might
        # exist, so that courier will deliver mail.
        try_unlink(filter_socket_path1)
        try_unlink(filter_socket_path)
        sys.stderr.write('pythonfilter failed to create socket in %s/%s\n' %
                         (courier.config.localstatedir, filter_dir))
        sys.exit()
    return (filter_socket, filter_socket_path)


def thread_time():
    r_used = resource.getrusage(resource.RUSAGE_THREAD)
    return r_used[0] + r_used[1]


def accounting_start(acct, filter):
    # Temporarily store the current thread_time in the filter's slot
    acct[1].append([filter, thread_time()])


def accounting_finish(acct):
    # The filter CPU time used is the current thread_time minus the
    # value of thread_time before the filter ran.
    acct[1][-1][1] = thread_time() - acct[1][-1][1]


def log_accounting(acct, control_paths):
    # Total time for this message is the current thread_time minus
    # the value of thread_time when accounting began.
    total_time = thread_time() - acct[0]
    filter_times = ' '.join(['%s(%f)' % (x[0], x[1]) for x in acct[1]])
    msgid = courier.control.get_lines(control_paths, 'M', 1)[0]
    sys.stderr.write('CPU TIME ACCOUNTING: %s processed with %f seconds: (%s)\n' %
                     (msgid, total_time, filter_times))


##############################
# Filter loop processing function
##############################
def process_message(active_socket, filters, active_filters):
    # Create a file object from the socket so we can read from it
    # using .readline()
    active_socket_file = active_socket.makefile('r')
    # Read content filename and control filenames from socket
    body_path = active_socket_file.readline().strip()
    # Normalize file name:
    if body_path[0] != '/':
        body_path = courier.config.localstatedir + '/tmp/' + body_path
    control_paths = []
    while 1:
        control_path = active_socket_file.readline()
        if control_path == '\n':
            break
        # Normalize file name:
        if control_path[0] != '/':
            control_path = (courier.config.localstatedir + '/tmp/' +
                            control_path)
        control_paths.append(control_path.strip())
    # We have nothing more to read from the socket, so we can close
    # the file object
    active_socket_file.close()
    # Prepare a response message, which is blank initially.  If a filter
    # decides that a message should be rejected, then it must return the
    # reason as an SMTP style response: numeric value and text message.
    # The response can be multiline.
    reply_code = ''
    # Prepare a set of filters that will not be run if a module returns
    # a 2XX code, and specifies a list of filters to bypass.
    bypass = set()
    # Prepare an object to store CPU-time accounting information.  The
    # first value is the CPU-time used before filtering began.  The
    # second is a list of pairs of module-name and CPU-time values.
    acct = [thread_time(), []]
    i_filter = None
    for i_filter in filters:
        # name = i_filter[0]
        # function = i_filter[1]
        # bypass = i_filter[2]
        if i_filter[0] in bypass:
            continue
        accounting_start(acct, i_filter[0])
        try:
            reply_code = i_filter[1](body_path, control_paths)
        except Exception:
            filter_error = sys.exc_info()
            sys.stderr.write('Uncaught exception in "%s" do_filter function: %s:%s\n' %
                             (i_filter[0], filter_error[0], filter_error[1]))
            sys.stderr.write(''.join(traceback.format_tb(filter_error[2])))
            reply_code = ''
        accounting_finish(acct)
        if not isinstance(reply_code, str):
            sys.stderr.write('"%s" do_filter function returned non-string\n' % i_filter[0])
            reply_code = ''
        if reply_code != '':
            if i_filter[2] and reply_code[0] == '2':
                # A list of filters to bypass was provided, so add that
                # list to the bypass set and continue filtering.
                bypass.update(i_filter[2])
            else:
                break
    # If all modules are ok or no filters are loaded, accept message
    #  else, write back error code and message
    if reply_code == '' or i_filter is None:
        active_socket.send('200 Ok'.encode())
    else:
        active_socket.send(reply_code.encode())
        log_file_codes(i_filter[0], reply_code, control_paths)
    log_accounting(acct, control_paths)
    active_socket.close()
    active_filters.dec()
    sys.stderr.flush()


def log_file_codes(module, reply_code, control_paths):
    # This function will not log the original list of recipients specified
    # in the SMTP session.  The recipients logged are subject to alias
    # expansion and also modification of the control files by filters.
    try:
        if not (reply_code.startswith('2') or reply_code.startswith('0')):
            sender = courier.control.get_sender(control_paths)
            for r in courier.control.get_recipients(control_paths):
                sys.stderr.write('pythonfilter %s reject,from=<%s>,addr=<%s>: %s\n' %
                                 (module, sender, r, reply_code))
    except Exception:
        # Any error from the above code is ignored entirely
        pass


def wait_for_message(filter_socket, filters, active_filters):
    try: ready_files = select.select([sys.stdin, filter_socket], [], [])
    except Exception: return True
    # If stdin raised an event, it was closed and we need to exit.
    if sys.stdin in ready_files[0]:
        return False
    if filter_socket in ready_files[0]:
        try:
            active_socket, addr = filter_socket.accept()
            # Now, hand off control to a new thread and continue listening
            # for new connections
            active_filters.inc()
            # Spawn thread and pass filenames as args
            _thread.start_new_thread(process_message, (active_socket, filters, active_filters))
        except Exception:
            # Take care of any potential problems after the above block fails
            sys.stderr.write('pythonfilter failed to accept connection '
                             'from courierfilter\n')
    return True


def close_socket(filter_socket_path, filter_socket):
    ##############################
    # Stop accepting connections when stdin closes, exit when filters are
    # complete.  Do not wait more than 10 seconds, as this might cause
    # problems with "courier restart"
    ##############################
    # Dispose of the unix socket
    os.unlink(filter_socket_path)
    filter_socket.close()


def wait_for_active_filters(active_filters):
    deadline = time.time() + 10
    while(active_filters.count > 0 and time.time() < deadline):
        # Wait for them all to finish
        time.sleep(0.1)


def main():
    ##############################
    # Initialize filter system
    ##############################
    active_filters = LockedCounter()
    filters = load_filters()
    sys.stderr.flush()
    (filter_socket, filter_socket_path) = create_socket()

    # Close fd 3 to notify courierfilter that initialization is complete
    if notify_after_init:
        os.close(3)

    ##############################
    # Listen for connnections on socket
    ##############################
    stdin_open = True
    while stdin_open:
        stdin_open = wait_for_message(filter_socket, filters, active_filters)
    close_socket(filter_socket_path, filter_socket)
    wait_for_active_filters(active_filters)


if __name__ == '__main__':
    main()
