Monday, July 11, 2011

Datatype generic programming in Scala - Fixing on Cata

The paper Functional Programming with Overloading and Higher-Order Polymorphism by Mark P Jones discusses recursion and fixpoints in a section Recursion schemes: Functional programming with bananas and lenses with all examples in Haskell. The paper is an excellent read for anyone interested in an overall landscape of functional programming concepts. In case you haven’t read it, stop reading this post and have a look at it.

The section on recursion in the original paper discusses type level fixpoints and how it can be used to define generic abstractions like catamorphism and anamorphism. These are used extensively in datatype generic programming, where you can define generic combinators parameterized by the shape of the data. A very common example of a combinator parameterized by the shape or type constructor is fold, which works with many recursive structures like List, Tree etc.

I have been trying to use Scala to do the same examples that the paper models in the section on recursion. I was curious to find out if Scala's type system is expressive enough to implement higher order abstractions like the type level fixpoint combinator that the paper implements using Haskell. Oliveira and Gibbons also made similar studies on the suitability of Scala for datatype generic programming and have translated Gibbons' ORIGAMI for a subset of the GoF patterns in Scala.

In order to learn type level fixpoints, I started with value level fixpoints using untyped lambda calculus. Once you study the properties of a fixed point for functions, you can find out the mappings with type level fixpoints. The former takes functions and maps types, while the latter takes type constructors and maps kinds.

I start with an introduction and some proofs for value level fixpoints. Most of it is pretty basic stuff - still I wanted to say these things as it may prove useful to someone learning my way. In case you are familiar with this, feel free to skip to the next section.

Value level fixpoint

Lambda calculus helps us split a recursive function definition into 2 parts - a non-recursive computable definition that abstracts the recursive call away to an additional parameter and a fixpoint operator that encodes the recursion part. This way we take the recursion off the main function that we model. Here’s the usual factorial that goes through this process and lands up as a fixed point computation ..

FACT = λn.IF (= n 0) 1 (* n (FACT (- n 1)))) // λ calculus
FACT = (λfac (λn.IF (= n 0) 1 (* n (fac (- n 1)))))) FACT // beta abstraction

Note here we have isolated the main factorial function as an abstraction that does not have any recursion. Call the above equation ..


and this encodes the recursion part. H is a function which when applied to FACT gives back a FACT. We call FACT a fixpoint of H.

What then is a fixpoint combinator ?

A fixpoint combinator is a function Y which takes another function as argument and generates its fixpoint. e.g. in the case above Y when applied on H will give us the fixpoint FACT.

Y H = FACT        // from definition of Y    #1
FACT = H FACT     // from above              #2
Y H = H FACT      // applying #2 in #1       #3
Y H = H (Y H)     // applying #1 in #3       #4

Note we started with FACT as the model of our factorial. It’s an interesting exercise to see that FACT 4 indeed results in 24 following the above formula.

So Y is the magic that helps us express any recursive function as a non-recursive computation. In fact Y can be expressed as a lambda expression without using recursion ..

= (λh. (λx. h(x x)) (λx. h(x x)))

Let’s see what Y H evaluates to ..

= (λh. (λx. h(x x)) (λx. h(x x))) H    // definition of Y
-> (λx. H(x x)) (λx. H(x x))                 // beta reduction
-> H ((λx. H(x x)) (λx. H(x x)))             // beta reduction
-> H (Y H)

So that’s cool .. we have successfully derived the Y combinator as the lambda expression.

In the above we did the derivation of the fixed point combinator for any recursive function, which is a value. The subject of today’s post is the fixed combinator for types. We will see how the above maps to the same model when applied at the type level.

Y for types

In this post I will model the type level fixpoint combinator in Scala showing that Scala’s type system has the power to express all the data type generic abstractions that Mark does using Haskell.

Here’s what we saw as fixed point for values (functions) ..

Y H = H (Y H)

The fixed point combinator for types is usually called Mu .. In Haskell we will say ..

data Mu f = In ((Mu f))

Note the correspondence .. while Y takes functions and maps across types, Mu takes type constructors and maps across kinds ..

GHCi> :t Y
:: (-> a) -> a
GHCi> :Mu
Fix :: (* -> *) -> *

Now in Scala we model Mu as a case class that takes a type constructor as parameter. Like the paper we only consider unary type constructors - it’s not that difficult to generalize it to higher arities ..

case class Mu[F[_]](out: F[Mu[F]])

Just like with value level fixpoint Y we can isolate the recursive function into a non recursive computation, we can do the same thing on types with Mu. Consider the following data type declaration for modeling Natural Numbers ..

trait Nat
case object Zero extends Nat
case class Succ(n: Nat) extends Nat

.. a recursive type .. let’s break the recursion by introducing an additional type parameter ..

trait NatF[+S]
case class Succ[+S](s: S) extends NatF[S]
case object Zero extends NatF[Nothing]

type Nat = Mu[NatF]

.. and modeling the actual type Nat as a fixpoint of NatF.

Mu is actually a functor fixpoint, which means that it works on functors (more specifically covariant functors). In our case we need to define a functor for NatF. No problem .. scalaz is there .. and here’s the functor instance for NatF ..

implicit object functorNatF extends Functor[NatF] {
  def fmap[A, B](r: NatF[A], f: A => B) = r match {
    case Zero => Zero
    case Succ(s) => Succ(f(s))

And we define a couple of convenience functions for the zero natural number and the successor function ..

def zero: Nat = Mu[NatF](Zero)
def succ: Nat => Nat = (s: Nat) => Mu[NatF](Succ(s))

What we did with Mu

So far we defined a datatype as a fixpoint of a functor. Instead of making the datatype definition recursive we abstracted the recursion within the fixpoint combinator. And we did all these as a general strategy that can be applied to many other recursive datatypes.

Let’s apply the same pattern to a List data type .. ok we use an IntList, a list of integers, following the paper.

trait IntListF[+S]
case object Nil extends IntListF[Nothing]
case class Cons[+S](x: Int, xs: S) extends IntListF[S]

// the integer list as a fixpoint of IntListF
type IntList = Mu[IntListF]

// convenience functions for the constructors
def nil = Mu[IntListF](Nil)
def cons = (x: Int) => (xs: IntList) => Mu[IntListF](Cons(x, xs))

.. and the functor instance ..

implicit object functorIntListF extends Functor[IntListF] {
  def fmap[A, B](r: IntListF[A], f: A => B) = r match {
    case Nil => Nil
    case Cons(n, x) => Cons(n, f(x))

Note the similarity in structure for both the datatypes Nat and IntList and how the type constructors in both the cases IntListF and NatF determine the shape of the computation for the respective datatypes. This means that the datatype definition is parameterized by a type constructor that determines the shape of the data that will be modeled by the datatype.

And now for the cata

Now with the above abstraction of Mu in place, we can define a combinator that can be shown to be more generalized than a fold .. it’s catamorphism, which we define as ..

def cata[A, F[_]](f: F[A] => A)(t: Mu[F])(implicit fc: Functor[F]): A = {
  f(fc.fmap(t.out, cata[A, F](f)))

For recursive datatypes folds will have different types, but cata is a more general abstraction that is capable of defining all of the operations on the datatype. cata is similar to a fold, just more generic. The signature of fold varies with the datatype on which it's applied. But the above cata definition is generic enough to model functions on type constructors for which there's a matching functor instance.

In this blog post Tony Morris discusses a catamorphism on an Option data type in Scala. He defines a cata which is specific for that data type. The above definition is at a higher level of abstraction and is parameterized on the shape of the data that it takes. The following snippets use the same cata to define functions for Nat as well as IntList. Datatype generic abstractions FTW.

Have a look at the following functions on Nat, all defined in terms of the cata combinator ..

def fromNat = cata[Int, NatF] {
  case Zero => 0
  case Succ(n) => 1 + n
} _ 

scala> fromNat(succ(succ(zero)))
res14: Int = 2

def addNat(m: Nat, n: Nat) = cata[Nat, NatF] {
  case Zero => m
  case Succ(x) => succ(x)
} (n)

scala> fromNat(addNat(succ(zero), succ(zero)))
res15: Int = 2
scala> fromNat(addNat(succ(zero), succ(succ(zero))))
res16: Int = 3

def mulNat(m: Nat, n: Nat) = cata[Nat, NatF] {
  case Zero => zero
  case Succ(x) => addNat(m, x)
} (n)

scala> fromNat(mulNat(succ(succ(succ(zero))), succ(succ(zero))))
res1: Int = 6
scala> fromNat(mulNat(succ(succ(succ(zero))), zero))
res2: Int = 0

def expNat(m: Nat, n: Nat) = cata[Nat, NatF] {
  case Zero => succ(zero)
  case Succ(x) => mulNat(m, x)
} (n)

scala> fromNat(expNat(succ(succ(succ(zero))), zero))
res0: Int = 1
scala> fromNat(expNat(succ(succ(succ(zero))), succ(succ(zero))))
res1: Int = 9

.. and some more on IntList using the same cata ..

def sumList = cata[Int, IntListF] {
  case Nil => 0
  case Cons(x, n) => x + n
} _

scala> sumList(cons(1)(cons(2)(cons(3)(nil))))
res1: Int = 6

def len = cata[Nat, IntListF] {
  case Nil => zero
  case Cons(x, xs) => succ(xs)
} _

scala> fromNat(len(cons(1)(cons(2)(cons(3)(nil)))))
res1: Int = 3