/* Copyright (C) 2008 Xavier Pujol.

This file is part of the fplll Library.

The fplll Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The fplll Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the fplll Library; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */

#ifndef EVALUATOR_H
#define EVALUATOR_H

#include "util.h"

/* Class Evaluator: stores the best solution during the algorithm loop and
   provides information about the solution accuracy */

class Evaluator {
public:
  Evaluator(const FloatMatrix& mu, const FloatVect& rdiag,
    const FloatVect& targetCoord = FloatVect()) :
    newSolFlag(false), inputErrorDefined(false), mu(mu),
    rdiag(rdiag), targetCoord(targetCoord)
  {
    maxDRdiag.resize(rdiag.size());
    maxDMu.resize(rdiag.size());
  }

  virtual ~Evaluator() {};

  /* When the enumeration is finished, returns true if solCoord
     is the shortest vector or false if it is not prooved
     The default implementation always returns false */
  virtual bool certifiedSol();

  /* When the enumeration is finished, returns maxError such that
    n(solCoord) <= (1 + maxError) * lambda_1(L)^2
    where n(solCoord) is the exact value of
    ||solCoord_0 b_0 + ... + solCoord_(d-1) b_(d-1)||^2 */
  virtual bool getMaxError(Float& maxError);

  /* Called when a solution is found during the enumeration
     Input: newSolCoord = coordinates of the solution in Gram-Schmidt basis
     newPartialDist = estimated distance between the solution and the
     orthogonal projection of target on the lattice space
     maxDist = current bound of the algorithm
     Output: maxDist can be decreased */
  virtual void evalSol(const FloatVect& newSolCoord,
    const double& newPartialDist, double& maxDist);

  FloatVect solCoord;           // Coordinates of the solution in the lattice
  bool newSolFlag;              // Set to true when solCoord is updated

  /* To enable error estimation, the caller must set
     inputErrorDefined=true and fill maxDRdiag and maxDMu */
  bool inputErrorDefined;
  FloatVect maxDRdiag, maxDMu;  // Error bounds on input parameters
  Float lastPartialDist;        // Approx. squared norm of the last solution

protected:
  bool getMaxErrorAux(const Float& maxDist, bool boundOnExactVal, Float& maxDE);

  const FloatMatrix& mu;
  const FloatVect& rdiag;
  const FloatVect& targetCoord;
};

/* Class SmartEvaluator: can "sometimes" certify the solution without doing
   exact computations with the integer matrix */

class SmartEvaluator : public Evaluator {
public:
  SmartEvaluator(const FloatMatrix& mu, const FloatVect& rdiag,
    const FloatVect& targetCoord = FloatVect()) :
    Evaluator(mu, rdiag, targetCoord), certified(false) {}
  virtual bool certifiedSol();
  virtual bool getMaxError(Float& maxError);
  virtual void evalSol(const FloatVect& newSolCoord,
    const double& newPartialDist, double& maxDist);

private:
  bool certified;
};

/* Class ExactEvaluator: ensures that the solution is the optimal vector
   (the exact coefficients of the input matrix are needed) */

class ExactEvaluator : public Evaluator {
public:
  ExactEvaluator(const IntMatrix& matrix, const IntVect& target,
    const Float& normFactor, const FloatMatrix& mu, const FloatVect& rdiag,
    const FloatVect& targetCoord = FloatVect()) :
    Evaluator(mu, rdiag, targetCoord), matrix(matrix), target(target),
    normFactor(normFactor) {}

  virtual bool certifiedSol();    // Always returns 'true'
  virtual bool getMaxError(Float& maxError);
  virtual void evalSol(const FloatVect& newSolCoord,
    const double& newPartialDist, double& maxDist);
  inline Integer& getSolDist() {
    return solDist;
  }

private:
  void updateMaxDist(double& maxDist);

  const IntMatrix& matrix;  // matrix of the lattice
  const IntVect& target;    // target for CVP, empty vector for SVP
  const Float& normFactor;  // Normalization factor in Solver (power of 2)
  Integer solDist;          // Exact norm/distance of the last vector
};

#endif
