#include "atlas_misc.h"
#include "atlas_amm.h"
#include Mstr(Mjoin(AMM_PRE,_sum.h))

static void ATL_ammmMK
(
   amminfo_t *mminf,
   ATL_CSZT nfmblks,         /* FLOOR(M/mb) */
   const int mbL,            /* mbL=M%mb */
   const int nmuL,
   const int nb,
   const int nnu,
   ATL_CSZT nfkblks,         /* FLOOR(K/kb) */
   const int kb0,
   const int KB0,
   const TYPE *A,
   const size_t lda,
   const size_t incAk,       /* 0: no need to copy A, else incK for cpying A */
   const size_t incAm,
   const TYPE *B,
   const size_t ldb,
   const size_t incBk0,      /* 0: no need to copy B, else incK for cpying A */
   TYPE *C,
   const size_t ldc,
   const size_t incCm,
   TYPE *a,
   ATL_CINT inca,      /* size of blocks of A, or 0 to reuse space */
   ATL_CSZT incam,
   TYPE *b,
   ATL_CINT incb,      /* size of blocks of B, or 0 to reuse space */
   TYPE *rC,
   TYPE *iC,
   const SCALAR alpA,
   const SCALAR alpB,
   const SCALAR alpC,
   const SCALAR beta
)
{
   ATL_SZT i, incBk=incBk0;
   const int mu=mminf->mu, mb=mminf->mb, nmu=(mb+mu-1)/mu;
   const int kb=mminf->kb;
   ablk2cmat_t blk2c=mminf->Cblk2cm;

   for (i=0; i < nfmblks; i++, A += incAm, C += incCm, a += incam)
   {
      Mjoin(PATL,ammmK)(mminf, mb, nmu, nb, nnu, nfkblks, kb, kb0, KB0,
                        A, lda, incAk, B, ldb, incBk, blk2c, C, ldc,
                        a, inca, b, incb, rC, iC, alpA, alpB, alpC, beta);
      incBk = 0; /* reuse B for rest of C colum panel */
   }
   if (mbL)
      Mjoin(PATL,ammmK)(mminf, mbL, nmuL, nb, nnu, nfkblks, kb, kb0, KB0,
                        A, lda, incAk, B, ldb, incBk, blk2c, C, ldc,
                        a, inca, b, incb, rC, iC, alpA, alpB, alpC, beta);
}

static void ATL_ammmNMK
(
   amminfo_t *mminf,
   enum ATLAS_TRANS TA,
   enum ATLAS_TRANS TB,
   ATL_CSZT M,
   ATL_CSZT N,
   ATL_CSZT K,
   const TYPE *A,
   ATL_CSZT lda,
   const TYPE *B,
   ATL_CSZT ldb,
   TYPE *C,
   ATL_CSZT ldc,
   ATL_SZT nmblks,            /* CEIL(M/mb) */
   ATL_SZT nkblks,            /* CEIL(K/kb) */
   TYPE *a,
   TYPE *b,
   TYPE *c,
   const SCALAR alpA,
   const SCALAR alpB,
   const SCALAR alpC,
   const SCALAR beta
)
{
   const int mb=mminf->mb, nb=mminf->nb, kb=mminf->kb;
   const int mu=mminf->mu, nu=mminf->nu, ku=mminf->ku;
   const int nmu=(mb+mu-1)/mu, nnu=(nb+nu-1)/nu;
   const int inca=mb*kb, incb=kb*nb;
   int mbL=0, nmuL=0, nbF=nb, nnuF=nnu, KB0, kb0;
   ATL_CSZT incam = nkblks*(inca SHIFT), incCn = nb*(ldc SHIFT);
   ATL_SZT j, incAm, incAk, incBk, incBn, incBnF;
   ATL_SZT nnblks = (N-1)/nb;
   #ifdef TCPLX
      TYPE *iC=c, *rC=c+mb*nb;
      const int mb2=mb+mb;
   #else
      #define rC c
      #define iC c
      #define mb2 mb
   #endif

   j = nmblks*mb;
   if (j != M)
   {
      nmblks--;
      mbL = M - j + mb;
      nmuL = (mbL+mu-1)/mu;
   }
   j = nnblks*nb;
   nbF = N - j;
   if (nbF != nb)
      nnuF = (nbF+nu-1)/nu;
   nkblks--;
   j = nkblks*kb;
   KB0 = kb0 = K - j;
   #if ATL_AMM_MAXKMAJ > 1
      if (kb0 != kb && ATL_AMMFLG_KMAJOR(mminf->flag))
         KB0 = ((kb0+ku-1)/ku)*ku;
   #endif
   if (IS_COLMAJ(TA))
   {
      incAk = kb*lda;
      incAm = mb;
   }
   else
   {
      incAk = kb;
      incAm = mb*lda;
   }
   if (IS_COLMAJ(TB))
   {
      incBk = kb;
      incBn = nb*ldb;
      incBnF = nbF*ldb;
   }
   else
   {
      incBk = kb*ldb;
      incBn = nb;
      incBnF = nbF;
   }
   #ifdef TCPLX
      incAk += incAk;
      incAm += incAm;
      incBk += incBk;
      incBn += incBn;
      incBnF += incBnF;
   #endif
/*
 * In first nbF-wide panel, we copy all of A into workspace
 */
   ATL_ammmMK(mminf, nmblks, mbL, nmuL, nbF, nnuF, nkblks, kb0, KB0,
              A, lda, incAk, incAm, B, ldb, incBk, C, ldc, mb2,
              a, inca, incam, b, incb, rC, iC, alpA, alpB, alpC, beta);
   C += ldc*(nbF SHIFT);
   B += incBnF;
/*
 * In all other N-panel computation, reuse previously copied A!
 */
   if (nnblks)
   {
      incAk = 0;
      for (j=0; j < nnblks; j++, C += incCn, B += incBn)
         ATL_ammmMK(mminf, nmblks, mbL, nmuL, nb, nnu, nkblks, kb0, KB0,
                    A, lda, incAk, incAm, B, ldb, incBk, C, ldc, mb2,
                    a, inca, incam, b, incb, rC, iC, alpA, alpB, alpC, beta);
   }
}
#ifndef TCPLX
   #undef rC
   #undef iC
   #undef mb2
#endif

int Mjoin(PATL,ammmKNMK)
(
   enum ATLAS_TRANS TA,
   enum ATLAS_TRANS TB,
   ATL_CSZT M,
   ATL_CSZT N,
   ATL_CSZT K,
   const SCALAR alpha,
   const TYPE *A,
   ATL_CSZT lda,
   const TYPE *B,
   ATL_CSZT ldb,
   const SCALAR beta,
   TYPE *C,
   ATL_CSZT ldc
)
{
   #ifdef TCPLX
      const TYPE ONE[2]={ATL_rone,ATL_rzero}, ZERO[2]={ATL_rzero,ATL_rzero};
      const TYPE *alpA=ONE, *alpB=ONE, *alpC=ONE;
   #else
      #define ONE ATL_rone
      TYPE alpA=ATL_rone, alpB=ATL_rone, alpC=ATL_rone;
   #endif
   void *vp=NULL;
   TYPE *a, *b, *c;
   ATL_SZT szA, szB, szC, sz, nmblks, nkblks, nkblksP, k;
   ATL_INT nkP=0;
   int mu, nu, mb, nb, kb, incak, incbk;
   amminfo_t mminf;


   mu = Mjoin(PATL,GetAmmmInfo)(&mminf, TA, TB, M, N, K, alpha, beta);
   if (!mu)
      alpA = alpha;
   else if (mu == 1)
      alpB = alpha;
   else
      alpC = alpha;

   mu = mminf.mu;
   nu = mminf.nu;
   mb = mminf.mb;
   nb = mminf.nb;
   kb = mminf.kb;
   nmblks = (M+mb-1)/mb;
   nkblksP = nkblks = (K+kb-1)/kb;
   nkblksP <<= 1;
   incak = mb*kb;
   incbk = kb*nb;
   szC = mb*nb;
   #if 0
   if (nkblks > 1)
   {
      nkP++;
      nkblksP >>= 1;
   }
   #endif
   do
   {
      nkP++;
      nkblksP >>= 1;
      szA = nkblksP*nmblks*incak;
      szB = nkblksP*incbk;
      sz = ATL_MulBySize(szA+szB+szC+(mu+mu)*nu) + 2*ATL_Cachelen;
      if (sz <= ATL_MaxMalloc)
         vp = malloc(sz);
   }
   while (!vp && nkblksP >= 3);
   if (!vp)
      return(1);

   a = ATL_AlignPtr(vp);
   b = a + (szA SHIFT);
   c = b + (szB SHIFT);

   if (nkblksP == nkblks)
      ATL_ammmNMK(&mminf, TA, TB, M, N, K, A, lda, B, ldb, C, ldc, nmblks,
                  nkblks, a, b, c, alpA, alpB, alpC, beta);
   else
   {
      ATL_CSZT KK = nkblksP*kb;
      ATL_SZT incAkp, incBkp;
      incBkp = incAkp = KK SHIFT;
      if (IS_COLMAJ(TA))
         incAkp *= lda;
      if (!IS_COLMAJ(TB))
         incBkp *= ldb;
      for (k=0; k < nkblks; k += nkblksP, A += incAkp, B += incBkp)
      {
         ATL_SZT nk=nkblks-k, kk;
         if (nk > nkblksP)
         {
            kk = KK;
            nk = nkblksP;
         }
         else
            kk = K - k*kb;
         ATL_ammmNMK(&mminf, TA, TB, M, N, kk, A, lda, B, ldb, C, ldc, nmblks,
                     nk, a, b, c, alpA, alpB, alpC, beta);
      }
   }
   free(vp);
   return(0);
}
