#include "atlas_misc.h"
#include Mstr(Mjoin(ATLAS_PRE,sqamm_blk.h))
#include Mstr(Mjoin(ATLAS_PRE,sqamm_ablk2cmat.h))
#include Mstr(Mjoin(ATLAS_PRE,sqamm_cm2am_a1.h))
#include Mstr(Mjoin(ATLAS_PRE,sqamm_cm2am_an.h))
#include Mstr(Mjoin(ATLAS_PRE,sqamm_cm2am_aX.h))
#include Mstr(Mjoin(ATLAS_PRE,sqamm_flag.h))
#include Mstr(Mjoin(ATLAS_PRE,sqamm_kern.h))
#include Mstr(Mjoin(ATLAS_PRE,sqamm_perf.h))

#ifdef ATL_CAMM_MAXINDX
   #define ATL_MAXIDX ATL_CAMM_MAXINDX
#elif defined(ATL_AMM_98IDX) && !defined(ATL_AMM_66IDX)
   #define ATL_MAXIDX ATL_AMM_98IDX
#endif
#ifndef ATL_MAXIDX
   #define ATL_MAXIDX ATL_AMM_NCASES-1
#endif

static INLINE void FillInInfo(amminfo_t *out, int id)
{
   out->IDX = id;
   out->nb = out->kb = out->mb = ATL_AMM_KBs[id];
   out->mu = ATL_AMM_MUs[id];
   out->nu = ATL_AMM_NUs[id];
   out->ku = ATL_AMM_KUs[id];
   out->kbmin = ATL_AMM_KBMINs[id];
   out->flag = ATL_AMM_KFLAG[id];
   out->amm_b0 = ATL_AMM_KERN_b0[id];
   out->amm_b1 = ATL_AMM_KERN_b1[id];
   out->amm_bn = ATL_AMM_KERN_bn[id];
   out->amm_k1_b0 = ATL_AMM_KERN_K1[id];
   out->amm_k1_b1 = ATL_AMM_KERN_K1_b1[id];
   out->amm_k1_bn = ATL_AMM_KERN_K1_bn[id];
}

static INLINE ablk2cmat_t GetBlk2C(int id, int ialp, int ibet)
{
   if (ialp == 1)
   {
      if (ibet == 1)
         return(ATL_AMM_BLK2C_a1_b1[id]);
      else if (!ibet)
         return(ATL_AMM_BLK2C_a1_b0[id]);
      else if (ibet == -1)
         return(ATL_AMM_BLK2C_a1_bn[id]);
      return(ATL_AMM_BLK2C_a1_bX[id]);
   }
   else if (ialp == -1)
   {
      if (ibet == 1)
         return(ATL_AMM_BLK2C_an_b1[id]);
      else if (!ibet)
         return(ATL_AMM_BLK2C_an_b0[id]);
      else if (ibet == -1)
         return(ATL_AMM_BLK2C_an_bn[id]);
      return(ATL_AMM_BLK2C_an_bX[id]);
   }
   else
   {
      if (ibet == 1)
         return(ATL_AMM_BLK2C_aX_b1[id]);
      else if (!ibet)
         return(ATL_AMM_BLK2C_aX_b0[id]);
      else if (ibet == -1)
         return(ATL_AMM_BLK2C_aX_bn[id]);
      return(ATL_AMM_BLK2C_aX_bX[id]);
   }
}

int Mjoin(PATL,GetSyrkInfo)  /* returns nb */
(
   amminfo_t *out,
   int ialp,             /* 1:alpha=1.0, -1:-1.0, else alpha=X */
   enum ATLAS_TRANS TA,
   ATL_CSZT N,           /* size of triangular matrix */
   ATL_CSZT K,           /* K dim of A/A^T */
   int ibet              /* 0:beta=0.0 1:beta=1.0, -1:-1.0, else beta=X */
)
{
   int id=0, nb, ibest=0;
   double timB = ((double)N)*N*K*ATL_sqAMM_TIME[0];
   for (id=0; id < ATL_AMM_NCASES; id++)
   {
      double tim;
      const int nb=ATL_AMM_KBs[id];
      ATL_SZT ndiag, ncblks;
      ATL_CSZT nnblks=N/nb, nkblks=K/nb;
      const int nr=N-nnblks*nb, kr=K-nkblks*nb;
      if (nb+nb > K)
         break;
      tim = nkblks*ATL_sqAMM_TIME[id];
      if (kr)
      {
         int i;
         double d;
         for (i=id; i > 0; i--)         /* find kb closest to kr to */
            if (ATL_AMM_KBs[i] <= kr)   /* estimate K-clean speed */
               break;
         d = nb;
         d /= (double)ATL_AMM_KBs[i];
         tim += d*ATL_sqAMM_TIME[i]*d*d;
      }
      ndiag = (nr) ? nnblks+1 : nnblks;
      ncblks = ((ndiag-1)*ndiag)>>1;
      tim *= (ndiag+ncblks);
      if (tim < timB)
      {
         timB = tim;
         ibest = id;
      }
   }
   id = ibest;

   FillInInfo(out, id);
// printf("IDX=%d, B=%d, U=(%d,%d,%d)\n", id, out->nb, out->mu,out->nu,out->ku);
   nb = out->nb;
   out->Cblk2cm = GetBlk2C(id, 1, ibet);
   out->Cblk2cm_b1 = ATL_AMM_BLK2C_a1_b0[id]; /* _b1 is really _b0 for SYRK! */
   if (TA == AtlasNoTrans)
   {
      out->a2blk = ATL_AMM_AT2BLK_a1[id];
      if (ialp == 1)
         out->b2blk = ATL_AMM_BT2BLK_a1[id];
      else
         out->b2blk = (ialp == -1)?ATL_AMM_BT2BLK_an[id]:ATL_AMM_BT2BLK_aX[id];
   }
   #ifdef TCPLX
   else if (TA == AtlasConj)  /* means HERK, noTrans */
   {
      out->a2blk = ATL_AMM_AT2BLK_a1[id];
      if (ialp == 1)
         out->b2blk = ATL_AMM_BH2BLK_a1[id];
      else
         out->b2blk = (ialp == -1)?ATL_AMM_BH2BLK_an[id]:ATL_AMM_BH2BLK_aX[id];
   }
   else if (TA == AtlasConjTrans)  /* Means HERK, HermTrans */
   {
      out->a2blk = ATL_AMM_AC2BLK_a1[id];
      if (ialp == 1)
         out->b2blk = ATL_AMM_B2BLK_a1[id];
      else
         out->b2blk = (ialp == -1)?ATL_AMM_B2BLK_an[id]:ATL_AMM_B2BLK_aX[id];
   }
   #endif
   else  /* TA == AtlasTrans */
   {
      out->a2blk = ATL_AMM_A2BLK_a1[id];
      if (ialp == 1)
         out->b2blk = ATL_AMM_B2BLK_a1[id];
      else
         out->b2blk = (ialp == -1)?ATL_AMM_B2BLK_an[id]:ATL_AMM_B2BLK_aX[id];
   }
   return(nb);
}
#ifdef TREAL
/*
 * For TRSM, we need a kernel with mb == kb, but nb can differ.
 * Alpha for A will be -1, for B it will 1,  and Cblk2cm will be for beta=alpha,
 * while Cbk2cm_b1 will be for beta=1.0.
 * RETURNS: mb to use, 0 if ATL_trsmKL_rk4 should be called instead.
 */
#include "atlas_ttypes.h"
#include Mstr(Mjoin(ATLAS_PRE,tsamm_perf.h))
int Mjoin(PATL,GetTrsmInfo)
(
   amminfo_t *out,
   int ialp,            /* 0 alpha=0.0, 1:alpha=1.0, -1:-1.0, else alpha=X */
   enum ATLAS_TRANS TA,
   ATL_CSZT M,           /* size of triangular matrix */
   ATL_CSZT N,           /* NRHS */
   const SCALAR beta
)
{
   #define MAXIDX ATL_AMM_NCASES-1
   int ik = MAXIDX;
   int mu, nu, ku, nb, nnblks, mb, nmblks, mb0, ibest, bbest;
   double tslv;   /* predicted time to solve whole prob wt trsmK */
   double tbest;  /* start it at time using nb=4 (ik=0) */

/*
 * First, find speed of trsmK by taking MB closest to M
 */
   if (M >= ATL_AMM_MAXKB)
      tslv = ATL_tsAMM_TIME[MAXIDX];
   else
   {
      int i;
      for (i=MAXIDX; i > 0; i--)
         if (ATL_AMM_KBs[i] <= M)
            break;
      tslv = ATL_tsAMM_TIME[i];
   }
   tslv = ((1.0*M)*M*N) * tslv;
   tbest = tslv;
   ibest = -1;
   bbest = -1;
/*
 * Now loop over all square block factors, and find the best predicted perf
 */
   for (ik=0; ik < ATL_AMM_NCASES; ik++)
   {
      const int kb = ATL_AMM_KBs[ik];
      int mb, ndi, nsq;
      double tim, tfl;
      if (kb+kb >= M)  /* don't use blks leading to little amm */
         break;
      ndi = M/kb;              /* # of full diagonal blocks */
      nsq = ((ndi-1)*ndi)>>1; /* # of full amm blks */
      mb = M - ndi*kb;         /* partial block at beginning */
      tfl = kb;                /* triangular flops are half  */
      tfl = tfl*kb*N;          /* of amm flops for same kb */
/*
 *    Compute time to do full blocks portion of algorithm
 */
      tim = (ndi*tfl)*ATL_tsAMM_TIME[ik] +             /* slv time */
            (nsq*(tfl+tfl))*ATL_sqAMM_TIME[ik];        /* amm time */
      if (mb) /* need to find perf of partial block */
      {
         int i;
         const double pfl=(1.0*mb)*kb*N, mmfl=(2.0*ndi)*ndi*N;

         for (i=ik; i > 0; i--)
            if (ATL_AMM_KBs[ik] <= mb)
               break;
         tim += pfl * ATL_tsAMM_TIME[i];   /* extra slvtime */
         tim += mmfl * ATL_sqAMM_TIME[i];  /* extra mmtime */
      }
//    printf("idx=%d, kb=%d(%d), spdup=%e\n", ik, kb, ATL_AMM_KBs[ik],
//           trk4/tim);
      if (tim <= tbest)
      {
         tbest = tim;
         ibest = ik;
         bbest = kb;
      }
   }
// best=0;   /* FOR TESTING!!!!! */
   if (ibest < 0)
      return(0);
   ik = ibest;

   out->IDX = ik;
   out->mb = mb = out->kb = bbest;
   out->nb = nb = ATL_AMM_NBs[ik];
   out->mu = mu = ATL_AMM_MUs[ik];
   out->nu = nu = ATL_AMM_NUs[ik];
   out->ku = ku = ATL_AMM_KUs[ik];
   out->kbmin = ATL_AMM_KBMINs[ik];
   out->mu = ATL_AMM_MUs[ik];
   out->nu = ATL_AMM_NUs[ik];
   out->ku = ATL_AMM_KUs[ik];
   out->flag = ATL_AMM_KFLAG[ik];
   out->amm_b0 = ATL_AMM_KERN_b0[ik];
   out->amm_b1 = ATL_AMM_KERN_b1[ik];
   out->amm_bn = ATL_AMM_KERN_bn[ik];
   out->amm_k1_b0 = ATL_AMM_KERN_K1[ik];
   out->amm_k1_b1 = ATL_AMM_KERN_K1_b1[ik];
   out->amm_k1_bn = ATL_AMM_KERN_K1_bn[ik];
   out->a2blk = (TA == AtlasNoTrans) ?
                ATL_AMM_AT2BLK_an[ik]:ATL_AMM_A2BLK_an[ik];
   out->b2blk = ATL_AMM_B2BLK_a1[ik];
   if (ialp == 1)
      out->Cblk2cm = ATL_AMM_BLK2C_a1_b1[ik];
   else if (ialp == -1)
      out->Cblk2cm = ATL_AMM_BLK2C_a1_bn[ik];
   else
      out->Cblk2cm = (ialp) ? ATL_AMM_BLK2C_a1_bX[ik]:ATL_AMM_BLK2C_a1_b0[ik];
   out->Cblk2cm_b1 = ATL_AMM_BLK2C_a1_b1[ik];
   out->cm2Cblk = NULL;
   printf("ik=%d, mb=%d(%d), nb=%d, pred spdup=%.2f\n", ik, out->mb, mb,
          out->nb, tslv/tbest);
   return(mb);
}

/*
 * For TRSM, we're going to get our main parellelism from the N dimension
 * We know that mb==kb, but nb is independent.  alpha for A will be -1,
 * and for B it will 1.  Cblk2cm will be beta=0, while Cbk2cm_b1 will be 1.
 * RETURNS: upper bound on useful nthreads to use
 */
#include "atlas_ttypes.h"
int Mjoin(PATL,tGetTrsmInfo)
(
   ATL_ttrsm_amm_t *pd,
   int P,
   enum ATLAS_TRANS TA,
   ATL_CSZT M,
   ATL_CSZT N,
   const SCALAR beta
)
{
   #ifdef ATL_CAMM_MAXINDX
      int ik = ATL_CAMM_MAXINDX;
   #else
      int ik = ATL_AMM_NCASES-1;
   #endif
   int mu, nu, ku, nb, nnblks, mb, nmblks, mb0;

/*
 * Get a KB smaller than M
 */
   for (; ik > 0 && ATL_AMM_KBs[ik] > M; ik--);

/*
 * Find a kernel where mb can be set to kb; we know these exist, since we insist
 * on square problems of moderate size
 */
   for (; ik > 0; ik--)
   {
      mu = ATL_AMM_MUs[ik];
      mb = ATL_AMM_KBs[ik];
      if (mb > M)
         continue;
/*
 *    Any kernel can be used if it can be called with MB = KB
 */
      if ((mb/mu)*mu == mb)       /* it is legal to call wt MB=KB */
         break;                   /* so use this kernel */
/*
 *    KRUNTIME kernels can vary their KB, and thus be made legal
 */
      if (ATL_AMM_KRUNTIME(ATL_AMM_KFLAG[ik]))
      {
         ku = ATL_AMM_KUs[ik];
         ku = ATL_lcm(ku, mu);
         mb = (mb/ku)*ku;
         if (mb)
            break;
      }
   }
// ik=0;   /* FOR TESTING!!!!! */
   pd->mb = mb = ATL_AMM_KBs[ik];
   nb = ATL_AMM_NBs[ik];
   nu = ATL_AMM_NUs[ik];
   if (P*nb > N)
      P = (N+nb-1)/nb;
   pd->nb = nb;
   mu = ATL_AMM_MUs[ik];
   ku = ATL_AMM_KUs[ik];
   pd->nmu = mb / mu;
   pd->nnu = nb / nu;
   nnblks = N / nb;
   pd->nbf = N - nb*nnblks;
   if (!pd->nbf)
   {
      pd->nbf = nb;
      pd->nnuf = pd->nnu;
   }
   else
   {
      nnblks++;
      pd->nnuf = (pd->nbf+nu-1)/nu;
   }
   pd->nnblks = nnblks;
   pd->amm_b0 = ATL_AMM_KERN_b0[ik];
   pd->amm_b1 = ATL_AMM_KERN_b1[ik];
   nmblks = M/mb;
   mb0 = (M - nmblks*mb);
   if (!mb0)
   {
      pd->MB0 = mb0 = mb;
      pd->nmu0 = pd->nmu;
   }
   else
   {
      nmblks++;
      if (ATL_AMM_KMAJOR(ik))
      {
         pd->MB0 = ((mb0+ku-1)/ku)*ku;
         if (!ATL_AMM_KRUNTIME(ik))
            pd->amm_b0 = ATL_AMM_KERN_K1[ik];
      }
      else if (!ATL_AMM_KRUNTIME(ik) || mb0 != (mb0/ku)*ku ||
               mb0 < ATL_AMM_KBMINs[ik])
      {
         pd->amm_b0 = ATL_AMM_KERN_K1[ik];
         pd->MB0 = mb0;
      }
      else
         pd->MB0 = mb0;
      pd->nmu0 = (mb0+mu-1)/mu;
   }
   pd->mb0 = mb0;
   pd->nmblks = nmblks;
   pd->nxblks = nnblks * nmblks;
   if (P > pd->nxblks)
      P = pd->nxblks;
   pd->mu = mu;
   pd->nu = nu;
   pd->ku = ATL_AMM_KUs[ik];
   #ifdef TCPLX
      if (TA == AtlasConjTrans)
         pd->a2blk = ATL_AMM_AC2BLK_an[ik];
      else if (TA == AtlasConj)
         pd->a2blk = ATL_AMM_AH2BLK_an[ik];
      else
   #endif
   pd->a2blk = (TA == AtlasNoTrans) ?
               ATL_AMM_AT2BLK_an[ik] : ATL_AMM_A2BLK_an[ik];
   pd->b2blk = ATL_AMM_B2BLK_a1[ik];
/*
 * beta != 0, because then trsm simply zeros X and returns
 */
   if (SCALAR_IS_ONE(beta))
      pd->blk2c = ATL_AMM_BLK2C_a1_b1[ik];
   else if (SCALAR_IS_NONE(beta))
      pd->blk2c = ATL_AMM_BLK2C_a1_bn[ik];
   else
      pd->blk2c = ATL_AMM_BLK2C_a1_bX[ik];
// printf("ik=%d, mb=%d(%d), nb=%d(%d), MB0=%d\n", ik, pd->mb, pd->mb0, pd->nb, pd->nbf, pd->MB0);
   return(P);
}

ablk2cmat_t Mjoin(PATL,tGetSyammInfo)
(
   amminfo_t *out,
   const int P,          /* scale you want to use */
   enum ATLAS_TRANS TA,
   ATL_CSZT N,
   ATL_CSZT K,
   const SCALAR alpha,
   const SCALAR beta
)
{
   ablk2cmat_t dblk2cmat;
   int ik = ATL_MAXIDX;
   int nb, k;
   if (K < ATL_AMM_MAXKB)
   {
      for (ik=0; ik <= ATL_MAXIDX && ATL_AMM_KBs[ik] < K; ik++);
   }
   nb = ATL_AMM_MBs[ik];
   #ifdef ATL_AMM_66IDX
      while (ik > ATL_AMM_66IDX)
      {
         k = N / nb;
         k = ((k-1)*k)>>1;
         if (k >= P)
            break;
         nb = ATL_AMM_MBs[--ik];
      }
   #endif
   out->IDX = ik;
   nb = Mmax(nb, ATL_AMM_NBs[ik]);
   if (nb > N)
      nb = N;
   out->mu = ATL_AMM_MUs[ik];
   out->nu = ATL_AMM_NUs[ik];
   out->ku = ATL_AMM_KUs[ik];
   k = ATL_lcm(out->mu, out->nu);
   nb = (nb > k) ? (nb/k)*k : k;
   out->nb = out->mb = nb;
   out->kb = ATL_AMM_KBs[ik];
/*   printf("tGetSyAMM, nb=%d, kb=%d, ik=%d\n", nb, out->kb, ik); */
   out->kbmin = ATL_AMM_KBMINs[ik];
   out->flag = ATL_AMM_KFLAG[ik];
   out->amm_b0 = ATL_AMM_KERN_b0[ik];
   out->amm_b1 = ATL_AMM_KERN_b1[ik];
   out->amm_bn = ATL_AMM_KERN_bn[ik];
   out->amm_k1_b0 = ATL_AMM_KERN_K1[ik];
   out->amm_k1_b1 = ATL_AMM_KERN_K1_b1[ik];
   out->amm_k1_bn = ATL_AMM_KERN_K1_bn[ik];
   if (TA == AtlasNoTrans)
   {
      out->a2blk = ATL_AMM_AT2BLK_a1[ik];
      out->b2blk =  ATL_AMM_BT2BLK_a1[ik];
   }
   else
   {
      out->a2blk = ATL_AMM_A2BLK_a1[ik];
      out->b2blk =  ATL_AMM_B2BLK_a1[ik];
   }
   if (SCALAR_IS_ONE(alpha))
   {
      dblk2cmat = ATL_AMM_BLK2C_a1_b0[ik];
      if (SCALAR_IS_ONE(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_a1_b1[ik];
      else if (SCALAR_IS_NONE(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_a1_bn[ik];
      else if (SCALAR_IS_ZERO(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_a1_b0[ik];
      else
         out->Cblk2cm = ATL_AMM_BLK2C_a1_bX[ik];
   }
   else if (SCALAR_IS_NONE(alpha))
   {
      dblk2cmat = ATL_AMM_BLK2C_an_b0[ik];
      if (SCALAR_IS_ONE(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_an_b1[ik];
      else if (SCALAR_IS_NONE(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_an_bn[ik];
      else if (SCALAR_IS_ZERO(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_an_b0[ik];
      else
         out->Cblk2cm = ATL_AMM_BLK2C_an_bX[ik];
   }
   else  /* alpha = X */
   {
      dblk2cmat = ATL_AMM_BLK2C_aX_b0[ik];
      if (SCALAR_IS_ONE(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_aX_b1[ik];
      else if (SCALAR_IS_NONE(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_aX_bn[ik];
      else if (SCALAR_IS_ZERO(beta))
         out->Cblk2cm = ATL_AMM_BLK2C_aX_b0[ik];
      else
         out->Cblk2cm = ATL_AMM_BLK2C_aX_bX[ik];
   }
   return(dblk2cmat);
}
/*
 * returns cblk2c_b0, cblk2c_b1 is in structure
 */
ablk2cmat_t Mjoin(PATL,tGetSyammInfo_K)
(
   amminfo_t *out,
   const int P,          /* scale you want to use */
   enum ATLAS_TRANS TA,
   ATL_CSZT N,
   ATL_CSZT K
)
{
   ablk2cmat_t dblk2cmat;
   int ik = ATL_AMM_NCASES-1;
   int mb, nb, k, mu, nu;

   if (K < ATL_AMM_MAXKB)
      for (ik=0; ik < ATL_AMM_NCASES-1 && ATL_AMM_KBs[ik] < K; ik++);
   out->IDX = ik;
   mu = out->mu = ATL_AMM_MUs[ik];
   nu = out->nu = ATL_AMM_NUs[ik];
   out->ku = ATL_AMM_KUs[ik];
   out->mb = ((N+mu-1)/mu)*mu;
   out->nb = ((N+nu-1)/nu)*nu;
   out->kb = ATL_AMM_KBs[ik];
/*  printf("tGetSyAMM_K, mb=%d, nb=%d, kb=%d, ik=%d\n", mb, nb, out->kb, ik); */
   out->kbmin = ATL_AMM_KBMINs[ik];
   out->flag = ATL_AMM_KFLAG[ik];
   out->amm_b0 = ATL_AMM_KERN_b0[ik];
   out->amm_b1 = ATL_AMM_KERN_b1[ik];
   out->amm_bn = ATL_AMM_KERN_bn[ik];
   out->amm_k1_b0 = ATL_AMM_KERN_K1[ik];
   out->amm_k1_b1 = ATL_AMM_KERN_K1_b1[ik];
   out->amm_k1_bn = ATL_AMM_KERN_K1_bn[ik];
   if (TA == AtlasNoTrans)
   {
      out->a2blk = ATL_AMM_AT2BLK_a1[ik];
      out->b2blk =  ATL_AMM_BT2BLK_a1[ik];
   }
   else
   {
      out->a2blk = ATL_AMM_A2BLK_a1[ik];
      out->b2blk =  ATL_AMM_B2BLK_a1[ik];
   }
   out->Cblk2cm = ATL_AMM_BLK2C_a1_b1[ik];
   return(ATL_AMM_BLK2C_a1_b0[ik]);
}
#endif
