Lambdalog

Sean Seefried's programming blog

22 Nov 2011

Haskell GADTs in Scala

This is an updated version of an earlier post. Owing to a comment by Jed Wesley-Smith I restructured this post somewhat to introduce two techniques for programming with GADTs in Scala. Thanks also go to Tony Morris.

First we’ll start with a fairly canonical example of why GADTs are useful in Haskell.

{-# LANGUAGE GADTs #-}
module Exp where

data Exp a where
LitInt :: Int -> Exp Int
LitBool :: Bool -> Exp Bool
Add :: Exp Int -> Exp Int -> Exp Int
Mul :: Exp Int -> Exp Int -> Exp Int
Cond :: Exp Bool -> Exp a -> Exp a -> Exp a
EqE :: Eq a => Exp a -> Exp a -> Exp Bool

eval :: Exp a -> a
eval e = case e of
LitInt i -> i
LitBool b -> b
Add e e' -> eval e + eval e'
Mul e e' -> eval e * eval e'
Cond b thn els -> if eval b then eval thn else eval els
EqE e e' -> eval e == eval e'

Here we have defined a data structure that represents the abstract syntax tree (AST) of a very simple arithmetic language. Notice that it ensures terms are well-typed. For instance something like the following just doesn’t type check.

LitInt 1 `Add` LitBool True -- this expression does not type check

I have also provided a function eval that evaluates terms in this language.

In Scala it is quite possible to define data structures which have the same properties as a GADT declaration in Haskell. You can do this with case classes as follows.

abstract class Exp[A] 

case class LitInt(i: Int) extends Exp[Int]
case class LitBool(b: Boolean) extends Exp[Boolean]
case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]

But how do we implement eval. You might think that the following code would work. I mean, it looks like the Haskell version, right?

abstract class Exp[A] {
def eval = this match {
case LitInt(i) => i
case LitBool(b) => b
case Add(e1, e2) => e1.eval + e2.eval
case Mul(e1, e2) => e1.eval * e2.eval
case Cond(b,thn,els) => if ( b.eval ) { thn.eval } else { els.eval }
case Eq(e1,e2) => e1.eval == e2.eval
}

}

case class LitInt(i: Int) extends Exp[Int]
case class LitBool(b: Boolean) extends Exp[Boolean]
case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]

Unfortunately for us, this doesn’t work. The Scala compiler is unable to instantiate the type Exp[A] to more specific ones (such as LitInt which extends Exp[Int])

3: constructor cannot be instantiated to expected type;
  found   : FailedExp.LitInt
  required: FailedExp.Exp[A]
    case LitInt(i)       => i
        ^

There are two solutions to this problem.

Solution 1: The object-oriented way

You must write eval the object-oriented way. The definition of eval gets spread over each of the sub-classes of Exp[A].

abstract class Exp[A] {
def eval: A
}

case class LitInt(i: Int) extends Exp[Int] {
def eval = i
}

case class LitBool(b: Boolean) extends Exp[Boolean] {
def eval = b
}

case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] {
def eval = e1.eval + e2.eval
}
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] {
def eval = e1.eval * e2.eval
}
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A] {
def eval = if ( b.eval ) { thn.eval } else { els.eval }
}
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean] {
def eval = e1.eval == e2.eval
}

Solution 2: The functional Haskell-like way

Personally I don’t like the OO style as much as the Haskell-like style. However, it turns out that you can program in that style by using a companion object.

object Exp {
def evalAny[A](e: Exp[A]): A = e match {
case LitInt(i) => i
case LitBool(b) => b
case Add(e1, e2) => e1.eval + e2.eval
case Mul(e1, e2) => e1.eval * e2.eval
case Cond(b, thn, els) => if (b.eval) { thn.eval } else { els.eval }
case Eq(e1, e2) => e1.eval == e2.eval
}
}

abstract class Exp[A] {
def eval: A = Exp.evalAny(this)
}

case class LitInt(i: Int) extends Exp[Int]
case class LitBool(b: Boolean) extends Exp[Boolean]
case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]

Ah, much better. But why does this work when the previous style doesn’t? The problem is that the constructors are not polymorphic. In Haskell-speak the type is:

LitInt :: Int -> Exp Int

not

LitInt :: Int -> Exp a

The second solution is subtly different. Method evalAny is polymorphic but its type is instantiated to that of the value of whatever it is called on. For instance evalAny when called on LitInt(42) equates type variable A with Int. It can then correctly deduce that it does indeed take a value of Exp[Int] and produce a value of Int.

Tagged as: Haskell, Scala, GADTs.

16 May 2011

Reifying Type Classes with GADTs

This year I’ve been contributing to Accelerate, an embedded language for regular, multi-dimensional array computations targetting high-performance back-ends such as NVIDIA GPUs. Even a cursory glance of the code base shows that it uses a number of advanced Haskell extensions including Generalized Algebraic Data Types (GADTs) and type families.

Earlier this year I learned a technique which seems to be part of the Haskell folklore but not necessarily well known: reifying types belonging to a type class into a GADT value. A typical use-case for this technique is using a type class defined in a library. Either you don’t have access to the source code (or for reasons of modularity/API stability you don’t want to touch it) but, nevertheless, you want to add a method to the type class. For argument’s sake let’s call this type class C.

The standard solution would be to declare a new type class, let’s call it D, with a super-class constraint on class C. (e.g. class C a => D a). You would then have to write one instance of class D for each existing instance of class C. Not only could this be a lot of work, you could run into the same problem of needing to extend class D further down the track (perhaps necessitating the declaration of class E).

What if there were a way to allow class C to be arbitrarily extensible? This being a blog post on the topic, it should come as no surprise that there is, with one caveat: the class must be closed. That is, the instances that have been written are complete and will not need to be changed in the future.

I’ll demonstrate how to reify types into values of a GADT using an example. Say we have a class, Counter, with a single method inc to increment a value.

{-# LANGUAGE GADTs #-}

class Counter a where
inc :: a -> a

Here are all the instances we will ever write for this class.

data CounterR a where
CounterRint :: CounterR Int
CounterRchar :: CounterR Char
CounterRlist :: CounterR a -> CounterR [a]
CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b)

class Counter a where
inc :: a -> a

instance Counter Int where
inc a = a + 1

instance Counter Char where
inc a = chr (ord a + 1)

instance Counter a => Counter [a] where
inc as = map inc as

instance (Counter a, Counter b) => Counter (a,b) where
inc (a,b) = (inc a, inc b)

Although the instances are fixed, we wish to allow others to effectively add new methods to the class in the future. We can do this by declaring a GADT that reifies the instance types. By convention we name this data structure by appending the character R to the class name.

data CounterR a where
CounterRint :: CounterR Int
CounterRchar :: CounterR Char
CounterRlist :: CounterR a -> CounterR [a]
CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b)

Please note the following properties of this declaration:

  • There is one constructor for each instance of class Counter.

  • Every constraint on the Counter class in an instance declaration becomes a parameter of the corresponding GADT constructor. e.g. instance (Counter a, Counter b) => Counter (a,b) becomes CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b).

We now need to add a new method signature to class Counter and provide an implementation for it in each of the instances. By convention this method is called counterR; that is, the name of the class with the first letter lower-cased and suffixed with character ‘R’.

Our module so far looks like this:

{-# LANGUAGE GADTs, FlexibleInstances, UndecidableInstances #-}
module Counter where

import Data.Char

data CounterR a where
CounterRint :: CounterR Int
CounterRchar :: CounterR Char
CounterRlist :: CounterR a -> CounterR [a]
CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b)

class Counter a where
inc :: a -> a
counterR :: CounterR a

instance Counter Int where
inc a = a + 1
counterR = CounterRint

instance Counter Char where
inc a = chr (ord a + 1)
counterR = CounterRchar

instance Counter a => Counter [a] where
inc as = map inc as
counterR = CounterRlist counterR

instance (Counter a, Counter b) => Counter (a,b) where
inc (a,b) = (inc a, inc b)
counterR = CounterRpair counterR counterR

Note that definition of counterR for each instance follows a very simple pattern. It is simply the application of the appropriate constructor to the appropriate number of calls to counterR.

It’s important to be clear that values of CounterR a don’t represent values of type a, they represent the type a itself.

For instance a value of type

CounterR (([Int], Int), (Int, Char))

reifies to the value

CounterRpair
(CounterRpair (CounterRlist CounterRint) CounterRint))
(CounterRpair CounterRint CounterRchar)

We can now write new methods on any type that is a member of class Counter. I’ll show you with an example. Say we now want to write a function that decrements instead of incrementing. We can write this as follows, but please note that the function takes the representation of the type and the value as an argument.

dec :: CounterR a -> a -> a
dec CounterRint a = a - 1
dec CounterRchar a = chr (ord a - 1)
dec (CounterRlist incRa) as = map (dec incRa) as
dec (CounterRpair incRa incRb) (a,b) = (dec incRa a, dec incRb b)

There are only a few things to note:

  • If dec were a method of class Counter then you would need to define it in each instance. Here you provide each implementation as a case of the dec function.
  • Since dec requires a value of type CounterR each recursive call to dec requires that you pass an appropriate value of type CounterR. Fortunately, the appropriate value is precisely one of those pattern matched on the left hand side.

As I noted at the beginning. This technique is not new, just relatively unknown. A paper by Oliveira and Sulzmann suggests unifying the two concepts. Also, GADTs have been used as an implementation technique in JHC instead of the normal dictionary passing implementation.

Tagged as: GADTs, Haskell, type classes.