#include "atlas_misc.h"
#include "atlas_amm.h"
#include Mstr(Mjoin(AMM_PRE,_sum.h))

static void ATL_ammmNK
(
   amminfo_t *mminf,
   const int mb,
   const int nmu,
   ATL_CSZT nfnblks,         /* FLOOR(N/nb) */
   const int nbL,            /* nbL=N%nb */
   const int nnuL,
   ATL_CSZT nfkblks,         /* FLOOR(K/kb) */
   const int kb0,
   const int KB0,
   const TYPE *A,
   const size_t lda,
   const size_t incAk0,      /* 0: no need to copy A, else incK for cpying A */
   const TYPE *B,
   const size_t ldb,
   const size_t incBk,       /* 0: no need to copy B, else incK for cpying A */
   const size_t incBn,
   TYPE *C,
   const size_t ldc,
   const size_t incCn,
   TYPE *a,
   ATL_CINT inca,      /* size of blocks of A, or 0 to reuse space */
   TYPE *b,
   ATL_CINT incb,      /* size of blocks of B, or 0 to reuse space */
   ATL_CSZT incbn,
   TYPE *rC,
   TYPE *iC,
   const SCALAR alpA,
   const SCALAR alpB,
   const SCALAR alpC,
   const SCALAR beta
)
{
   ATL_SZT j, incAk=incAk0;
   const int nu=mminf->nu, nb=mminf->nb, nnu=(nb+nu-1)/nu;
   const int kb=mminf->kb;
   ablk2cmat_t blk2c=mminf->Cblk2cm;

   for (j=0; j < nfnblks; j++, B += incBn, C += incCn, b += incbn)
   {
      Mjoin(PATL,ammmK)(mminf, mb, nmu, nb, nnu, nfkblks, kb, kb0, KB0,
                        A, lda, incAk, B, ldb, incBk, blk2c, C, ldc,
                        a, inca, b, incb, rC, iC, alpA, alpB, alpC, beta);
      incAk = 0; /* reuse A for rest of C row panel */
   }
   if (nbL)
      Mjoin(PATL,ammmK)(mminf, mb, nmu, nbL, nnuL, nfkblks, kb, kb0, KB0,
                        A, lda, incAk, B, ldb, incBk, blk2c, C, ldc,
                        a, inca, b, incb, rC, iC, alpA, alpB, alpC, beta);
}
static void ATL_ammmMNK
(
   amminfo_t *mminf,
   enum ATLAS_TRANS TA,
   enum ATLAS_TRANS TB,
   ATL_CSZT M,
   ATL_CSZT N,
   ATL_CSZT K,
   const TYPE *A,
   ATL_CSZT lda,
   const TYPE *B,
   ATL_CSZT ldb,
   TYPE *C,
   ATL_CSZT ldc,
   ATL_SZT nnblks,            /* CEIL(N/nb) */
   ATL_SZT nkblks,            /* CEIL(K/kb) */
   TYPE *a,
   TYPE *b,
   TYPE *c,
   const SCALAR alpA,
   const SCALAR alpB,
   const SCALAR alpC,
   const SCALAR beta
)
{
   const int mb=mminf->mb, nb=mminf->nb, kb=mminf->kb;
   const int mu=mminf->mu, nu=mminf->nu, ku=mminf->ku;
   const int nmu=(mb+mu-1)/mu, nnu=(nb+nu-1)/nu;
   const int inca=mb*kb, incb=kb*nb;
   int nbL=0, nnuL=0, mbF=mb, nmuF=nmu, KB0, kb0;
   ATL_CSZT incbn = nkblks*(incb SHIFT), incCn = nb*(ldc SHIFT);
   ATL_SZT j, incAm, incAk, incBk, incBn, incAmF;
   ATL_SZT nmblks = (M-1)/mb;
   #ifdef TCPLX
      TYPE *iC=c, *rC=c+mb*nb;
      const int mb2=mb+mb;
   #else
      #define rC c
      #define iC c
      #define mb2 mb
   #endif

   j = nnblks*nb;
   if (j != N)
   {
      nnblks--;
      nbL = N - j + nb;
      nnuL = (nbL+nu-1)/nu;
   }
   j = nmblks*mb;
   mbF = M - j;
   if (mbF != mb)
      nmuF = (mbF+mu-1)/mu;
   nkblks--;
   j = nkblks*kb;
   KB0 = kb0 = K - j;
   #if ATL_AMM_MAXKMAJ > 1
      if (kb0 != kb && ATL_AMMFLG_KMAJOR(mminf->flag))
         KB0 = ((kb0+ku-1)/ku)*ku;
   #endif
   if (IS_COLMAJ(TA))
   {
      incAk = kb*lda;
      incAm = mb;
      incAmF = mbF;
   }
   else
   {
      incAk = kb;
      incAm = mb*lda;
      incAmF = mbF*lda;
   }
   if (IS_COLMAJ(TB))
   {
      incBk = kb;
      incBn = nb*ldb;
   }
   else
   {
      incBk = kb*ldb;
      incBn = nb;
   }
   #ifdef TCPLX
      incAk += incAk;
      incAm += incAm;
      incBk += incBk;
      incBn += incBn;
      incAmF += incAmF;
   #endif
/*
 * In first mbF-wide panel, we copy all of B into workspace
 */
   ATL_ammmNK(mminf, mbF, nmuF, nnblks, nbL, nnuL, nkblks, kb0, KB0,
              A, lda, incAk, B, ldb, incBk, incBn, C, ldc, incCn,
              a, inca, b, incb, incbn, rC, iC, alpA, alpB, alpC, beta);
   C += (mbF SHIFT);
   A += incAmF;
/*
 * In all other M-panel computation, reuse previously copied B!
 */
   if (nmblks)
   {
      for (j=0; j < nmblks; j++, C += mb2, A += incAm)
         ATL_ammmNK(mminf, mb, nmu, nnblks, nbL, nnuL, nkblks, kb0, KB0,
                    A, lda, incAk, B, ldb, 0, incBn, C, ldc, incCn,
                    a, inca, b, incb, incbn, rC, iC, alpA, alpB, alpC, beta);
   }
}
#ifndef TCPLX
   #undef rC
   #undef iC
   #undef mb2
#endif

int Mjoin(PATL,ammmKMNK)
(
   enum ATLAS_TRANS TA,
   enum ATLAS_TRANS TB,
   ATL_CSZT M,
   ATL_CSZT N,
   ATL_CSZT K,
   const SCALAR alpha,
   const TYPE *A,
   ATL_CSZT lda,
   const TYPE *B,
   ATL_CSZT ldb,
   const SCALAR beta,
   TYPE *C,
   ATL_CSZT ldc
)
{
   #ifdef TCPLX
      const TYPE ONE[2]={ATL_rone,ATL_rzero}, ZERO[2]={ATL_rzero,ATL_rzero};
      const TYPE *alpA=ONE, *alpB=ONE, *alpC=ONE;
   #else
      #define ONE ATL_rone
      TYPE alpA=ATL_rone, alpB=ATL_rone, alpC=ATL_rone;
   #endif
   void *vp=NULL;
   TYPE *a, *b, *c;
   ATL_SZT szA, szB, szC, sz, nnblks, nkblks, nkblksP, k;
   ATL_INT nkP=0;
   int mu, nu, mb, nb, kb, incak, incbk;
   amminfo_t mminf;


   mu = Mjoin(PATL,GetAmmmInfo)(&mminf, TA, TB, M, N, K, alpha, beta);
   if (!mu)
      alpA = alpha;
   else if (mu == 1)
      alpB = alpha;
   else
      alpC = alpha;

   mu = mminf.mu;
   nu = mminf.nu;
   mb = mminf.mb;
   nb = mminf.nb;
   kb = mminf.kb;
   nnblks = (N+nb-1)/nb;
   nkblksP = nkblks = (K+kb-1)/kb;
   nkblksP <<= 1;
   incak = mb*kb;
   incbk = kb*nb;
   szC = mb*nb;
   do
   {
      nkP++;
      nkblksP >>= 1;
      szA = nkblksP*incak;
      szB = nkblksP*nnblks*incbk;
      sz = ATL_MulBySize(szA+szB+szC+(mu+mu)*nu) + 2*ATL_Cachelen;
      if (sz <= ATL_MaxMalloc)
         vp = malloc(sz);
   }
   while (!vp && nkblksP >= 3);
   if (!vp)
      return(1);

   b = ATL_AlignPtr(vp);
   a = b + (szB SHIFT);
   c = a + (szA SHIFT);

   if (nkblksP == nkblks)
      ATL_ammmMNK(&mminf, TA, TB, M, N, K, A, lda, B, ldb, C, ldc, nnblks,
                  nkblks, a, b, c, alpA, alpB, alpC, beta);
   else
   {
      ATL_CSZT KK = nkblksP*kb;
      ATL_SZT incAkp, incBkp;
      incBkp = incAkp = KK SHIFT;
      if (IS_COLMAJ(TA))
         incAkp *= lda;
      if (!IS_COLMAJ(TB))
         incBkp *= ldb;
      for (k=0; k < nkblks; k += nkblksP, A += incAkp, B += incBkp)
      {
         ATL_SZT nk=nkblks-k, kk;
         if (nk > nkblksP)
         {
            kk = KK;
            nk = nkblksP;
         }
         else
            kk = K - k*kb;
         ATL_ammmMNK(&mminf, TA, TB, M, N, kk, A, lda, B, ldb, C, ldc, nnblks,
                     nk, a, b, c, alpA, alpB, alpC, beta);
      }
   }
   free(vp);
   return(0);
}
