#include "atlas_amm.h"
/*
 * This function loops only over M.  Therefore, it updates 1 column-panel
 * of C using a rank-kb update of a column-panel of A * colpan of B
 * If A2blk non-NULL, copy A, else assume already available in a.
 */
void Mjoin(PATL,ammmM) /* C <= beta*C + A*B, B/C nb-wide colpan */
(
   rkinfo_t *rkinf,
   int nb,             /* # of cols to do */
   int nnu,            /* CEIL(NB/nu) */
   ammkern_t amm,      /* amm kern to use for this kb */
   TYPE *a,            /* workspace for A */
   ATL_CINT inca,      /* gap between blocks (>= mb*kb) */
   const TYPE *b,      /* access-major kbXnb workspace for B */
   TYPE *c,            /* access-major mbXnb workspace for C */
   ATL_CINT incc,      /* gap between M blocks (0 or >=mb*nb) */
   const TYPE *A,      /* col/row-major (original) A */
   TYPE *C,            /* col-major (original) C */
   TYPE beta           /* scale factor for C */
)
{
   cm2am_t a2blk=rkinf->a2blk;    /* copy from row/col major A to access-maj */
   ablk2cmat_t blk2C=rkinf->blk2c;/* frm C's access-maj storage to col-maj */
   TYPE *af=a;
   ATL_CINT nfmb=rkinf->nfmb, mb=rkinf->mb, nmu=rkinf->nmu, mbL=rkinf->mbL;
   ATL_CINT kb=rkinf->kb;
   const size_t lda=rkinf->lda, ldc=rkinf->ldc, incAm=rkinf->incAm;
   ATL_INT i;

   for (i=0; i < nfmb; i++)
   {
      TYPE *an = a + inca, *cn=c+incc;

      if (a2blk)
         a2blk(kb, mb, ATL_rone, A, lda, a);
      amm(nmu, nnu, kb, a, b, c, an, b, cn);
      if (blk2C)
      {
         blk2C(mb, nb, ATL_rone, c, beta, C, ldc);
         C += mb;
      }

      c = cn;
      a = an;
      A += incAm;
   }
   if (mbL)
   {
      if (a2blk)
         a2blk(kb, mbL, ATL_rone, A, lda, a);
      amm(rkinf->nmuL, nnu, kb, a, b, c, af, b, c);
      if (blk2C)
         blk2C(mbL, nb, ATL_rone, c, beta, C, ldc);
   }
}

static void Mjoin(PATL,ammmKM)/* computes full answer for one col-panel of C */
(                      /* by looping over both M&K */
   rkinfo_t *rkinf,    /* amm info for kb-width cols */
   rkinfo_t *krinf,    /* amm info for K remainder cols */
   int nb,             /* # of cols to do */
   int nnu,            /* NB/nu */
   TYPE *a,            /* workspace for A */
   ATL_CINT incam,     /* gap between blocks (>= mb*kb) */
   ATL_CINT incak,     /* 0: reuse same col of a, else a M*K in size */
   TYPE *b,            /* access-major kbXnb workspace for B */
   TYPE *c,            /* access-major mbXnb workspace for C */
   ATL_CINT inccm,     /* 0: write to C, else: write only to c */
   const TYPE *A,      /* col/row-major (original) A */
   const TYPE *B,      /* col/row-major (original) B */
   TYPE *C,            /* col-major (original) C */
   enum ATLAS_TRANS TA,
   enum ATLAS_TRANS TB,
   ATL_CSZT M,
   ATL_CSZT N,
   const SCALAR alpha, /* scale factor for B */
   const SCALAR beta   /* scale factor for C */
)
{
   const size_t incAk=rkinf->incAk, incBk=rkinf->incBk;
   const size_t ldb=rkinf->ldb;
   cm2am_t b2blk=rkinf->b2blk;
   ablk2cmat_t blk2c=rkinf->blk2c;
   ammkern_t amm=rkinf->amm_b0, amm_b1=rkinf->amm_b1;
   ATL_CINT idr=krinf->idx;
   ATL_CINT kb=rkinf->kb, kbL=rkinf->kbL, nfkblks=rkinf->nfkb;
   ATL_INT k, nk=rkinf->nfkb, DOPEEL=1;

   ATL_assert(rkinf->idx != -1);  /* don't call w/o more than 1 K block */
/*
 * The last block must be peeled to write C unless the last block is actually
 * a partial K-block, which isn't handled here anyway.  We also must peel and
 * write C if the K-block doesn't use the same C storage, which requires
 * an additional write to the original C to handle.  So, the only time we
 * don't peel an iteration is when the K-cleaner exists, and uses the
 * the same format as the mainline K
 */
   if (kbL == 0 || idr == -1) /* no K-clean or GER/GER2 clean forces peel */
      nk--;
   else if (krinf->mu != rkinf->mu || krinf->nu != rkinf->nu)
      nk--;
   else
      DOPEEL=0;

   rkinf->blk2c = NULL;           /* don't write C out until K loop done */
   for (k=0; k < nk; k++, A += incAk, B += incBk)
   {
      b2blk(kb, nb, alpha, B, ldb, b);
      Mjoin(PATL,ammmM)(rkinf, nb, nnu, amm, a, incam, b, c, inccm, A, C, beta);
      amm = amm_b1;
      a += incak;
   }
   rkinf->blk2c = blk2c;  /* next ammmM call should write to C */
/*
 * If we peeled to write C, do that along with last full M block
 */
   if (DOPEEL)
   {
      b2blk(kb, nb, alpha, B, ldb, b);
      Mjoin(PATL,ammmM)(rkinf, nb, nnu, amm, a, incam, b, c, inccm, A, C, beta);
      a += incak;
      A += incAk;
      B += incBk;
   }
/*
 * Do we have K-cleanup to do?
 */
   if (kbL)
   {
      if (idr >= 0)  /* K cleanup uses gemm kernel */
      {
         krinf->kb = kbL;
         krinf->b2blk(kbL, nb, alpha, B, ldb, b);
         Mjoin(PATL,ammmM)(krinf, nb, nnu, krinf->amm_b1, a, incam, b,
                           c, inccm, A, C, (DOPEEL)?ATL_rone:beta);
      }
/*
 *    If we use GER1/GER2 for cleanup, beta has already been applied above
 */
      else if (kbL == 2)  /* use GER2 to clean up */
         Mjoin(PATL,ammm_rk2)(TA, TB, M, nb, alpha, A, rkinf->lda, B, ldb,
                              ATL_rone, C, rkinf->ldc);
      else /* kbL == 1, use GER1 */
      {
         #ifdef TCPLX
            ATL_CINT incA = (TA==AtlasNoTrans || TA==AtlasConj) ? 1:rkinf->lda;
            ATL_CINT incB = (TB==AtlasNoTrans || TB==AtlasConj) ? ldb:1;
         #else
            ATL_CINT incA = (TA==AtlasNoTrans) ? 1:rkinf->lda;
            ATL_CINT incB = (TB==AtlasNoTrans) ? ldb:1;
         #endif
         Mjoin(PATL,ger)(M, nb, alpha, A, incA, B, incB, C, rkinf->ldc);
      }
   }
}


static int ATL_ammm_rkK
(
   rkinfo_t *krinf,
   ATL_CSZT N,
   const SCALAR alpha,
   const TYPE *A,
   const TYPE *B,
   TYPE *C,
   const SCALAR beta
)
{
   const size_t ldb=krinf->ldb, ldc=krinf->ldc;
   ATL_CINT NB=krinf->nb, mu=krinf->mu, MB=krinf->mb, kb=krinf->kb;
   ATL_CINT mbL=krinf->mbL, nfnb=krinf->nfnb, nnu=krinf->nnu;
   ATL_INT inca, j;
   const size_t incC=NB*ldc, incBn=krinf->incBn;
   size_t szA, szB, szC;
   cm2am_t b2blk = krinf->b2blk;
   TYPE *a, *b, *c;
   void *vp;

   if (N > NB)
   {
      szA = (MB*krinf->nfmb + krinf->MBL)*kb;
      inca = MB*kb;
   }
   else
   {
      szA = MB*kb;
      inca = 0;
   }
   szB = NB*kb;
   szC = MB*NB;
   vp = malloc(3*ATL_Cachelen + ATL_MulBySize(szA+szB+szC));
   ATL_assert(vp);
   if (!vp)
      return(1);
   a = ATL_AlignPtr(vp);
   b = a + szA;
   b = ATL_AlignPtr(b);
   c = b + szB;
   c = ATL_AlignPtr(c);

   if (nfnb)
   {
      b2blk(kb, NB, alpha, B, ldb, b);
      Mjoin(PATL,ammmM)(krinf, NB, nnu, krinf->amm_b0, a, inca, b, c, 0, A,
                        C, beta);
      krinf->a2blk = NULL;
      C += incC;
      B += incBn;
      for (j=1; j < nfnb; j++, C += incC, B += incBn)
      {
         b2blk(kb, NB, alpha, B, ldb, b);
         Mjoin(PATL,ammmM)(krinf, NB, nnu, krinf->amm_b0, a, inca, b, c, 0, A,
                           C, beta);
      }
   }
   if (krinf->nbL)
   {
      b2blk(kb, krinf->nbL, alpha, B, ldb, b);
      Mjoin(PATL,ammmM)(krinf, krinf->nbL, krinf->nnuL, krinf->amm_b0, a, inca,
                        b, c, 0, A, C, beta);
   }
   free(vp);
}

int Mjoin(PATL,ammmNKM)
(
   enum ATLAS_TRANS TA,
   enum ATLAS_TRANS TB,
   ATL_CSZT M,
   ATL_CSZT N,
   ATL_CSZT K,
   const SCALAR alpha,
   const TYPE *A,
   ATL_CSZT lda,
   const TYPE *B,
   ATL_CSZT ldb,
   const SCALAR beta,
   TYPE *C,
   ATL_CSZT ldc
)
{
   rkinfo_t kbinf, krinf;
   size_t szA, szB, szC, incak, inccm, incBn, incC;
   ATL_INT j, mb, nfmb, nb, mbL, kb, kbL, incam, nfnb, nfkb, RCPYA=0;
   ATL_INT nnu, nmblks, nkblks;
   TYPE *a, *b, *c;
   void *vp;
   void Mjoin(PATL,GetBestKBInfo)
      (rkinfo_t*, rkinfo_t*, enum ATLAS_TRANS, enum ATLAS_TRANS,
       ATL_CSZT, ATL_CSZT, ATL_CSZT, size_t, size_t, size_t,
       const SCALAR, const SCALAR);

   if (K < 3)
   {
      Mjoin(PATL,ammm)(TA, TB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
      return(0);
   }
   Mjoin(PATL,GetBestKBInfo)(&kbinf, &krinf, TA, TB, M, N, K, lda, ldb, ldc,
                             alpha, beta);
   if (kbinf.idx == -1)
      return(ATL_ammm_rkK(&krinf, N, alpha, A, B, C, beta));

   mb = kbinf.mb;
   nb = kbinf.nb;
   kb = kbinf.kb;
   mbL = kbinf.mbL;
   kbL = kbinf.kbL;
   nfmb = kbinf.nfmb;
   nfnb = kbinf.nfnb;
   nfkb = kbinf.nfkb;
   nmblks = (mbL) ? nfmb+1 : nfmb;
   nkblks = (kbL) ? nfkb+1 : nfkb;
   nnu = kbinf.nnu;
   incBn = kbinf.incBn;
   incC = nb*ldc;
   if (K > kb)
   {
      inccm = mb*nb;
      szC = (nfmb*mb+kbinf.MBL)*nb;
   }
   else
   {
      szC = mb*nb;
      inccm = 0;
   }
   if (N > nb)
   {
      incam = mb*kb;
      incak = incam*nmblks;
      szA = incak * nkblks;
   }
   else
   {
      incam = incak = 0;
      szA = mb * kb;
   }

   if (krinf.idx >= 0)
      RCPYA = krinf.a2blk != kbinf.a2blk;
   szB = kb*nb;
   vp = malloc(3*ATL_Cachelen + ATL_MulBySize(szA+szB+szC));
   ATL_assert(vp);
   a = ATL_AlignPtr(vp);
   b = a + szA;
   b = ATL_AlignPtr(b);
   c = b + szB;
   c = ATL_AlignPtr(c);
   for (j=0; j < nfnb; j++, C += incC, B += incBn)
      Mjoin(PATL,ammmKM)(&kbinf, &krinf, nb, nnu, a, incam, incak, b,
                         c, inccm, A, B, C, TA, TB, M, N, alpha, beta);
   if (kbinf.nbL)
   {
      Mjoin(PATL,ammmKM)(&kbinf, &krinf, kbinf.nbL, kbinf.nnuL, a, incam, incak,
                         b, c, inccm, A, B, C, TA, TB, M, N, alpha, beta);
   }
   free(vp);
   return(0);
}
