(* scheduler.sml
 * 2004 Matthew Fluet (mfluet@acm.org)
 *  Ported to MLton threads.
 *)

(* scheduler.sml
 *
 * COPYRIGHT (c) 1995 AT&T Bell Laboratories.
 * COPYRIGHT (c) 1989-1991 John H. Reppy
 *
 * This module implements the scheduling queues and preemption
 * mechanisms.
 *)

structure Scheduler : SCHEDULER =
   struct
      structure Assert = LocalAssert(val assert = false)
      structure GlobalDebug = Debug
      structure Debug = LocalDebug(val debug = false)

      open Critical

      structure Q = ImpQueue 
      structure MT = MLton.Thread
      
      structure TID = ThreadID
      structure SH = SchedulerHooks
      structure G = StableGraph
      
      open StableGraph
      type thread_id = ThreadID.thread_id
      datatype thread = datatype RepTypes.thread
      datatype rdy_thread = datatype RepTypes.rdy_thread

      fun prep (THRD (tid, t)) = RTHRD (tid, MT.prepare (t, ()))
      fun prepVal (THRD (tid, t), v) = RTHRD (tid, MT.prepare (t, v))
      fun prepFn (THRD (tid, t), f) = RTHRD (tid, MT.prepare (MT.prepend (t, f), ()))

      (* the dummy thread Id; this is used when an ID is needed to get
       * the types right
       *)
      val dummyTid = TID.bogus "dummy"
      (* the error thread.  This thread is used to trap attempts to run CML
       * without proper initialization (i.e., via RunCML).  This thread is
       * enqueued by reset.
       *)
      val errorTid = TID.bogus "error"
      fun errorThrd () : unit thread =
	 THRD (errorTid, MT.new (fn () =>
	       (GlobalDebug.sayDebug 
		([fn () => "CML"], fn () => "**** Use RunCML.doit to run CML ****")
		; raise Fail "CML not initialized")))

      local
	 val curTid : thread_id ref = ref dummyTid
      in
	 fun getThreadId (THRD (tid, _)) = tid
	 fun getCurThreadId () = 
	    let
	       val tid = !curTid
	    in
	       tid
	    end
	 fun setCurThreadId tid = 
	    let
	       val () = Assert.assertAtomic' ("Scheduler.setCurThreadId", NONE)
	    in 
	       curTid := tid
	    end
      end
      fun tidMsg () = TID.tidToString (getCurThreadId ())
      fun debug msg = Debug.sayDebug ([atomicMsg, tidMsg], msg)
      fun debug' msg = debug (fn () => msg)

      (* The thread ready queues:
       * rdyQ1 is the primary queue and rdyQ2 is the secondary queue.
       *)
      val rdyQ1 : rdy_thread Q.t = Q.new ()
      and rdyQ2 : rdy_thread Q.t = Q.new ()

      (* enqueue a thread in the primary queue *)
      fun enque1 thrd =
	 (Assert.assertAtomic' ("Scheduler.enque1", NONE)
	  ; Q.enque (rdyQ1, thrd))
      (* enqueue a thread in the secondary queue *)
      fun enque2 thrd =
	 (Assert.assertAtomic' ("Scheduler.enque2", NONE)
	  ; Q.enque (rdyQ2, thrd))
      (* dequeue a thread from the primary queue *)
      fun deque1 () =
	 (Assert.assertAtomic' ("Scheduler.deque1", NONE)
	  ; case Q.deque rdyQ1 of
	       NONE => deque2 ()
	     | SOME thrd => SOME thrd)
      (* dequeue a thread from the secondary queue *)
      and deque2 () =
	 (Assert.assertAtomic' ("Scheduler.deque2", NONE)
	  ; case Q.deque rdyQ2 of
	       NONE => NONE
	     | SOME thrd => SOME thrd)
      (* promote a thread from the secondary queue to the primary queue *)
      fun promote () =
	 (Assert.assertAtomic' ("Scheduler.promote", NONE)
	  ; case deque2 () of
	       NONE => ()
	     | SOME thrd => enque1 thrd)

      fun next () =
	 let
	    val () = Assert.assertAtomic' ("Scheduler.next", NONE)
	    val thrd =
	       case deque1 () of
		  NONE => !SH.pauseHook ()
		| SOME thrd => thrd
	 in
	    thrd
	 end
      fun ready thrd = 
	 let
	    val () = Assert.assertAtomic' ("Scheduler.ready", NONE)
	    val () = enque1 thrd
	 in
	    ()
	 end
      local
	 fun atomicSwitchAux msg f = 
	    (Assert.assertAtomic (fn () => "Scheduler." ^ msg, NONE)
	     ; MT.atomicSwitch (fn t => 
			       let
				  val tid = getCurThreadId ()
				  val () = TID.mark tid
				  val RTHRD (tid',t') = f (THRD (tid, t))
				  val () = setCurThreadId tid'
			       in 
				  t'
			       end))
      in
	 fun atomicSwitch f =
	    atomicSwitchAux "atomicSwitch" f
	 fun switch f =
	    (atomicBegin (); atomicSwitch f)
	 fun atomicSwitchToNext f =
	    atomicSwitchAux "atomicSwitchToNext" (fn (THRD(tid, t)) => (f (THRD(tid, t)); next ())) 
	 fun switchToNext f =
	    (atomicBegin (); atomicSwitchToNext f)
	 fun atomicReadyAndSwitch f =
	    atomicSwitchAux "atomicReadyAndSwitch" (fn thrd => (ready (prep thrd); f ()))
                                                   (*(fn (THRD(tid, t)) => 
                                                     let  val (t'', t') = G.TCopy(t, tid)
                                                     in (ready (prepFn (THRD(tid, t''), fn () => G.schedThread(t', tid))); f ())
                                                     end)*)
	 fun readyAndSwitch f =
	    (atomicBegin (); atomicReadyAndSwitch f)
	 fun atomicReadyAndSwitchToNext f =
	    atomicSwitchAux "atomicReadyAndSwitchToNext" (fn thrd => (ready (prep thrd); f (); next ())) 
                                                      (*   (fn (THRD(tid, t)) => 
                                                           let  val (t'', t') = G.TCopy(t, tid)
                                                            in (ready (prepFn (THRD(tid, t''), fn () => G.schedThread(t', tid))); f (); next ())
                                                            end) *)
	 fun readyAndSwitchToNext f =
	    (atomicBegin (); atomicReadyAndSwitchToNext f)
      end
     
(**ORIG
      val restoreThreads = G.array(1050, (false, (errorTid, MT.prepare(G.COPYTHREAD, ()))))
**)

      fun stabilizeQs(threads, threads2Kill) =
        let val _ = MT.atomicBegin()
	    (* assume the hashTable should be double the size of what to restore *)
	    (* Bad heuristic and should be optimized later. *)
	    val tableSize = 2 * List.length threads + 1
            val _ = G.debug("API - scheduler using HashTable of size=" ^ Int.toString tableSize ^ "\n")
	    val threadTable = HashTable.new (tableSize, TID.hashTid, TID.sameTid)
            val _ = () (*print("i have allocated my array\n")*)
(**ORIG
            val (_,rS) = G.getAndSet(restoreThreads)
**)
            val _ = () (*print("i have set my array\n")*)

	    fun addToRestoreTable (tid, t) = HashTable.put threadTable (tid, RTHRD(tid, t))

            fun processQ(rdyQA) =
              case Q.deque rdyQA
(**ORIG
                of SOME (RTHRD(tid, x)) => (rS(TID.toNum(tid), (true, (tid, x))); processQ(rdyQA))
**)
		of SOME (RTHRD(tid, x)) => (addToRestoreTable (tid, x); processQ(rdyQA))

                 | NONE =>  ()

(**ORIG
            fun checkQ(tid, t) = rS(TID.toNum(tid), (true, (tid, MT.prepare (t, ()))))
*)
	    fun checkQ(tid, t) = addToRestoreTable (tid, MT.prepare (t, ()))

            val _ = processQ(rdyQ1)
	    val _ = processQ(rdyQ2)
            val _ = ()(* print(" i have processed my Qs\n")*)
            val _ = app (fn (tid, x) => checkQ(tid,x)) threads 
            val _ =() (* print("API - about to process queues\n")*)
            val _ = G.debug ("API - about to remove threads2kill\n")
            val _ = G.debug ("API - threads2Kill len=" ^ Int.toString (List.length threads2Kill) ^ "\n")
            val _ = G.debug ("API - threadTable size before=" ^ Int.toString (HashTable.size threadTable) ^ "\n")
            fun rem tid = HashTable.remove threadTable tid
            val _ = List.app (fn tid => rem tid) threads2Kill
            val _ = G.debug ("API - threadTable size after =" ^ Int.toString (HashTable.size threadTable) ^ "\n")

            val _ = G.debug("API - about to enqueue new threads\n")
(**ORIG
            val rdyQ1' = G.fold(G.toVector(restoreThreads), [], fn(x,xs) => if (#1 x)
                                                                    then let val (tid, t) = #2 x
                                                                         in RTHRD(tid, t)::xs
                                                                         end
                                                                  else xs)
         val _ = app (fn thrd => enque1(thrd)) rdyQ1'
**)
	    val rdyQ1' = HashTable.getValues threadTable
            val _ = app (fn thrd => (let val RTHRD (x,_) = thrd in G.debug ("API - enqueue " ^ TID.tidToString x  ^ "\n") end; enque1(thrd))) rdyQ1'

            val _ = G.debug("API - finished scheduling\n")
            val _ = MT.atomicEnd()
        in  ()
        end

      fun new (f : thread_id -> ('a -> unit)) : 'a thread =
	 let
	    val () = Assert.assertAtomic' ("Scheduler.new", NONE)
            val _ = ()(*print "I do not handle this right here\n"*)
	    val tid = TID.new ()
            val _ = G.spawnThread(getCurThreadId (), tid)
	    val t = MT.new (f tid)
	 in
	    THRD (tid, t)
	 end
      (* Same as new except creates a bidi edge *)
      fun new2 f =
	 let
	    val () = Assert.assertAtomic' ("Scheduler.new", NONE)
            val _ = ()(*print "I do not handle this right here\n"*)
	    val tid = TID.new ()
            val _ = G.spawnThread2(getCurThreadId (), tid)
	    val t = MT.new (f tid)
	 in
	    THRD (tid, t)
	 end

      fun prepend (thrd : 'a thread, f : 'b -> 'a) : 'b thread =
	 let
	    val () = Assert.assertAtomic' ("Scheduler.prepend", NONE)
	    val THRD (tid, t) = thrd
	    val t = MT.prepend (t, f)
	 in
	    THRD (tid, t)
	 end

      fun unwrap (f : rdy_thread -> rdy_thread) (t: MT.Runnable.t) : MT.Runnable.t =
	 let
	    val () = Assert.assertAtomic' ("Scheduler.unwrap", NONE)
	    val tid = getCurThreadId ()
	    val RTHRD (tid', t') = f (RTHRD (tid, t))
	    val () = setCurThreadId tid'
	 in
	    t'
	 end


      (* reset various pieces of state *)
      fun reset running = 
	 (atomicBegin ()
	  ; setCurThreadId dummyTid
	  ; Q.reset rdyQ1; Q.reset rdyQ2
	  ; if not running then ready (prep (errorThrd ())) else ()
	  ; atomicEnd ())

      (* what to do at a preemption (with the current thread) *)
      fun preempt (thrd as RTHRD (tid, _)) =
	 let (*val _ = G.debug("PREEEEEEEEEEEEEEEMPTING\n")*)
	    val () = Assert.assertAtomic' ("Scheduler.preempt", NONE)
	    val () = debug' "Scheduler.preempt" (* Atomic 1 *)
	    val () = Assert.assertAtomic' ("Scheduler.preempt", SOME 1)
	    val () = 
	       if TID.isMarked tid
	       then (TID.unmark tid
			; promote ()
			; enque1 thrd)
	       else enque2 thrd
         in
	    ()
	 end

      val _ = reset false
   end


