/* Copyright (C) 2005-2008 Damien Stehle.
Copyright (C) 2007 David Cade.

This file is part of the fplll Library.

The fplll Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The fplll Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the fplll Library; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */


#ifndef HEURISTIC_EARLYRED_H
#define HEURISTIC_EARLYRED_H





template<class ZT,class FT> class heuristic_early_red:public heuristic<ZT,FT>
{
protected:
  virtual inline int BabaiCall(int*alpha,int zeros,int kappamax,int var_k,
			       Z_NR<ZT>&ztmp,FP_NR<FT>& tmp, FP_NR<FT>& rtmp,
			       FP_NR<FT>& max,FP_NR<FT>&max2,FP_NR<FT>&max3,
			       int& newvec,int& newvecmax,int n)
  {
    static int d=this->B->GetNumRows();
    if (newvec>newvecmax)
      {
	newvecmax*=2;newvec=0;
#ifdef VERBOSE
	cerr<<"Doing Early-reduction. \n";
#endif
	for (int target=d-1;target>=this->kappa;target--)
	  {
	    //  int var_k2=var_k<=target+1+shift?target+1+shift:var_k;
	    if (Babai (alpha[target], zeros, kappamax, var_k<=n?var_k:n, ztmp,
		       tmp,rtmp,max,max2,max3,target))
	      return this->kappa;
	  }
      }
    else
      if (Babai (alpha[this->kappa], zeros, kappamax, var_k<=n?var_k:n, ztmp,tmp,rtmp,max,max2,max3,this->kappa))
	return this->kappa;
    return 0;
  };
  virtual inline void GSO(int a, int zeros, int kappamax, int n,
			  Z_NR<ZT>& ztmp, FP_NR<FT>& tmp, FP_NR<FT>& rtmp, FP_NR<FT>& max,
			  int aa,int red)
  {
    max.set(0.0);
    for (int j=aa; j<this->kappa; j++)
    {
      if (this->appSP->Get(red,j).is_nan()) 
	{
	  if (!(fpScalarProduct (this->appSP->Get(red,j), 
				 this->appB->GetVec(red), 
				 this->appB->GetVec(j), n)))
	    {
	      ScalarProduct (ztmp, this->B->GetVec(red), this->B->GetVec(j), n);
	      this->appSP->Get(red,j).set_z(ztmp);
	    }
	}
      
      
#ifdef DEBUG
      
      printf("\n          j is %d\n", j);         
      printf("          appSP[%d][%d] is: ", this->kappa, j); 
      this->appSP->Get(this->kappa,j).print(); 
      printf(", which approximates ");
      ScalarProduct (ztmp, this->B->GetVec(this->kappa), this->B->GetVec(j), n);
      printf("\n");
      this->B->print( this->kappa+1, n);
      ztmp.print(); 
      printf("\n          Norm of B[%d]^2: ", j);
      ScalarProduct (ztmp, this->B->GetVec(j), this->B->GetVec(j), n);
      ztmp.print();  
      printf("\n          Norm of B[%d]^2: ",this->kappa);
      ScalarProduct (ztmp, this->B->GetVec(this->kappa), this->B->GetVec(this->kappa), n);
      ztmp.print(); 
      
#endif
      
      if (j > zeros+2)
	{
	  tmp.mul(this->mu->Get(j,zeros+1), this->r->Get(red,zeros+1));
	  rtmp.sub(this->appSP->Get(red,j), tmp);
	  for (int k=zeros+2; k<j-1; k++)
	    {
	      tmp.mul(this->mu->Get(j,k), this->r->Get(red,k));
	      rtmp.sub(rtmp, tmp);
	    }
	  tmp.mul(this->mu->Get(j,j-1), this->r->Get(red,j-1));
	  this->r->Get(red,j).sub(rtmp, tmp);
	}
      else if (j==zeros+2)
	{         
	  tmp.mul(this->mu->Get(j,zeros+1),this->r->Get(red,zeros+1));
	  this->r->Get(red,j).sub(this->appSP->Get(red,j), tmp);
	}
      else this->r->Get(red,j).set(this->appSP->Get(red,j));
      
      this->mu->Get(red,j).div( this->r->Get(red,j), this->r->Get(j,j));
      rtmp.abs(this->mu->Get(red,j));
      if (max.cmp(rtmp)<0) 
	max.set(rtmp);
    }    
  };
public:
  virtual int Babai (int a, int zeros, int kappamax, int n,
		     Z_NR<ZT>& ztmp, FP_NR<FT>& tmp, FP_NR<FT>& rtmp,
		     FP_NR<FT>&max,FP_NR<FT>& max2,FP_NR<FT>&max3,int red)
  {
    int i, j, k, test, aa, sg, ll;
    signed long xx;
    long int expo;

#ifdef DEBUG
    int loops=0;
#endif 

    typename FP_NR<FT>::assoc_int X;


    aa = (a > zeros) ? a : zeros+1;

#ifdef DEBUG
    printf("\nappSP: \n");
    this->appSP->print(kappamax+1, kappamax+1);
    printf("\nr: \n");
    this->r->print(kappamax+1, kappamax+1);
    printf("\n          STARTING BABAI WITH k=%d\n", this->kappa);
    printf("\nappB: \n");
    this->appB->print(kappamax+1, n);
    loops = 0;
    printf("\nmu: \n");
    this->mu->print(kappamax+1, kappamax+1);
    printf ("\n\na is %d, zeros is %d, aa is %d\n", a, zeros, aa);
    this->B->print(kappamax+1, n);
#endif

    ll=0;
    do
      {
	ll++;
	test=0;


#ifdef DEBUG      
	if (loops++ > LOOPS_BABAI) 
	  {         
	    printf("INFINITE LOOP?\n");
	    abort();
	  }
#endif 


	/* ************************************** */
	/* Step2: compute the GSO for stage kappa */
	/* ************************************** */
	max3.set(max2);
	max2.set(max);
	GSO(a,zeros,kappamax,n,ztmp,tmp,rtmp,max,aa,red);

	if(ll>=3)
	  {
	    rtmp.mul_2ui(max2,10);
	    if( max3.is_nan() || max3.cmp(rtmp)<=0)
	      {
#ifdef VERBOSE
		cerr << "unexpected behaviour -> exit\n";
#endif
		return this->kappa;
	      }
#ifdef DEBUG
	    cout << "\nrtmp=";rtmp.print();cout<<"\nmax2=";max2.print();cout<<"\nmax3=";max3.print();
	    cout<<"\n";
#endif
	  
	  }

#ifdef DEBUG
	if (loops <=LOOPS_BABAI)
	  {
	    printf("\nmu :\n");
	    this->mu->print(this->kappa+1, this->kappa+1);
	    printf("\nr :\n");
	    this->r->print(this->kappa+1, this->kappa+1);
	  }
#endif

	/* **************************** */
	/* Step3--5: compute the X_j's  */
	/* **************************** */
      
	for (j=this->kappa-1; j>zeros; j--)
	  {
	    /* test of the relaxed size-reduction condition */
          
	    tmp.abs(this->mu->Get(red,j));

#ifdef DEBUG
	    if (loops<=LOOPS_BABAI)
	      {
		fflush(stdout);
		fflush(stderr);
		printf( "tmp is : "); 
		tmp.print(); printf( "\n");
		fflush(stdout);
		fflush(stderr);
	      }
#endif

	    if ((tmp.cmp(this->halfplus)) > 0) 
	      {
		test = 1; 
		/* we consider separately the case X = +-1 */     
		if ((tmp.cmp(this->onedotfive))<=0 )   
		  {
#ifdef DEBUG
		    printf(" X is pm1\n");
#endif

		    sg =  this->mu->Get(red,j).sgn();
		    if ( sg >=0 )   /* in this case, X is 1 */
		      {
			for (k=zeros+1; k<j; k++)
			  this->mu->Get(red,k).sub( 
					       this->mu->Get(red,k), 
					       this->mu->Get(j,k));
                      
			for (i=0; i<n; i++)    
			  this->B->Get(red,i).sub( 
					      this->B->Get(red,i), 
					      this->B->Get(j,i));      
		      }
                  
		    else          /* otherwise X is -1 */ 
		      {
			for (k=zeros+1; k<j; k++)
			  this->mu->Get(red,k).add( 
					       this->mu->Get(red,k), 
					       this->mu->Get(j,k));
                      
			for (i=0; i<n; i++)    
			  this->B->Get(red,i).add(
					      this->B->Get(red,i), 
					      this->B->Get(j,i));
		      }
		  }
              
		else   /* we must have |X| >= 2 */
		  {
		    tmp.rnd(this->mu->Get(red,j));
                  
		    for (k=zeros+1; k<j; k++)
		      {
			rtmp.mul(tmp, this->mu->Get(j,k));
			this->mu->Get(red,k).sub( this->mu->Get(red,k), rtmp);
		      }
                  

		    if (tmp.exp() < CPU_SIZE-2)   
		      /* X is stored in a long signed int */                

		      {                 
			xx=tmp.get_si();
                      
#ifdef DEBUG
			if (loops<=LOOPS_BABAI)
			  {
			    printf("          xx[%d] is %ld\n", j, xx);
			    printf("          and tmp was ");
			    tmp.print(); printf("\n");
			  }
#endif
                      
			for (i=0; i<n; i++)  
			  {
			    if (xx > 0)
			      {
				this->B->Get(red,i).submul_ui(
							  this->B->Get(j,i), 
							  (unsigned long int) xx);
			      }
			    else
			      {
				this->B->Get(red,i).addmul_ui( 
							  this->B->Get(j,i), 
							  (unsigned long int) -xx);
			      }
			  }
		      }


                  
		    else
		      {
			expo = get_z_exp (X, tmp);
			for (i=0; i<n; i++)  
			  {
			    ztmp.mul_2exp(this->B->Get(j,i), expo); 
			    this->B->Get(red,i).submul(ztmp,X);
			  }
		      }
		  }
	      }
	  }
      
      
	if (test)   /* Anything happened? */
	  {
	    for (i=0 ; i<n; i++)     /* update appB[kappa] */
	      this->appB->Get(red,i).set_z(this->B->Get(red,i));
	    aa = zeros+1;
	    for (i=zeros+1; i<=this->kappa; i++) 
	      this->appSP->Get(red,i).set_nan();
	    for (i=red+1; i<=kappamax; i++) 
	      this->appSP->Get(i,red).set_nan();
	  }
	else if (red!=this->kappa)
	  {
	    for (i=zeros+1; i<=this->kappa; i++) 
	      this->appSP->Get(red,i).set_nan();
	  }	  
	

#ifdef DEBUG
	if (loops<=LOOPS_BABAI)
	  {
	    printf("          test is %d\n", test);
	    printf("\nmu: \n");
	    this->mu->print(this->kappa+1, this->kappa+1);
	    printf("\nr: \n");
	    this->r->print(this->kappa+1, this->kappa+1);
	  }
#endif

      }
    while (test);


    if (red==this->kappa)
      {
	if (this->appSP->Get(this->kappa,this->kappa).is_nan()) 
	  {
	    fpNorm (this->appSP->Get(this->kappa,this->kappa), this->appB->GetVec(this->kappa), n);
	  }
	this->s[zeros+1].set(this->appSP->Get(this->kappa,this->kappa));
	
	
	/* the last s[kappa]=r[kappa,kappa] is computed only if kappa increases */
	for (k=zeros+1; k<this->kappa-1; k++)
	  {
	    tmp.mul( this->mu->Get(this->kappa,k), this->r->Get(this->kappa,k));
	    this->s[k+1].sub(this->s[k], tmp);
	  }
	this->r->Get(this->kappa,this->kappa).set(this->s[this->kappa-1]);
	
      }
    
#ifdef DEBUG
    printf("          Number of loops is %d\n", loops);
#endif
      
    return 0;    
  };
  heuristic_early_red(ZZ_mat<ZT>*B,int precision=0,double eta=0.51,double delta=0.99):heuristic<ZT,FT>::heuristic(B,precision,eta,delta)
  {
  };
};  





#endif


