Monday, October 13, 2008

Map Composition in Scala (or The Virtues of Laziness)

I have been dabbling around with some lazy optimization techniques in designing functional data structures. I am using Scala to implement some of the functional data structures of Okasaki .. hence thought it would be appropriate to think aloud of some of the lazy optimization that Scala offers.

Scala is a big language with myriads of features and some unfortunate syntactic warts that will definitely rub you the wrong way till you get comfortable with them. Some time back, a popular joke on twitter was related to the number of meanings and interpretations that the character "_" has in Scala. Indeed there are many, and not all of them are intuitive enough .. :-( ..

Now, on to real some Scala snippets ..

Do you know all the differences in interpretation of the following variants of function definition and application ?


// call by value
def foo(x:Bar) = {
  val a = x
  val b = x
}

// call by name
def foo(x: => Bar) = {
  val a = x
  val b = x
}

// no argument function implemented as thunks
def foo(x:() => Bar) = {
  val a = x
  val b = x
  val c = x()
}

// call by name to call by need
def foo(x: => Bar) = {
  lazy val y = x
  //...
}



This post is not about discussing the details of the above 4 cases .. Scala wiki has an excellent FAQ entry that describes each of them with sufficient rigor and detail .. go get it ..

Lazy evaluation is one of the techniques of performance optimization when dealing with large functional data structures. In Scala, Seq[A].Projection is an abstraction that makes operations lazy.


val l = List(...) // a big list
l.map(* 4).map(+ 2).filter(% 6 == 0)



In the above snippet, every operation on the list creates a separate list, which gets chained on to the next operation in line. Hence if you are dealing with large collections, it always sucks in performance. Go lazy and you have savings both in memory requirements and in elapsed time ..


l.projection.map(* 4).map(+ 2).filter(% 6 == 0)



This will return a Stream, which will do a lazy evaluation on demand. I found the term Stream first in SICP, where Abelson and Sussman introduce it as a data structure for delayed evaluation, which enables us to represent very large (even infinite) sequences.

In Scala, a Stream is a lazy list, and follows the semantics of SICP streams, where elments are evaluated only when they are needed.

Making your custom collection lazy ..

I was recently working with a custom recursive tree-like data structure in Scala, which, for simplicity, let us assume is a binary tree. And since I was fetching records from a database and then loading up data structures in memory, I was working with a really big sized collection. Let us see how we can implement Projection on my custom data structure and make things lazy on my own. Scala, unlike Haskell, is not an inherently lazy language, and abstractions like Projection, help implement laziness in evaluation. Eric Kidd wrote a great post on Haskell rewrite rules to implement declarative fusion of maps using compiler directives. This post has some inspiration from it, through the looking glass of Scala, an admittedly more verbose language than Haskell.


trait Tree[+A] {
  def map[B](f: A => B): Tree[B]
}
/**
 * Non empty tree node, with a left subtree and a right subtree
 */
case class Node[+A](data: A, left: Tree[A], right: Tree[A]) extends Tree[A] {
  override def map[B](f: A => B): Tree[B] = Node(f(data), left.map(f), right.map(f))
}
/**
 * Leaf node
 */
case class Leaf[+A](data: A) extends Tree[A] {
  override def map[B](f: A => B): Tree[B] = Leaf(f(data))
}
/**
 * Empty tree object
 */
case object E extends Tree[Nothing] {
  override def map[B](f: Nothing => B): Tree[B] = throw new Exception
}



We have a map operation defined for the tree, that uses a recursive implementation to map over all tree nodes. The map operation is a strict/eager one, much like List.map ..


val t = Node(7, Node(8, Leaf(9), Leaf(10)), Node(11, Leaf(12), Leaf(13)))
t.map(* 2).map(+ 1)



will result in a new tree that will have both the map operations done in succession. And in the process will generate intermediate tree structures, one for each of the map operations in chain. Needless to say, for a large collection, both space and time will hit you.

Getting rid of the intermediate trees ..

Implement laziness .. make evaluations lazy, so that effectively we have one final tree that evaluates it's nodes only when asked for. In other words, lift the operation from the collection to an iterator, which gets evaluated only when asked for.

Here is a sample bare bone unoptimized iterator implemented via scala.collection.mutable.Stack ..


class TreeIterator[A](it: Tree[A]) extends Iterator[A] {
  import scala.collection.mutable.Stack
  val st = new Stack[Tree[A]]
  st push it

  override def hasNext = st.isEmpty == false
  override def next: A = st.pop match {
    case Node(d, l, r) =>
      st push r
      st push l
      d
    case Leaf(d) =>
      d
  }
}



Using this iterator, we define the Projection for Tree with the lazy map implementation and integrate it with the main data structure through a projection method ..

Here is the modified Tree[+A] ..


trait Tree[+A] {
  def elements: Iterator[A] = new TreeIterator(this)
  def map[B](f: A => B): Tree[B]
  def projection : Tree.Projection[A] = new Tree.Projection[A] {
    override def elements = Tree.this.elements
  }
}



and the Projection trait in an accompanying singleton for Tree ..


object Tree {
  trait Projection[+A] extends Tree[A] {
    override def map[B](f: A => B): Projection[B] = new Projection[B] {
      override def elements = Projection.this.elements.map(f)
    }
  }
}



Now I can use my data structure to implement lazy evaluations and fusion of operations ..


val t = Node(7, Node(8, Leaf(9), Leaf(10)), Node(11, Leaf(12), Leaf(13)))
t.projection.map(* 2).map(+ 1)



Eric Kidd reported on making Haskell maps 225% faster through fusion and rewrite rules. In Scala, implementing laziness through delayed evaluation (or Projection) can also lead to substantial reduction in memory usage and elapsed time.

1 comment:

Anonymous said...

When I try to travers a recursive lazy tree,


{lazy val t :Node[Int] = Node(7, Leaf(3), t) ; Unit}


how can I preserve laziness ?