/*
 * Automatically Tuned Linear Algebra Software v3.11.38
 * Copyright (C) 2015 R. Clint Whaley
 */
#include "atlas_misc.h"
#include "atlas_mmtesttime.h"

static int Mylcm(const int M, const int N)
/*
 * Returns least common multiple (LCM) of two positive integers M & N by
 * computing greatest common divisor (GCD) and using the property that
 * M*N = GCD*LCM.
 */
{
   register int tmp, max, min, gcd=0;

   if (M != N)
   {
      if (M > N) { max = M; min = N; }
      else { max = N; min = M; }
      if (min > 0)  /* undefined for negative numbers */
      {
         do  /* while (min) */
         {
            if ( !(min & 1) ) /* min is even */
            {
               if ( !(max & 1) ) /* max is also even */
               {
                  do
                  {
                     min >>= 1;
                     max >>= 1;
                     gcd++;
                     if (min & 1) goto MinIsOdd;
                  }
                  while ( !(max & 1) );
               }
               do min >>=1 ; while ( !(min & 1) );
            }
/*
 *          Once min is odd, halve max until it too is odd.  Then, use
 *          property that gcd(max, min) = gcd(max, (max-min)/2)
 *          for odd max & min
 */
MinIsOdd:
            if (min != 1)
            {
               do  /* while (max >= min */
               {
                  max -= (max & 1) ? min : 0;
                  max >>= 1;
               }
               while (max >= min);
            }
            else return( (M*N) / (1<<gcd) );
            tmp = max;
            max = min;
            min = tmp;
         }
         while(tmp);
      }
      return( (M*N) / (max<<gcd) );
   }
   else return(M);
}

int NumberBetaFails(FILE *fperr, char pre, int nb, ATL_mmnode_t *p)
{
   const int mu = p->mu, nu = p->nu, ku = p->ku;
   int mb = ((nb+mu-1)/mu)*mu, kb = ((nb+ku-1)/ku)*ku;
   int i, nfail = 0;

   if (kb < p->kbmin)
      kb = p->kbmin;
   if (p->kbmax && kb > p->kbmax)
      kb = p->kbmax;
   nb = ((nb+nu-1)/nu)*nu;
   for (i=(-1); i < 2; i++)
   {
      if (MMKernelFailsTest(pre, mb, nb, kb, i, p))
      {
         if (fperr)
         {
            char *sp;
            fprintf(fperr, "FAIL: B=(%d,%d,%d), rout='%s', genstr='%s'\n",
                    mb, nb, kb, p->rout?p->rout:"NULL",
                    p->genstr?p->genstr:"NULL");
            sp = MMGetTestString(pre, mb, nb, kb, i, p);
            fprintf(fperr,"   '%s'\n", sp);
            free(sp);
         }
         nfail++;
      }
   }
   return(nfail);
}

/*
 * RETURNS: 0 if bcast slower than ld/splat combo
 */
int UseBcast(int flg, int verb, char pre, int kb, int nreg, int VL)
{
   int nu=VL, mu, mb, nb, TEST=flg&1;
   ATL_mmnode_t *mpBC, *mpNO;
   double mfBC, mfNO;

   if (VL < 2)
      return(1);
   mu = (nreg-nu-1) / (nu+1);
   if (mu < 2)
      mu = (nreg - 2) / (nu+1);
   mu = (mu) ? mu : 1;
   mb = ((kb+mu-1)/mu)*mu;
   nb = ((kb+nu-1)/nu)*nu;
   mpBC = MMGetNodeGEN(pre, 0, nb, mu*VL, nu, 1, VL, 0, NULL);
   mpNO = MMGetNodeGEN(pre, 1, nb, mu*VL, nu, 1, VL, 0, NULL);
   printf("TIMING BCAST VS SPLAT MVEC WITH: B=(%d,%d,%d) U=(%d,%d,1)\n",
          mb, nb, kb, mu, nu);
   mfBC = TimeMMKernel(verb, 1, mpBC, pre, mb, nb, kb, 1, 0, -1);
   printf("   BCAST = %.0f MFLOP\n", mfBC);
   mfNO = TimeMMKernel(verb, 1, mpNO, pre, mb, nb, kb, 1, 0, -1);
   printf("   SPLAT = %.0f MFLOP\n", mfNO);
   if (TEST)
   {
      printf("   TESTING . . .");
      assert(!NumberBetaFails(stderr, pre, nb, mpBC));
      printf("  . . . ");
      assert(!NumberBetaFails(stderr, pre, nb, mpNO));
      printf("   PASS!\n");
   }
   KillMMNode(mpBC);
   KillMMNode(mpNO);
   if (mfNO > mfBC)
   {
      printf("VLD/VSPLAT PROVIDES %.4f SPEEDUP\n", mfNO/mfBC);
      return(0);
   }
   printf("VBCAST PROVIDES %.4f SPEEDUP\n", mfBC/mfNO);
   return(1);
}

ATL_mmnode_t *FullSrchMUNU(int flg, int verb, char pre, int nreg, int nb,
                           int VL, int KVEC)
{
   char fn[32];
   ATL_mmnode_t *mmp;
   double mf, mfB=0.0;
   const int CHK=(flg&1), ku = (KVEC) ? VL : 1;
   int n, i, j, mbB, nbB, kbB, muB=1, nuB=1;
   char ch;

   assert(VL < 1000 && nreg < 1000);  /* don't overflow fn len */
   if (VL < 2)
      ch = 'U';
   else
      ch = (KVEC) ? 'K':'M';
   sprintf(fn, "gAMMUR_%c%d_%d.sum", ch, VL, nreg);
   mmp = ReadMMFileWithPath(pre, "res", fn);
   if (mmp)
   {
      MMFillInGenStrings(pre, mmp);
      TimeNegMMKernels(0, verb, 0, mmp, pre, 1, 0, -1);
      WriteMMFileWithPath(pre, "res", fn, mmp);
      return(mmp);
   }
   assert(nb%VL == 0);
   mmp = MMGetNodeGEN(pre, 0, nb, 1, 1, ku, 1, KVEC,
                      DupString("ATL_Xamm_munu.c"));
   mmp->rout[4] = pre;
   mmp->mbB = mmp->nbB = mmp->kbB = nb;
   mmp->vlen = VL;
   if (KVEC)
      mmp->flag |= (1<<MMF_KVEC);
   else if (!UseBcast(flg, verb, pre, nb, nreg, VL))
      mmp->flag |= (1<<MMF_NOBCAST);
/*
 * Try all MU/NU unrollings
 */
   printf("Full search on MUxNU for nb=%d, NREG=%d, VLEN=%d, KVEC=%d\n",
          nb, nreg, VL, KVEC);
   for (i=1; i <= nreg; i++)
   {
      for (j=1; j <= nreg; j++)
      {
         int ONETIME=0;
         int mbu, nbu, mu, nu;
         if (i*j+Mmin(i,j)+1 > nreg)
            continue;
         if (KVEC)  /* vec on K needs mu*nu mult of VLEN */
         {
            if (VL >= nreg)
            {
               ONETIME=1;
               i = VL;
               j = 1;
            }
            else if ((i*j)%VL)
               continue;
            mu = i;
         }
         else /* vect on M dim need mu multiple of VLEN */
            mu = i*VL;
         nu = j;
         mmp->mu = mu;
         mmp->nu = nu;
         if (mmp->genstr)
           free(mmp->genstr);
         mbu = (nb >= mu) ? (nb/mu)*mu : mu;
         nbu = (nb >= nu) ? (nb/nu)*nu : nu;
         mmp->genstr = MMGetGenString(pre, mmp);
         mf = TimeMMKernel(verb, 0, mmp, pre, mbu, nbu, nb, 1, 0, -1);
         printf("   MU=%2d, NU=%2d, MFLOP=%.2f\n", i, j, mf);
         if (CHK)
         {
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 1, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 0, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, -1, mmp));
         }
         if (mf > mfB)
         {
            mbB = mbu;
            nbB = nbu;
            kbB = nb;
            muB = mu;
            nuB = nu;
            mfB = mf;
         }
         if (ONETIME)
            goto DONE;
      }
   }
DONE:
   assert(mfB > 0.0);
   i = FLAG_IS_SET(mmp->flag, MMF_NOBCAST);
   KillMMNode(mmp);
   mmp = MMGetNodeGEN(pre, i, nb, muB, nuB, ku, VL, KVEC, NULL);
   mmp->mbB = mbB;
   mmp->nbB = nbB;
   mmp->kbB = kbB;
   mmp->mflop[0] = mfB;
   printf("BEST FULL-SEARCH CASE IS B=(%d,%d,%d), U=(%d,%d) MFLOP=%.2f\n\n",
          mbB, nbB, kbB, muB, nuB, mfB);
   WriteMMFileWithPath(pre, "res", fn, mmp);
   return(mmp);
}

ATL_mmnode_t *SrchNU(int flg, int verb, char pre, int nreg, int nb, int VL,
                     int I)
/*
 * M-vectorized search for with mu=I*VLEN.  It allows us to find a case
 * that can handle smaller blocks with VLEN is long.
 */
{
   char fn[32];
   ATL_mmnode_t *mmp;
   double mf, mfB=0.0;
   const int CHK=(flg&1), mu = I*VL;
   int n, i, j, mbB, nbB, kbB, nuB=1, mbu;

   sprintf(fn, "gAMMUR_MU%d_M%d_%d.sum", I, VL, nreg);
   mmp = ReadMMFileWithPath(pre, "res", fn);
   if (mmp)
   {
      MMFillInGenStrings(pre, mmp);
      TimeNegMMKernels(0, verb, 0, mmp, pre, 1, 0, -1);
      WriteMMFileWithPath(pre, "res", fn, mmp);
      return(mmp);
   }
   mmp = MMGetNodeGEN(pre, 0, nb, 1, 1, 1, 1, 0, DupString("ATL_Xamm_munu.c"));
   mmp->rout[4] = pre;
   mmp->mbB = mmp->nbB = mmp->kbB = nb;
   mmp->vlen = VL;
   mmp->mu = mu;
   mbu = (nb >= mu) ? (nb/mu)*mu : mu;
   if (!UseBcast(flg, verb, pre, nb, nreg, VL))
      mmp->flag |= (1<<MMF_NOBCAST);
/*
 * Try all powers of 2 MU/NU unrollings
 */
   printf("Searching M-VEC MU=%d xNU case for mb=%d, kb=%d, NREG=%d, VLEN=%d\n",
          I, mbu, nb, nreg, VL);
   for (j=1; j <= nreg; j++)
   {
      int nbu, nu;
      if (I*j+Mmin(I,j)+1 > nreg)
         continue;
      nu = j;
      mmp->nu = nu;
      if (mmp->genstr)
        free(mmp->genstr);
      nbu = (nb >= nu) ? (nb/nu)*nu : nu;
      mmp->genstr = MMGetGenString(pre, mmp);
      mf = TimeMMKernel(verb, 0, mmp, pre, mbu, nbu, nb, 1, 0, -1);
      printf("   MU=%2d, NU=%2d, MFLOP=%.2f\n", I, j, mf);
      if (CHK)
      {
         assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 1, mmp));
         assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 0, mmp));
         assert(!MMKernelFailsTest(pre, mbu, nbu, nb, -1, mmp));
      }
      if (mf > mfB)
      {
         mbB = mbu;
         nbB = nbu;
         kbB = nb;
         nuB = nu;
         mfB = mf;
      }
   }
   i = FLAG_IS_SET(mmp->flag, MMF_NOBCAST);
   KillMMNode(mmp);
   if (mfB == 0.0)  /* no legal kern found! */
   {
      printf("NO LEGAL KERNS FOR MU=%d!\n", I);
      return(NULL);
   }
   mmp = MMGetNodeGEN(pre, i, nb, mu, nuB, 1, VL, 0, NULL);
   mmp->mbB = mbB;
   mmp->nbB = nbB;
   mmp->kbB = kbB;
   mmp->mflop[0] = mfB;
   assert(mfB > 0.0);
   printf("BEST MU=%d CASE IS B=(%d,%d,%d) U=(%d,%d), MFLOP=%.2f\n\n",
          I, mbB, nbB, kbB, mu, nuB, mfB);
   WriteMMFileWithPath(pre, "res", fn, mmp);
   return(mmp);
}

ATL_mmnode_t *SrchMUNUp2(int flg, int verb, char pre, int nreg, int nb,
                         int VL, int KVEC)
{
   char fn[32];
   ATL_mmnode_t *mmp;
   double mf, mfB=0.0;
   const int CHK=(flg&1), ku = (KVEC) ? VL : 1;
   int n, i, j, mbB, nbB, kbB, muB=1, nuB=1;
   char ch;

   if (VL < 2)
      ch = 'U';
   else
      ch = (KVEC) ? 'K':'M';
   sprintf(fn, "gAMMURP2_%c%d_%d.sum", ch, VL, nreg);
   mmp = ReadMMFileWithPath(pre, "res", fn);
   if (mmp)
   {
      MMFillInGenStrings(pre, mmp);
      TimeNegMMKernels(0, verb, 0, mmp, pre, 1, 0, -1);
      WriteMMFileWithPath(pre, "res", fn, mmp);
      return(mmp);
   }
   mmp = MMGetNodeGEN(pre, 0, nb, 1, 1, ku, 1, KVEC,
                      DupString("ATL_Xamm_munu.c"));
   mmp->rout[4] = pre;
   mmp->mbB = mmp->nbB = mmp->kbB = nb;
   mmp->vlen = VL;
   if (KVEC)
      mmp->flag |= (1<<MMF_KVEC);
   else if (!UseBcast(flg, verb, pre, nb, nreg, VL))
      mmp->flag |= (1<<MMF_NOBCAST);
/*
 * Try all powers of 2 MU/NU unrollings
 */
   printf("Searching PWR-2 MUxNU cases for nb=%d, NREG=%d, VLEN=%d, KVEC=%d\n",
          nb, nreg, VL, KVEC);
/*
 * Long VLEN with small NREG can be impossible to do with mu*nu%VLEN==0,
 * so force at least one case even if it overruns registers
 */
   for (i=1; i <= nreg; i += i)
   {
      for (j=1; j <= nreg; j += j)
      {
         int ONETIME=0;
         int mbu, nbu, mu, nu;
         if (i*j+Mmin(i,j)+1 > nreg)
            continue;
         if (KVEC)  /* vec on K needs mu*nu mult of VLEN */
         {
            if (VL >= nreg)
            {
               ONETIME = 1;
               i=VL;
               j=1;
            }
            else if ((i*j)%VL)
               continue;
            mu = i;
         }
         else /* vect on M dim need mu multiple of VLEN */
            mu = i*VL;
         nu = j;
         mmp->mu = mu;
         mmp->nu = nu;
         if (mmp->genstr)
           free(mmp->genstr);
         mbu = (nb >= mu) ? (nb/mu)*mu : mu;
         nbu = (nb >= nu) ? (nb/nu)*nu : nu;
         mmp->genstr = MMGetGenString(pre, mmp);
         mf = TimeMMKernel(verb, 0, mmp, pre, mbu, nbu, nb, 1, 0, -1);
         printf("   MU=%2d, NU=%2d, MFLOP=%.2f\n", i, j, mf);
         if (CHK)
         {
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 1, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 0, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, -1, mmp));
         }
         if (mf > mfB)
         {
            mbB = mbu;
            nbB = nbu;
            kbB = nb;
            muB = mu;
            nuB = nu;
            mfB = mf;
         }
         if (ONETIME)
            goto DONE;
      }
   }
DONE:
   assert(mfB != 0.0);
   i = FLAG_IS_SET(mmp->flag, MMF_NOBCAST);
   KillMMNode(mmp);
   mmp = MMGetNodeGEN(pre, i, nb, muB, nuB, ku, VL, KVEC, NULL);
   mmp->mbB = mbB;
   mmp->nbB = nbB;
   mmp->kbB = kbB;
   mmp->mflop[0] = mfB;
   printf("BEST POW2 CASE IS B=(%d,%d,%d) U=(%d,%d), MFLOP=%.2f\n\n",
          mbB, nbB, kbB, muB, nuB, mfB);
   WriteMMFileWithPath(pre, "res", fn, mmp);
   return(mmp);
}

ATL_mmnode_t *SrchMUNU(int flg, int verb, char pre, int nreg, int nb,
                       int VL, int KVEC)
{
   ATL_mmnode_t *mmp, *mmp2;
   char fn[32];
   double mf, mfB=0.0;
   const int CHK=(flg&1), ku = (KVEC) ? VL : 1;
   #if (defined(ATL_GAS_x8664) || defined(ATL_GAS_x8632)) && !defined(ATL_AVX)
      int DO1D=1;
   #else
      int DO1D=(nreg < 9 || nreg < VL);
   #endif
   int n, i, j, kb, mbB, nbB, kbB, muB=1, nuB=1;
   char ch;

   if (flg&2)
      return(FullSrchMUNU(flg, verb, pre, nreg, nb, VL, KVEC));
   mmp2 = SrchMUNUp2(flg, verb, pre, nreg, nb, VL, KVEC);
   assert(VL < 1000 && nreg < 1000);  /* don't overflow fn len */
   if (VL < 2)
      ch = 'U';
   else
      ch = (KVEC) ? 'K':'M';
   sprintf(fn, "gAMMUR_%c%d_%d.sum", ch, VL, nreg);
   mmp = ReadMMFileWithPath(pre, "res", fn);
   if (mmp)
   {
      KillAllMMNodes(mmp2);
      MMFillInGenStrings(pre, mmp);
      TimeNegMMKernels(0, verb, 0, mmp, pre, 1, 0, -1);
      WriteMMFileWithPath(pre, "res", fn, mmp);
      return(mmp);
   }
   mmp = MMGetNodeGEN(pre, 0, nb, 1, 1, ku, 1, KVEC,
                      DupString("ATL_Xamm_munu.c"));
   mmp->rout[4] = pre;
   mmp->mbB = mmp->nbB = mmp->kbB = nb;
   mmp->vlen = VL;
   if (KVEC)
      mmp->flag |= (1<<MMF_KVEC);
   else if (!UseBcast(flg, verb, pre, nb, nreg, VL))
      mmp->flag |= (1<<MMF_NOBCAST);
/*
 * Try all near-square register blocking cases
 */
   printf("Finding best MUxNU case for nb=%d, NREG=%d, VLEN=%d, KVEC=%d\n",
          nb, nreg, VL, KVEC);
   for (n=4; n <= nreg; n++)
   {
      int ONETIME=0;
      int mbu, nbu, mu, nu;
      for (j=1; j*j < n; j++);
      i = n / j;
      if (nb%i || nb%j)
         continue;
      if (KVEC)
      {
         if (VL >= nreg)
         {
            ONETIME = 1;
            i = VL;
            j = 1;
         }
         else if ((i*j)%VL)
            continue;
         mu = i;
         nu = j;
      }
      else /* vectorized along M dimension */
      {
         mu = i * VL;
         nu = j;
      }
      mmp->mu = mu;
      mmp->nu = nu;
      if (mmp->genstr)
        free(mmp->genstr);
      mbu = (nb >= mu) ? (nb/mu)*mu : mu;
      nbu = (nb >= nu) ? (nb/nu)*nu : nu;
      mmp->genstr = MMGetGenString(pre, mmp);
      mf = TimeMMKernel(verb, 0, mmp, pre, mbu, nbu, nb, 1, 0, -1);
      printf("   MU=%2d, NU=%2d, MFLOP=%.2f\n", i, j, mf);
      if (CHK)
      {
         assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 1, mmp));
         assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 0, mmp));
         assert(!MMKernelFailsTest(pre, mbu, nbu, nb, -1, mmp));
      }
      if (mf > mfB)
      {
         mbB = mbu;
         nbB = nbu;
         kbB = nb;
         muB = mu;
         nuB = nu;
         mfB = mf;
      }
      if (ONETIME)
         break;
   }
/*
 * For non-AVX x86, try 1-D cases since they are 2-operand assemblies; always
 * try 1-D for low registers
 */
   if (DO1D)
   {
      printf("BEST NEAR-SQUARE CASE IS MU=%d, NU=%d, MFLOP=%.2f\n\n",
             muB, nuB, mfB);
      printf("Finding best 1-D outer loop unrolling for nb=%d\n", nb);
      for (n=2; n <= nreg; n++)
      {
         int mbu, nbu, mu, nu;
         i = 1; j = n;
         if (nb % n)
            continue;
         mu = mmp->mu = i*VL;
         nu = mmp->nu = j;
         if (mmp->genstr)
           free(mmp->genstr);
         mmp->genstr = MMGetGenString(pre, mmp);
         mbu = (nb >= mu) ? (nb/mu)*mu : mu;
         nbu = (nb >= nu) ? (nb/nu)*nu : nu;
         mf = TimeMMKernel(verb, 0, mmp, pre, mbu, nbu, nb, 1, 0, -1);
         printf("   MU=%2d, NU=%2d, MFLOP=%.2f\n", i, j, mf);
         if (CHK)
         {
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 1, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 0, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, -1, mmp));
         }
         if (mf > mfB)
         {
            muB = i;
            nuB = j;
            mfB = mf;
         }
         i = n; j = 1;
         mu = mmp->mu = i * VL;
         nu = mmp->nu = j;
         mbu = (nb >= mu) ? (nb/mu)*mu : mu;
         nbu = (nb >= nu) ? (nb/nu)*nu : nu;
         if (mmp->genstr)
           free(mmp->genstr);
         mmp->genstr = MMGetGenString(pre, mmp);
         mf = TimeMMKernel(verb, 1, mmp, pre, mbu, nbu, nb, 1, 0, -1);
         printf("   MU=%2d, NU=%2d, MFLOP=%.2f\n", i, j, mf);
         if (CHK)
         {
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 1, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, 0, mmp));
            assert(!MMKernelFailsTest(pre, mbu, nbu, nb, -1, mmp));
         }
         if (mf > mfB)
         {
            muB = i;
            nuB = j;
            mfB = mf;
         }
      }
   }

   assert(mfB > 0.0);
   i = FLAG_IS_SET(mmp->flag, MMF_NOBCAST);
   KillMMNode(mmp);
   mmp = MMGetNodeGEN(pre, i, nb, muB, nuB, ku, VL, KVEC, NULL);
   mmp->mbB = mbB;
   mmp->nbB = nbB;
   mmp->kbB = kbB;
   mmp->mflop[0] = mfB;
   if (mmp->mflop[0] < mmp2->mflop[0])
   {
      printf("Taking pow2 srch (%d,%d:%.0f) over square (%d,%d:%.0f)\n",
             mmp2->mu, mmp2->nu, mmp2->mflop[0],
             mmp->mu, mmp->nu, mmp->mflop[0]);
      KillMMNode(mmp);
      mmp = mmp2;
   }
   else
      KillAllMMNodes(mmp2);
   WriteMMFileWithPath(pre, "res", fn, mmp);
   printf("BEST CASE IS B=(%d,%d,%d), U=(%d,%d), MFLOP=%.2f\n\n",
          mmp->mbB, mmp->nbB, mmp->kbB, mmp->mu, mmp->nu, mmp->mflop[0]);
   return(mmp);
}

#if 0
/* this idea doesn't really work */
int FindKvecXover(int flg, int verb, char pre, int nreg, int VL, int nb)
/*
 * On some machines, vectorizing the M dim will win for small problems,
 * due to not needing to the summation at the end of the K-loop.  However,
 * once this cost is dominated by the K-loop, K dim vectorization can
 * start to win, possibly by reducing the C write traffic, as well as
 * avoiding vec bcast inside the K-loop.
 *
 * For finding this crossover, we use best pw2 M/K, since we can be sure
 * they can always use a pwr2 block factor for direct comparison.  If pwr2
 * cases are much different than normal, this may cause problems!
 * IDEA: Read in normal MUNU results, and don't use this test if gap is wide.
 */
{
   ATL_mmnode_t *mmM, *mmK;
   double mfM, mfK;
   int b, b0;

   mmM = SrchMUNUp2(flg, verb, pre, nreg, nb, VL, 0);
   mmK = SrchMUNUp2(flg, verb, pre, nreg, nb, VL, 1);
   printf("FINDING NB CROSSOVER FOR M- AND K-VECTORIZATION:");
   b = Mmax(mmM->mu, mmK->mu);
   b = Mmax(b, mmM->nu);
   b = Mmax(b, mmK->nu);
   b = Mmax(b,16);
   b0 = b;
   while (b < 512)
   {
      mfM = TimeMMKernel(verb, 0, mmM, pre, b, b, b, 1, 0, -1);
      mfK = TimeMMKernel(verb, 0, mmK, pre, b, b, b, 1, 0, -1);
      printf("   B=%d, mflopM=%.0f, mflopK=%.0f\n", b, mfM, mfK);
      if (mfK > mfM*1.02)
         break;
      b += b;
   }
   if (b == b0)
   {
      printf("K-VECTORIZATION BETTER FROM FIRST BLOCK!\n");
      b = 1;            /* Kvec always better */
   }
   else if (b == 512)
   {
      printf("M-VECTORIZATION ALWAYS BETTER\n");
      b = 0;            /* Kvec never better */
   }
   else
      printf("K-VEC BEGINS WINNING AROUND %d\n", b);
   return(b);
}
#endif

void FindDefMUNU(int flg, int verb, char pre, int nb, int *NREG, int *VLEN)
{
   ATL_mmnode_t *mp, *mmM, *mmK;
   int nreg=(*NREG), VL=(*VLEN), chkNR=0, chkVL=0;

   if (nreg < 1)
      nreg = GetNumVecRegs(pre);
   if (nreg < 1)
   {
      #ifdef ATL_GAS_x8632
         nreg = 8;
      #else
         nreg = 16;
      #endif
      chkNR = 1;
   }
   if (VL < 1)
      VL = GetNativeVLEN(pre);
   if (!VL)
   {
      VL = (pre == 'c' || pre == 's') ? 4:2;
      chkVL = 1;
   }
/*
 * Always do full search for low number of registers, where this is only search
 */
   if (!chkNR && nreg > 0 && nreg < 20)
      flg |= 2;
   mmM = SrchMUNU(flg, verb, pre, nreg, nb, VL, 0);
   mmK = SrchMUNU(flg, verb, pre, nreg, nb, VL, 1);
   printf("MVEC: B=(%d,%d,%d) mu=%d, nu=%d, MFLOP=%.0f\n",
          mmM->mbB,  mmM->nbB,  mmM->kbB, mmM->mu, mmM->nu, mmM->mflop[0]);
   printf("KVEC: B=(%d,%d,%d) mu=%d, nu=%d, MFLOP=%.0f\n",
          mmK->mbB,  mmK->nbB,  mmK->kbB, mmK->mu, mmK->nu, mmK->mflop[0]);
/*
 * After this, fastest code in mmM, slowest mmK
 */
   if (mmK->mflop[0] > mmM->mflop[0])
   {
      mp = mmK;
      mmK = mmM;
      mmM = mp;
   }
   KillAllMMNodes(mmK);
/*
 * If we only guessed a lower bound on # regs, try some searches with
 * increasing regs
 */
   if (chkNR)
   {
      const int KVEC = FLAG_IS_SET(mmM->flag, MMF_KVEC);
      int i, nr = nreg+nreg;
      printf("NREG=%d, U=(%d,%d): MFLOP=%.0f\n",
             nreg, mmM->mu, mmM->nu, mmM->mflop[0]);
      for (i=0; i < 4; i++)  /* sanity check for stopping */
      {
         mp = SrchMUNU(flg, verb, pre, nr, nb, VL, KVEC);
         printf("NREG=%d, U=(%d,%d): MFLOP=%.0f\n",
                nr, mp->mu, mp->nu, mp->mflop[0]);
         if (mp->mu*mp->nu + mp->mu + 1 <= (nr>>1)) /* did not use more regs */
            break;
         if (mp->mflop[0] < 1.03*mmM->mflop[0])     /* perf not better */
            break;
         KillMMNode(mmM);
         mmM = mp;
         nreg = nr;
         nr += nr;
      }
   }
/*
 * Now that we are confident in our NREG, see if we need to confirm VLEN
 */
   if (chkVL)
   {
   }
   *NREG = nreg;
   *VLEN = VL;
#if 0
/*
 * Now see if K-vec has a crossover with M-vec
 */
   FindKvecXover(flg, verb, pre, nreg, VL, nb);
#endif
   KillAllMMNodes(mmM);
}

void PrintUsage(char *name, int ierr, char *flag)
{
   if (ierr > 0)
      fprintf(stderr, "Bad argument #%d: '%s'\n", ierr,
              flag?flag:"OUT-OF_ARGUMENTS");
   else if (ierr < 0)
      fprintf(stderr, "ERROR: %s\n", flag);
   fprintf(stderr, "USAGE: %s [flags:\n", name);
   fprintf(stderr, "   -p [s,d,c,z]: set type/precision prefix (d) \n");
   fprintf(stderr, "   -r <nreg> : set max # of registers to try\n");
   fprintf(stderr, "   -V <vlen> : force vector length\n");
   fprintf(stderr, "   -b <nb>   : set initial block factor to try\n");
   fprintf(stderr, "   -v <verb> : set verbosity (1)\n");
   fprintf(stderr, "   -T 1      : test all legal kerns, and exit\n");
   fprintf(stderr,
           "   -f <flg>  : bitvec for srch control, add vals you want set:\n");
   fprintf(stderr, "        1: test all generated kernels\n");
   fprintf(stderr, "        2: do full MUxNU search\n");
   fprintf(stderr, "        4: print # of regs to res/<pre>nreg\n");
   fprintf(stderr, "      DEFAULT: all bits unset\n");
   exit(ierr ? ierr : -1);
}

void GetFlags(int nargs, char **args, int *FLG, int *VERB, char *PRE, int *NREG,
              int *VLEN, int *NB, int *TEST)
{
   int i, flg=0, nreg=0;
   char pre = 'd';

   *VERB = 0;
   *NB = 120;
   *VLEN = 0;
   *TEST = 0;
   for (i=1; i < nargs; i++)
   {
      if (args[i][0] != '-')
         PrintUsage(args[0], i, args[i]);

      switch(args[i][1])
      {
      case 'p':
        if (++i >= nargs)
            PrintUsage(args[0], i-1, NULL);
        pre = tolower(args[i][0]);
        assert(pre == 's' || pre == 'd' || pre == 'z' || pre == 'c');
        break;
      case 'T':
        if (++i >= nargs)
            PrintUsage(args[0], i-1, NULL);
         *TEST = atoi(args[i]);
         break;
      case 'V':
        if (++i >= nargs)
            PrintUsage(args[0], i-1, NULL);
         *VLEN = atoi(args[i]);
         break;
      case 'f':
        if (++i >= nargs)
            PrintUsage(args[0], i-1, NULL);
         flg = atoi(args[i]);
         break;
      case 'r':
        if (++i >= nargs)
            PrintUsage(args[0], i-1, NULL);
         nreg = atoi(args[i]);
         break;
      case 'v':
        if (++i >= nargs)
            PrintUsage(args[0], i-1, NULL);
         *VERB = atoi(args[i]);
         break;
      case 'b':
        if (++i >= nargs)
            PrintUsage(args[0], i-1, NULL);
         *NB = atoi(args[i]);
         break;
      default:
         PrintUsage(args[0], i, args[i]);
      }
   }
   if (!nreg)
      nreg = GetNumVecRegs(pre);
   *PRE = pre;
   *FLG = flg;
   *NREG = nreg;
}

/*
 * Using best discovered kernel, figure out the largest NB < 512 that
 * gets good performance
 */
int GetMaxNB(int flag, int verb, char pre, ATL_mmnode_t *mp)
{
   int inc, i, bB=0, badrow=0;
   double mf, mfB=0.0;

   printf("FINDING RANGE OF NB FOR GENKERN MU=%d, NU=%d, %cVEC=%d:\n",
          mp->mu, mp->nu, FLAG_IS_SET(mp->flag, MMF_KVEC)?'K':'M', mp->vlen);
   inc = Mylcm(mp->mu, mp->nu);
   inc = Mylcm(inc,mp->ku);
   while (inc < 12)
      inc += inc;

   for (i=inc; i < 512; i += inc)
   {
       mf = TimeMMKernel(verb, 0, mp, pre, i, i, i, 1, 0, -1);
       printf("   NB=%d, mf=%.0f\n", i, mf);
       if (mf > mfB)
       {
          bB=i;
          mfB = mf;
       }
       else badrow++;
       if (badrow > 4)
          break;
   }
   mp->mbB = mp->kbB = mp->nbB = bB;
   mp->mflop[0] = mfB;
   printf("BEST SQUARE NB=%d (%.0f)\n", bB, mfB);
   return(bB+inc-1);
}

void FindInfo(int flag, int verb, char pre, int NB, int *NREG, int *VLEN)
{
   const int WNR=(flag&4);
   int nreg=(*NREG), vlen=(*VLEN);
   ATL_mmnode_t *mp, *mpN;

   if (pre == 'z')
      pre = 'd';
   else if (pre == 'c')
      pre = 's';

   mp = ReadMMFileWithPath(pre, "res", "gmvAMMUR.sum");
   mpN = ReadMMFileWithPath(pre, "res", "gkvAMMUR.sum");
   if (mp && mpN)
   {
      *VLEN = mp->vlen;
      *NREG = mp->ivar;
      MMFillInGenStrings(pre, mp);
      MMFillInGenStrings(pre, mpN);
      TimeNegMMKernels(0, verb, 0, mp, pre, 1, 0, -1);
      TimeNegMMKernels(0, verb, 0, mpN, pre, 1, 0, -1);
      WriteMMFileWithPath(pre, "res", "gmvAMMUR.sum", mp);
      WriteMMFileWithPath(pre, "res", "gkvAMMUR.sum", mpN);
      KillAllMMNodes(mp);
      KillAllMMNodes(mpN);
      if (WNR)
      {
         FILE *fp;
         char fn[12];
         sprintf(fn, "res/%cnreg", pre);
         fp = fopen(fn, "w");
         fprintf(fp, "%d\n", *NREG);
         fclose(fp);
      }
      return;
   }
   else if (mp)
      KillAllMMNodes(mp);
   else if (mpN)
      KillAllMMNodes(mpN);
   FindDefMUNU(flag, verb, pre, NB, &nreg, &vlen);
   printf("\nNREG=%d, VLEN=%d\n", nreg, vlen);
/*
 * With nreg & vlen set, create output with best K- & M- vectorized code
 * in standard names with nreg in mp->ivar
 */
   mp = SrchMUNU(flag, verb, pre, nreg, NB, vlen, 0);
   mpN = SrchMUNUp2(flag, verb, pre, nreg, NB, vlen, 0);
   mpN->next = SrchNU(flag, verb, pre, nreg, NB, vlen, 1);
   if (mpN->next)
   {
      mpN->next->next = SrchNU(flag, verb, pre, nreg, NB, vlen, 2);
      if (mpN->next->next)
          mpN->next->next->next = SrchNU(flag, verb, pre, nreg, NB, vlen, 3);
   }
   mp = AddUniqueMMKernCompList(mp, mpN);
   KillAllMMNodes(mpN);
   mp = ReverseMMQ(mp);
   mp->ivar = nreg;
   WriteMMFileWithPath(pre, "res", "gmvAMMUR.sum", mp);
   KillAllMMNodes(mp);

   mp = SrchMUNU(flag, verb, pre, nreg, NB, vlen, 1);
   mp->ivar = nreg;
   mpN = SrchMUNUp2(flag, verb, pre, nreg, NB, vlen, 1);
   mp = AddUniqueMMKernCompList(mp, mpN);
   KillAllMMNodes(mpN);
   mp = ReverseMMQ(mp);
   WriteMMFileWithPath(pre, "res", "gkvAMMUR.sum", mp);
   KillAllMMNodes(mp);
   if (WNR)
   {
      FILE *fp;
      char fn[12];
      sprintf(fn, "res/%cnreg", pre);
      fp = fopen(fn, "w");
      fprintf(fp, "%d\n", nreg);
      fclose(fp);
   }
   *VLEN = vlen;
   *NREG = nreg;
}
ATL_mmnode_t *GetBestKernVT(char pre, char vt)
{
   ATL_mmnode_t *mp;
   if (vt == 'K')
      mp = ReadMMFileWithPath(pre, "res", "gkvAMMUR.sum");
   else
      mp = ReadMMFileWithPath(pre, "res", "gmvAMMUR.sum");
   assert(mp);
   if (mp->next)
   {
      KillAllMMNodes(mp->next);
      mp->next = NULL;
   }
   return(mp);
}

ATL_mmnode_t *GetBestKern(char pre)
{
   ATL_mmnode_t *mp, *mpB;
   mpB = GetBestKernVT(pre, 'M');
   mp =  GetBestKernVT(pre, 'K');
   if (mp->mflop[0] > mpB->mflop[0])
   {
      KillAllMMNodes(mpB);
      mpB = mp;
   }
   else
      KillAllMMNodes(mp);
   return(mpB);
}


void DoSquare(int flag, int verb, char pre, int nreg, int VL)
{
   ATL_mmnode_t *mp;
   int maxNB;

   mp = GetBestKern(pre);
   maxNB = GetMaxNB(flag, verb, pre, mp);
}

int CountFails(int flg, int verb, char pre, int NB, int nreg, int VL)
{
   int i, j, ntest=0, nfail=0;
   char *frm="%8d %4d %3d %3d %2d   %c %2d  %5d\n";
   FILE *fperr;

   fperr = fopen("res/FAIL.OUT", "w");
   assert(fperr);

   assert(VL);
   assert(nreg);
   assert(NB > 0);
   printf("     NUM   B   MU  NU VL VEC BC  NPASS\n");
   printf("======== ==== === === == === ==  =====\n");
   for (i=1; i < nreg; i++)
   {
      for (j=1; j < nreg; j++)
      {
         ATL_mmnode_t *mp;
         int nf;
         if (i*j+1 > nreg)
            continue;
         mp = MMGetNodeGEN(pre, 0, NB, i*VL, j, 1, VL, 0, NULL);
         nf = NumberBetaFails(fperr, pre, NB, mp);
         printf(frm, ntest, NB, mp->mu, mp->nu, VL, 'M', 0, 3-nf);
         mp = KillMMNode(mp);
         nfail += nf;
         ntest += 3;
         if (j%VL == 0)  /* try m-vec w/o bcast */
         {
            mp = MMGetNodeGEN(pre, 1, NB, i*VL, j, 1, VL, 0, NULL);
            nf = NumberBetaFails(fperr, pre, NB, mp);
            printf(frm, ntest, NB, mp->mu, mp->nu, VL, 'M', 0, 3-nf);
            mp = KillMMNode(mp);
            nfail += nf;
            ntest += 3;
         }
         if ((i*j)%VL == 0) /* try k-vec wt bcast */
         {
            mp = MMGetNodeGEN(pre, 0, NB, i, j, VL, VL, 1, NULL);
            nf = NumberBetaFails(fperr, pre, NB, mp);
            printf(frm, ntest, NB, mp->mu, mp->nu, VL, 'K', 0, 3-nf);
            mp = KillMMNode(mp);
            nfail += nf;
            ntest += 3;
         }
      }
   }
   if (!nfail)
   {
      printf("ALL %d TESTS PASS!\n", ntest);
      fprintf(fperr, "ALL %d TESTS PASS!\n", ntest);
   }
   else
   {
      printf("FAILED %d OF %d TESTS!!\n\n", nfail, ntest);
      fprintf(fperr, "FAILED %d OF %d TESTS!!\n\n", nfail, ntest);
   }
   fclose(fperr);
   return(nfail);
}

int FindBlockingRegions(int flag, int verb, char pre)
/*
 * This routine attempts to find the best prefetch kernel for all problem sizes
 * using timings of the fastest unprefetched kernel found in MU/NU search.
 * We can prefetch or not each of the three operands, and for each we an
 * also use ATL_pfl1 (pref to L1) or ATL_pfl2.  L2 may really be last-lvl cache.
 * We assume there are five operand size ranges of interest:
 * 1. nb <= sqrt(L1sz/5) : can fit 5 blks of A/B in L1.  May be worthwhile
 *    to prefetch next block of A & B to L1 cache
 * 2. nb <= sqrt(L1sz/4) : fit 4 blks, try fetching A or B to L1, other to L2
 * 3. nb <= sqrt(L2sz/6) : pref next A&B blocks to L2
 * 4. nb <= sqrt(L2sz/5) : pref only one of A/B to L2
 * 5. else only do inter-block prefetch
 * RETURNS: maximum NB providing speedup
 */
{
   int maxNB=0;

   return(maxNB);
}
int main(int nargs, char **args)
{
   int flg, verb, nreg, VLEN, NB, TEST;
   char pre;
   GetFlags(nargs, args, &flg, &verb, &pre, &nreg, &VLEN, &NB, &TEST);
   if (TEST)
      return(CountFails(flg, verb, pre, NB, nreg, VLEN));
   FindInfo(flg, verb, pre, NB, &nreg, &VLEN);
   return(0);
}
