# Haskell GADTs in Scala

This is an updated version of an earlier post. Owing to a comment by Jed Wesley-Smith I restructured this post somewhat to introduce two techniques for programming with GADTs in Scala. Thanks also go to Tony Morris.

First we’ll start with a fairly canonical example of why GADTs are useful in Haskell.

`{-# LANGUAGE GADTs #-}`

module Exp where

data Exp a where

LitInt :: Int -> Exp Int

LitBool :: Bool -> Exp Bool

Add :: Exp Int -> Exp Int -> Exp Int

Mul :: Exp Int -> Exp Int -> Exp Int

Cond :: Exp Bool -> Exp a -> Exp a -> Exp a

EqE :: Eq a => Exp a -> Exp a -> Exp Bool

eval :: Exp a -> a

eval e = case e of

LitInt i -> i

LitBool b -> b

Add e e' -> eval e + eval e'

Mul e e' -> eval e * eval e'

Cond b thn els -> if eval b then eval thn else eval els

EqE e e' -> eval e == eval e'

Here we have defined a data structure that represents the abstract syntax tree (AST) of a very simple arithmetic language. Notice that it ensures terms are well-typed. For instance something like the following just doesn’t type check.

`LitInt 1 `Add` LitBool True -- this expression does not type check`

I have also provided a function `eval`

that evaluates terms in this language.

In Scala it is quite possible to define data structures which have the same properties as a GADT declaration in Haskell. You can do this with *case classes* as follows.

`abstract class Exp[A] `

case class LitInt(i: Int) extends Exp[Int]

case class LitBool(b: Boolean) extends Exp[Boolean]

case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]

case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]

case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]

case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]

But how do we implement `eval`

. You might think that the following code would work. I mean, it looks like the Haskell version, right?

`abstract class Exp[A] {`

def eval = this match {

case LitInt(i) => i

case LitBool(b) => b

case Add(e1, e2) => e1.eval + e2.eval

case Mul(e1, e2) => e1.eval * e2.eval

case Cond(b,thn,els) => if ( b.eval ) { thn.eval } else { els.eval }

case Eq(e1,e2) => e1.eval == e2.eval

}

}

case class LitInt(i: Int) extends Exp[Int]

case class LitBool(b: Boolean) extends Exp[Boolean]

case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]

case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]

case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]

case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]

Unfortunately for us, this doesn’t work. The Scala compiler is unable to instantiate the type `Exp[A]`

to more specific ones (such as `LitInt`

which extends `Exp[Int]`

)

```
3: constructor cannot be instantiated to expected type;
found : FailedExp.LitInt
required: FailedExp.Exp[A]
case LitInt(i) => i
^
```

There are two solutions to this problem.

# Solution 1: The object-oriented way

You must write `eval`

the object-oriented way. The definition of `eval`

gets spread over each of the sub-classes of `Exp[A]`

.

`abstract class Exp[A] {`

def eval: A

}

case class LitInt(i: Int) extends Exp[Int] {

def eval = i

}

case class LitBool(b: Boolean) extends Exp[Boolean] {

def eval = b

}

case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] {

def eval = e1.eval + e2.eval

}

case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] {

def eval = e1.eval * e2.eval

}

case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A] {

def eval = if ( b.eval ) { thn.eval } else { els.eval }

}

case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean] {

def eval = e1.eval == e2.eval

}

# Solution 2: The functional Haskell-like way

Personally I don’t like the OO style as much as the Haskell-like style. However, it turns out that you can program in that style by using a companion object.

`object Exp {`

def evalAny[A](e: Exp[A]): A = e match {

case LitInt(i) => i

case LitBool(b) => b

case Add(e1, e2) => e1.eval + e2.eval

case Mul(e1, e2) => e1.eval * e2.eval

case Cond(b, thn, els) => if (b.eval) { thn.eval } else { els.eval }

case Eq(e1, e2) => e1.eval == e2.eval

}

}

abstract class Exp[A] {

def eval: A = Exp.evalAny(this)

}

case class LitInt(i: Int) extends Exp[Int]

case class LitBool(b: Boolean) extends Exp[Boolean]

case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]

case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]

case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]

case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]

Ah, much better. But why does this work when the previous style doesn’t? The problem is that the constructors are not polymorphic. In Haskell-speak the type is:

`LitInt :: Int -> Exp Int`

not

`LitInt :: Int -> Exp a`

The second solution is subtly different. Method `evalAny`

is polymorphic but its type is instantiated to that of the value of whatever it is called on. For instance `evalAny`

when called on `LitInt(42)`

equates type variable `A`

with `Int`

. It can then correctly deduce that it does indeed take a value of `Exp[Int]`

and produce a value of `Int`

.