16 May 2011

# Reifying Type Classes with GADTs

This year I’ve been contributing to Accelerate, an embedded language for regular, multi-dimensional array computations targetting high-performance back-ends such as NVIDIA GPUs. Even a cursory glance of the code base shows that it uses a number of advanced Haskell extensions including Generalized Algebraic Data Types (GADTs) and type families.

Earlier this year I learned a technique which seems to be part of the Haskell folklore but not necessarily well known: reifying types belonging to a type class into a GADT value. A typical use-case for this technique is using a type class defined in a library. Either you don’t have access to the source code (or for reasons of modularity/API stability you don’t want to touch it) but, nevertheless, you want to add a method to the type class. For argument’s sake let’s call this type class C.

The standard solution would be to declare a new type class, let’s call it D, with a super-class constraint on class C. (e.g. class C a => D a). You would then have to write one instance of class D for each existing instance of class C. Not only could this be a lot of work, you could run into the same problem of needing to extend class D further down the track (perhaps necessitating the declaration of class E).

What if there were a way to allow class C to be arbitrarily extensible? This being a blog post on the topic, it should come as no surprise that there is, with one caveat: the class must be closed. That is, the instances that have been written are complete and will not need to be changed in the future.

I’ll demonstrate how to reify types into values of a GADT using an example. Say we have a class, Counter, with a single method inc to increment a value.

{-# LANGUAGE GADTs #-}

class Counter a where
inc :: a -> a

Here are all the instances we will ever write for this class.

data CounterR a where
CounterRint  :: CounterR Int
CounterRchar :: CounterR Char
CounterRlist :: CounterR a -> CounterR [a]
CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b)

class Counter a where
inc :: a -> a

instance Counter Int where
inc a = a + 1

instance Counter Char where
inc a = chr (ord a + 1)

instance Counter a => Counter [a] where
inc as = map inc as

instance (Counter a, Counter b) => Counter (a,b) where
inc (a,b) = (inc a, inc b)

Although the instances are fixed, we wish to allow others to effectively add new methods to the class in the future. We can do this by declaring a GADT that reifies the instance types. By convention we name this data structure by appending the character R to the class name.

data CounterR a where
CounterRint  :: CounterR Int
CounterRchar :: CounterR Char
CounterRlist :: CounterR a -> CounterR [a]
CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b)

Please note the following properties of this declaration:

• There is one constructor for each instance of class Counter.

• Every constraint on the Counter class in an instance declaration becomes a parameter of the corresponding GADT constructor. e.g. instance (Counter a, Counter b) => Counter (a,b) becomes CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b).

We now need to add a new method signature to class Counter and provide an implementation for it in each of the instances. By convention this method is called counterR; that is, the name of the class with the first letter lower-cased and suffixed with character ‘R’.

Our module so far looks like this:

{-# LANGUAGE GADTs, FlexibleInstances, UndecidableInstances #-}
module Counter where

import Data.Char

data CounterR a where
CounterRint  :: CounterR Int
CounterRchar :: CounterR Char
CounterRlist :: CounterR a -> CounterR [a]
CounterRpair :: CounterR a -> CounterR b -> CounterR (a,b)

class Counter a where
inc :: a -> a
counterR :: CounterR a

instance Counter Int where
inc a = a + 1
counterR = CounterRint

instance Counter Char where
inc a = chr (ord a + 1)
counterR = CounterRchar

instance Counter a => Counter [a] where
inc as = map inc as
counterR = CounterRlist counterR

instance (Counter a, Counter b) => Counter (a,b) where
inc (a,b) = (inc a, inc b)
counterR = CounterRpair counterR counterR

Note that definition of counterR for each instance follows a very simple pattern. It is simply the application of the appropriate constructor to the appropriate number of calls to counterR.

It’s important to be clear that values of CounterR a don’t represent values of type a, they represent the type a itself.

For instance a value of type

CounterR (([Int], Int), (Int, Char))

reifies to the value

CounterRpair
(CounterRpair (CounterRlist CounterRint) CounterRint))
(CounterRpair CounterRint CounterRchar)

We can now write new methods on any type that is a member of class Counter. I’ll show you with an example. Say we now want to write a function that decrements instead of incrementing. We can write this as follows, but please note that the function takes the representation of the type and the value as an argument.

dec :: CounterR a -> a -> a
dec CounterRint a                    = a - 1
dec CounterRchar a                   = chr (ord a - 1)
dec (CounterRlist incRa) as          = map (dec incRa) as
dec (CounterRpair incRa incRb) (a,b) = (dec incRa a, dec incRb b)

There are only a few things to note:

• If dec were a method of class Counter then you would need to define it in each instance. Here you provide each implementation as a case of the dec function.
• Since dec requires a value of type CounterR each recursive call to dec requires that you pass an appropriate value of type CounterR. Fortunately, the appropriate value is precisely one of those pattern matched on the left hand side.

As I noted at the beginning. This technique is not new, just relatively unknown. A paper by Oliveira and Sulzmann suggests unifying the two concepts. Also, GADTs have been used as an implementation technique in JHC instead of the normal dictionary passing implementation.