/*
   $Id: bact_classify.cpp,v 1.1.1.1 2004/06/23 05:00:42 taku-ku Exp $;

   Copyright (C) 2003 Taku Kudo, All rights reserved.
   This is free software with ABSOLUTELY NO WARRANTY.

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.
  
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABLITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
  
   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
   02111-1307, USA
*/

#include <vector>
#include <string>
#include <map>
#include "mmap.h"
#include <algorithm>
#include <cstdio>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <iterator>
#include "common.h"
#include "darts.h"

static inline char *read_ptr (char **ptr, size_t size) 
{
  char *r = *ptr;
  *ptr += size;
  return r;
}

template <class T> static inline void read_static (char **ptr, T& value)
{
  char *r = read_ptr (ptr, sizeof (T));
  memcpy (&value, r, sizeof (T));
}

template <typename T1, typename T2>
struct pair_2nd_cmp: public std::binary_function<bool, T1, T2> {
   bool operator () (const std::pair <T1, T2>& x1, const std::pair<T1, T2> &x2)
   {
      return x1.second > x2.second;
   }
};

class BactClassifier
{
private:

  MeCab::Mmap<char> mmap; 
  double *alpha;
  double bias;
  Darts::DoubleArray da;
  std::vector <int>  result;
  std::vector <node_t> tree;
  std::map <std::string, double> rules;
  bool userule;

  void project (std::string prefix, 
		unsigned int size, // # of nodes
		int depth, 
		int pos,
		size_t trie_pos,
		size_t str_pos)
  {
    for (int d = -1 ; d < depth && pos != -1; ++d) {
      int start    = (d == -1) ? tree[pos].child : tree[pos].sibling;
      int newdepth = depth - d;

      for (int l = start; l != -1; l = tree[l].sibling) {
	std::string item = prefix + " " + tree[l].val.key();
	size_t new_trie_pos = trie_pos;
	size_t new_str_pos  = str_pos;
	int id = da.traverse (item.c_str(), new_trie_pos, new_str_pos);
	if (id == -2) continue;
	if (id != -1) {
	  if (userule) 
	    rules.insert (std::make_pair <std::string, double> (item, alpha[id]));
	  result.push_back (id);
	}
	project (item, size+1, newdepth, l, new_trie_pos, new_str_pos);
      }
      
      if (d != -1) pos = tree[pos].parent;

      prefix += " )";
      int id = da.traverse (prefix.c_str(), trie_pos, str_pos);
      if (id == -2) break;
    }
  }

public:

  BactClassifier(): userule(false) {};

  void setRule(bool t)
  {
    userule = t;
  }

  bool open (const char *file)
  {
    if (! mmap.open (file)) return false;

    char *ptr = mmap.begin ();
    unsigned int size = 0;
    read_static<unsigned int>(&ptr, size);
    da.setArray (ptr);
    ptr += size;
    read_static<double>(&ptr, bias);
    alpha = (double *)ptr;

    return true;
  }

  double classify (const char *line) 
  {
    result.clear ();
    tree.clear ();
    rules.clear ();
    double r = bias;

    str2node (line, tree);

    for (unsigned int i = 0; i < tree.size(); ++i) {
       int id = da.exactMatchSearch (tree[i].val.key().c_str());
       if (id == -2) continue;
       if (id >= 0) {
	 if (userule) 
	   rules.insert (std::make_pair <std::string, double> (tree[i].val.key(), alpha[id]));
	 result.push_back (id);
       }
       project (tree[i].val.key(), 1, 0, i, 0, 0);
    }

    std::sort (result.begin(), result.end());
    result.erase (std::unique (result.begin(), result.end()), result.end());

    for (unsigned int i = 0; i < result.size(); ++i) r += alpha[result[i]];

    return r;
  }

  std::ostream &printRules (std::ostream &os) 
  {
    std::vector <std::pair <std::string, double> > tmp;

    for (std::map <std::string, double>::iterator it = rules.begin();
	 it != rules.end(); ++it) 
      tmp.push_back (std::make_pair <std::string, double> (it->first,  it->second));
      
    std::sort (tmp.begin(), tmp.end(), pair_2nd_cmp<std::string, double>());

    os << "rule: " << bias << " __DFAULT__" << std::endl;

    for (std::vector <std::pair <std::string, double> >::iterator it = tmp.begin();
	 it != tmp.end(); ++it) 
      os << "rule: " << it->second << " " << it->first << std::endl;

    return os;
  }
};

#define OPT " testdata modelfile"

int main (int argc, char **argv)
{
  std::string modelfile = "";
  std::istream *is = 0;
  unsigned int verbose = 0;

  int opt;
  while ((opt = getopt(argc, argv, "v:")) != -1) {
    switch(opt) {
     case 'v':
       verbose = atoi (optarg);
       break;
    default:
      std::cout << "Usage: " << argv[0] << OPT << std::endl;
      return -1;
    }
  }

  if (argc < 3) {
    std::cout << "Usage: " << argv[0] << OPT << std::endl;
    return -1;
  }

  if (! strcmp (argv[argc - 2], "-")) {
    is = &std::cin;
  } else {
    is = new std::ifstream (argv[argc - 2]);
    if (! *is) {
      std::cerr << argv[0] << " " << argv[argc-2] << " No such file or directory" << std::endl;
      return -1;
    }
  }

  BactClassifier bc;
   
  if (verbose >= 3) bc.setRule (true);
  
  if (! bc.open (argv[argc-1])) {
    std::cerr << argv[0] << " " << argv[argc-1] << " No such file or directory" << std::endl;
    return -1;
  }

  std::string line;
  char *column[4];
  unsigned int all = 0;
  unsigned int correct = 0;
  unsigned int res_a = 0;
  unsigned int res_b = 0;
  unsigned int res_c = 0;
  unsigned int res_d = 0;
  
  while (std::getline (*is, line)) {

    if (line[0] == '\0' || line[0] == ';') continue;
    if (2 != tokenize ((char *)line.c_str(), "\t ", column, 2)) {
      std::cerr << "Format Error: " << line << std::endl;
      return -1;
    }

    int y = atoi (column[0]);
    double dist = bc.classify (column[1]);

    if (verbose == 1) {
      std::cout << y << " " << dist << std::endl;
    } else if (verbose == 2) {
      std::cout << y << " " << dist << " " << column[1] << std::endl;
    } else if (verbose >= 3) {
      std::cout << "<instance>" << std::endl;
      std::cout << y << " " << dist << " " << column[1] << std::endl;
      bc.printRules (std::cout);
      std::cout << "</instance>" << std::endl;
    }

    all++;
    if (dist > 0) {
      if(y > 0) correct++;
      if(y > 0) res_a++; else res_b++;
    } else {
      if(y < 0) correct++;
      if(y > 0) res_c++; else res_d++;
    }
  }

   double prec = 1.0 * res_a/(res_a + res_b);
   double rec  = 1.0 * res_a/(res_a + res_c); 

   std::printf ("Accuracy:   %.5f%% (%d/%d)\n", 100.0 * correct/all , correct, all);
   std::printf ("Precision:  %.5f%% (%d/%d)\n", 100.0 * prec,  res_a, res_a + res_b);
   std::printf ("Recall:     %.5f%% (%d/%d)\n", 100.0 * rec, res_a, res_a + res_c);
   std::printf ("F1:         %.5f%%\n",         100.0 * 2 * rec * prec / (prec+rec));
   std::printf ("System/Answer p/p p/n n/p n/n: %d %d %d %d\n", res_a,res_b,res_c,res_d);

   if (is != &std::cin) delete is;

   return 0;
}
