(*---RCS--- $Log:	_optimisecexp.sml,v $
Revision 1.1  92/09/17  14:17:24  birkedal
Edinburgh Version 11Sep92
*)

(*$OptimiseCExp: CEXP PRIMS FINMAP OPTIMISE_CEXP*)
functor OptimiseCExp(structure CExp: CEXP

		     structure Prims: PRIMS
		       sharing type CExp.prim = Prims.prim

		     structure FinMap: FINMAP
		       sharing type CExp.map = FinMap.map
		    ): OPTIMISE_CEXP =
  struct
    open CExp

   (* General utilities. *)
   
   (* count - count occurrences of lvar (non-binding occurrence) in a CExp. *)

    fun maybe lv1 lv2 = if lv1 = lv2 then 1 else 0
    val sum = ListUtil.fold (op +) 0

    fun count(lv, cexp) =
      case cexp
	of FIX(bindings, scope) =>
	     sum(map (fn {body, ...} => count(lv, body)) bindings)
	     + count(lv, scope)

	 | APP{f, actuals} =>
	     maybe lv f
	     + sum(map (fn VAR lv' => maybe lv lv' | _ => 0) actuals)

	 | PRIM_APP{arg, cont, ...}  => maybe lv arg + maybe lv cont
	 | SELECT(_, lv') => maybe lv lv'
	 | SWITCH_I sw => countInSwitch(lv, sw)
	 | SWITCH_S sw => countInSwitch(lv, sw)
	 | SWITCH_R sw => countInSwitch(lv, sw)
	 | SIMPLE(VAR lv') => maybe lv lv'
	 | SIMPLE(VECTOR lvs) => sum(map (maybe lv) lvs)
	 | SIMPLE _ => 0

    and countInSwitch lv (SWITCH{arg, selections, wildcard}) =
      maybe lv  arg
      + countInMap(lv, selections)
      + (case wildcard of SOME cexp => count(lv, cexp) | NONE => 0)

    and countInMap(lv, map) =
      FinMap.fold (fn (cexp, n) => count(lv, cexp) + n) 0 map


   (* We do the following optimisations:

	o Exbedding of nested fixes of the form

		FIX(bind1, FIX(bind2, ...))
		=> FIX(bind1 @ bind2, ...)

	o Remove null FIX bindings (generated as a result of the
	  other optimisations):

		FIX(nil, scope) => scope

	o Beta-reduce once-called functions, as long as they occur
	  in an applied context (otherwise we'd have to build anonymous
	  lambdas, and we'd be back where we started):

		FIX(f(x) = e1, ...f(y)...) => e1[x/y]

	o Eta-reduction:

		FIX(f(x) = g(x), e1) => e1[f/g]

      Each of the optimisation functions operates on the node it's
      passed only, and doesn't bother recursing down the tree; a
      general tree-walker does that. The type of each optimiser is
      (string->unit) -> CExp -> CExp, where the first argument is a
      `tick' function to mark a successful optimisation (and perhaps
      print some diagnostic); this is used to determine the fix-point
      in the rewritings.
    *)

    fun opt_exbed tick cexp =	(* Exbed FIX's which are the immediate scope
				   of other FIX's. *)
      case cexp
	of FIX(bindings, FIX(bindings', scope)) =>
	     (tick "exbed"; FIX(bindings @ bindings', scope))
	 | _ => cexp

    fun opt_empty tick cexp =	(* Remove FIX(nil, ...) *)
      case cexp
	of FIX(nil, scope) => (tick "empty"; scope)
	 | _ => cexp

    fun opt_beta tick cexp =	(* Beta-reduce once-called functions which
				   occur in an application context. *)
      let
       (* replace - replace each lvar `a' with the corresponding `x',
		    but only look for each lvar once. *)

	fun replace(a :: as, x :: xs, cexp) =
	      replace(
		as, xs,
		walk (fn it as PRIM_APP{cont, prim, arg} =>
			   if cont=a
			   then PRIM_APP{cont=x, prim=prim, arg=arg}
			   else if cont=arg
			   then PRIM_APP{cont=cont, prim=prim, arg=x}
			   else it

		       | it as SELECT(int, lv) =>
			   if lv=a then SELECT(int, x) else it

		       | it as SIMPLE(VAR lv) =>
			   if lv=a then SIMPLE(VAR x) else it

		       | VECTOR lvs =>
			   VECTOR(replaceL(a, x, lvs))
		     ) cexp
	      )

	  | replace(nil, nil, cexp) = cexp
	  | replace _ = Crash.impossible "OptimiseCExp.replace"

	and replaceL(a, x, lv :: lvs) =
	  if lv=a then x :: lvs else lv :: replaceL(a, x, lvs)


	fun scan(prev, (it as {f, actuals, scope=scope'}) :: rest, scope) =
	     let
	      (* count occurrences of f in RHS's of the other bindings in
		 this FIX, and in the scope of the FIX. But, if it occurs
		 in scope' then it's a recursive call, and we can't do
		 anything useful. *)

	       val otherN =
		 sum(map (fn {scope, ...} => count(f, scope)) (prev @ rest))
		 plus count(f, scope)

	       val recN = count(f, scope')
	     in
	       case (recN, otherN)
		 of ({calls=0, other=0}, {calls=1, other=0}) =>
		   ...replace in prev @ scan(replace in rest, replace in scope)
		   tick "beta"

		  | _ =>	(* Nope, don't try to replace this one. *)
		      scan(it :: prev, rest, scope)
	     end

	  | scan(prev, nil, scope) = (prev, scope)


	fun betaReduce(f, args, body, scope) =
	  walk (fn (it as APP{f=f', actuals}) =>
		     if f=f' then replace(args, actuals, body) else it
		 | x => x
	       ) scope
      in
	case cexp
	  of FIX(bindings, scope) =>
	       FIX(scan(nil, bindings, scope))
	   | _ => cexp
      end

    fun opt_eta tick cexp =	(* Rewrite FIX(f(xs) = g(xs), ...) *)
      cexp (***)


   (* General tree-walking stuff. *)

   (* onSwitch: apply a CExp->Cexp to a switch construct. *)

    fun onSwitch opt (SWITCH{arg, selections, wildcard}) =
      SWITCH{arg=arg,
	     selections=FinMap.composemap opt selections,
	     wildcard=case wildcard of SOME cexp => SOME(opt cexp)
	       			     | NONE => NONE
	    }

   (* onBindings: apply a CExp->CExp over FIX bindings. *)

    fun onBindings opt bindings =
      map (fn {f, formals, body} => {f=f, formals=formals, body=opt body}
	  ) bindings

   (* pass - apply a CExp->CExp to each CExp in a tree, from the
      bottom up. *)

    fun pass (f: CExp->CExp) cexp =
      f(case cexp
	  of FIX(bindings, scope) =>
	       FIX(onBindings f bindings, f scope)

	   | APP _       => cexp
	   | PRIM_APP _  => cexp
	   | SELECT _    => cexp
	   | SWITCH_I sw => SWITCH_I(onSwitch f sw)
	   | SWITCH_S sw => SWITCH_S(onSwitch f sw)
	   | SWITCH_R sw => SWITCH_R(onSwitch f sw)
	   | SIMPLE _    => cexp
       )

   (* optimise: apply a composition of all the optimisations repeatedly
      on a tree.  *)

    fun optimise cexp =
      let
	val going = ref true
	fun tick s = (output(std_out, "   " ^ s ^ "\n"); going := true)

	val opt =
	  (opt_exbed tick) o (opt_empty tick)
	  o (opt_beta tick) o (opt_eta tick)

	fun go cexp =
	  if !going then (output(std_out, "go:\n");
			  going := false;
			  go(pass opt cexp)
			 )
	  else cexp
      in
	go cexp
      end
  end;
