# Haskell GADTs in Scala

This is an updated version of an earlier post. Owing to a comment by Jed Wesley-Smith I restructured this post somewhat to introduce two techniques for programming with GADTs in Scala. Thanks also go to Tony Morris.

First we’ll start with a fairly canonical example of why GADTs are useful in Haskell.

```
{-# LANGUAGE GADTs #-}
module Exp where
data Exp a where
LitInt :: Int -> Exp Int
LitBool :: Bool -> Exp Bool
Add :: Exp Int -> Exp Int -> Exp Int
Mul :: Exp Int -> Exp Int -> Exp Int
Cond :: Exp Bool -> Exp a -> Exp a -> Exp a
EqE :: Eq a => Exp a -> Exp a -> Exp Bool
eval :: Exp a -> a
eval e = case e of
LitInt i -> i
LitBool b -> b
Add e e' -> eval e + eval e'
Mul e e' -> eval e * eval e'
Cond b thn els -> if eval b then eval thn else eval els
EqE e e' -> eval e == eval e'
```

Here we have defined a data structure that represents the abstract syntax tree (AST) of a very simple arithmetic language. Notice that it ensures terms are well-typed. For instance something like the following just doesn’t type check.

`LitInt 1 `Add` LitBool True -- this expression does not type check`

I have also provided a function `eval`

that evaluates terms in this language.

In Scala it is quite possible to define data structures which have the same properties as a GADT declaration in Haskell. You can do this with *case classes* as follows.

```
abstract class Exp[A]
case class LitInt(i: Int) extends Exp[Int]
case class LitBool(b: Boolean) extends Exp[Boolean]
case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]
```

But how do we implement `eval`

. You might think that the following code would work. I mean, it looks like the Haskell version, right?

```
abstract class Exp[A] {
def eval = this match {
case LitInt(i) => i
case LitBool(b) => b
case Add(e1, e2) => e1.eval + e2.eval
case Mul(e1, e2) => e1.eval * e2.eval
case Cond(b,thn,els) => if ( b.eval ) { thn.eval } else { els.eval }
case Eq(e1,e2) => e1.eval == e2.eval
}
}
case class LitInt(i: Int) extends Exp[Int]
case class LitBool(b: Boolean) extends Exp[Boolean]
case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]
```

Unfortunately for us, this doesn’t work. The Scala compiler is unable to instantiate the type `Exp[A]`

to more specific ones (such as `LitInt`

which extends `Exp[Int]`

)

```
3: constructor cannot be instantiated to expected type;
found : FailedExp.LitInt
required: FailedExp.Exp[A]
case LitInt(i) => i
^
```

There are two solutions to this problem.

# Solution 1: The object-oriented way

You must write `eval`

the object-oriented way. The definition of `eval`

gets spread over each of the sub-classes of `Exp[A]`

.

```
abstract class Exp[A] {
def eval: A
}
case class LitInt(i: Int) extends Exp[Int] {
def eval = i
}
case class LitBool(b: Boolean) extends Exp[Boolean] {
def eval = b
}
case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] {
def eval = e1.eval + e2.eval
}
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int] {
def eval = e1.eval * e2.eval
}
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A] {
def eval = if ( b.eval ) { thn.eval } else { els.eval }
}
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean] {
def eval = e1.eval == e2.eval
}
```

# Solution 2: The functional Haskell-like way

Personally I don’t like the OO style as much as the Haskell-like style. However, it turns out that you can program in that style by using a companion object.

```
object Exp {
def evalAny[A](e: Exp[A]): A = e match {
case LitInt(i) => i
case LitBool(b) => b
case Add(e1, e2) => e1.eval + e2.eval
case Mul(e1, e2) => e1.eval * e2.eval
case Cond(b, thn, els) => if (b.eval) { thn.eval } else { els.eval }
case Eq(e1, e2) => e1.eval == e2.eval
}
}
abstract class Exp[A] {
def eval: A = Exp.evalAny(this)
}
case class LitInt(i: Int) extends Exp[Int]
case class LitBool(b: Boolean) extends Exp[Boolean]
case class Add(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Mul(e1: Exp[Int], e2: Exp[Int]) extends Exp[Int]
case class Cond[A](b: Exp[Boolean], thn: Exp[A], els: Exp[A]) extends Exp[A]
case class Eq[A](e1: Exp[A], e2: Exp[A]) extends Exp[Boolean]
```

Ah, much better. But why does this work when the previous style doesn’t? The problem is that the constructors are not polymorphic. In Haskell-speak the type is:

`LitInt :: Int -> Exp Int`

not

`LitInt :: Int -> Exp a`

The second solution is subtly different. Method `evalAny`

is polymorphic but its type is instantiated to that of the value of whatever it is called on. For instance `evalAny`

when called on `LitInt(42)`

equates type variable `A`

with `Int`

. It can then correctly deduce that it does indeed take a value of `Exp[Int]`

and produce a value of `Int`

.