Tuesday, January 31, 2012

Encoding algebraic data types in C#

Algebraic data types are generally not directly expressible in C#, but they're a tool far too useful to be left unused. ADTs are a very precise modeling tool, helping making illegal states unrepresentable. The Base Class library already includes a Tuple type representing the anonymous product type, and C# also has anonymous types to represent labeled product types (isn't that confusing). And of course you could even consider simple classes or structs to be product types.

But the BCL doesn't include any anonymous sum type. We can use F#'s Choice type in C#, for example this discriminated union type in F# (borrowed from MSDN):

type Shape =
  // The value here is the radius.
| Circle of float
  // The values here are the height and width.
| Rectangle of double * double

Could be represented in C# as FSharpChoice<float, Tuple<double, double>> . But this obviously loses the labels (Circle, Rectangle). These "labels" are usually called "constructors".

A reasonable approach to encode ADTs in C# would be using ILSpy to reverse engineer the code the F# compiler generates from the discriminated union above: (this might hurt a bit, but don't get scared!)

using Microsoft.FSharp.Core;
using System;
using System.Collections;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
[DebuggerDisplay("{__DebugDisplay(),nq}"), CompilationMapping(SourceConstructFlags.SumType)]
[Serializable]
[StructLayout(LayoutKind.Auto, CharSet = CharSet.Auto)]
public abstract class Shape : IEquatable<Program.Shape>, IStructuralEquatable, IComparable<Program.Shape>, IComparable, IStructuralComparable
{
    public static class Tags
    {
        public const int Circle = 0;
        public const int Rectangle = 1;
    }
    [DebuggerTypeProxy(typeof(Program.Shape.Circle@DebugTypeProxy)), DebuggerDisplay("{__DebugDisplay(),nq}")]
    [Serializable]
    public class Circle : Program.Shape
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal readonly double item;
        [CompilationMapping(SourceConstructFlags.Field, 0, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this.item;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        internal Circle(double item)
        {
            this.item = item;
        }
    }
    [DebuggerTypeProxy(typeof(Program.Shape.Rectangle@DebugTypeProxy)), DebuggerDisplay("{__DebugDisplay(),nq}")]
    [Serializable]
    public class Rectangle : Program.Shape
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal readonly double item1;
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal readonly double item2;
        [CompilationMapping(SourceConstructFlags.Field, 1, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item1
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this.item1;
            }
        }
        [CompilationMapping(SourceConstructFlags.Field, 1, 1), CompilerGenerated, DebuggerNonUserCode]
        public double Item2
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this.item2;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        internal Rectangle(double item1, double item2)
        {
            this.item1 = item1;
            this.item2 = item2;
        }
    }
    internal class Circle@DebugTypeProxy
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal Program.Shape.Circle _obj;
        [CompilationMapping(SourceConstructFlags.Field, 0, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this._obj.item;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        public Circle@DebugTypeProxy(Program.Shape.Circle obj)
        {
            this._obj = obj;
        }
    }
    internal class Rectangle@DebugTypeProxy
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal Program.Shape.Rectangle _obj;
        [CompilationMapping(SourceConstructFlags.Field, 1, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item1
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this._obj.item1;
            }
        }
        [CompilationMapping(SourceConstructFlags.Field, 1, 1), CompilerGenerated, DebuggerNonUserCode]
        public double Item2
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this._obj.item2;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        public Rectangle@DebugTypeProxy(Program.Shape.Rectangle obj)
        {
            this._obj = obj;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode, DebuggerBrowsable(DebuggerBrowsableState.Never)]
    public int Tag
    {
        [CompilerGenerated, DebuggerNonUserCode]
        get
        {
            return (!(this is Program.Shape.Rectangle)) ? 0 : 1;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode, DebuggerBrowsable(DebuggerBrowsableState.Never)]
    public bool IsRectangle
    {
        [CompilerGenerated, DebuggerNonUserCode]
        get
        {
            return this is Program.Shape.Rectangle;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode, DebuggerBrowsable(DebuggerBrowsableState.Never)]
    public bool IsCircle
    {
        [CompilerGenerated, DebuggerNonUserCode]
        get
        {
            return this is Program.Shape.Circle;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode]
    internal Shape()
    {
    }
    [CompilationMapping(SourceConstructFlags.UnionCase, 1)]
    public static Program.Shape NewRectangle(double item1, double item2)
    {
        return new Program.Shape.Rectangle(item1, item2);
    }
    [CompilationMapping(SourceConstructFlags.UnionCase, 0)]
    public static Program.Shape NewCircle(double item)
    {
        return new Program.Shape.Circle(item);
    }
    [CompilerGenerated, DebuggerNonUserCode]
    internal object __DebugDisplay()
    {
        return ExtraTopLevelOperators.PrintFormatToString<FSharpFunc<Program.Shape, string>>(new PrintfFormat<FSharpFunc<Program.Shape, string>, Unit, string, string, string>("%+0.8A")).Invoke(this);
    }
    [CompilerGenerated]
    public sealed override int CompareTo(Program.Shape obj)
    {
        if (this != null)
        {
            if (obj == null)
            {
                return 1;
            }
            int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
            int num2 = (!(obj is Program.Shape.Rectangle)) ? 0 : 1;
            if (num != num2)
            {
                return num - num2;
            }
            if (this is Program.Shape.Circle)
            {
                Program.Shape.Circle circle = (Program.Shape.Circle)this;
                Program.Shape.Circle circle2 = (Program.Shape.Circle)obj;
                IComparer genericComparer = LanguagePrimitives.GenericComparer;
                double item = circle.item;
                double item2 = circle2.item;
                if (item < item2)
                {
                    return -1;
                }
                if (item > item2)
                {
                    return 1;
                }
                if (item == item2)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(genericComparer, item, item2);
            }
            else
            {
                Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
                Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)obj;
                IComparer genericComparer2 = LanguagePrimitives.GenericComparer;
                double item3 = rectangle.item1;
                double item4 = rectangle2.item1;
                int num3 = (item3 >= item4) ? ((item3 <= item4) ? ((item3 != item4) ? LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(genericComparer2, item3, item4) : 0) : 1) : -1;
                if (num3 < 0)
                {
                    return num3;
                }
                if (num3 > 0)
                {
                    return num3;
                }
                IComparer genericComparer3 = LanguagePrimitives.GenericComparer;
                double item5 = rectangle.item2;
                double item6 = rectangle2.item2;
                if (item5 < item6)
                {
                    return -1;
                }
                if (item5 > item6)
                {
                    return 1;
                }
                if (item5 == item6)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(genericComparer3, item5, item6);
            }
        }
        else
        {
            if (obj != null)
            {
                return -1;
            }
            return 0;
        }
    }
    [CompilerGenerated]
    public sealed override int CompareTo(object obj)
    {
        return this.CompareTo((Program.Shape)obj);
    }
    [CompilerGenerated]
    public sealed override int CompareTo(object obj, IComparer comp)
    {
        Program.Shape shape = (Program.Shape)obj;
        if (this != null)
        {
            if ((Program.Shape)obj == null)
            {
                return 1;
            }
            int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
            Program.Shape shape2 = shape;
            int num2 = (!(shape2 is Program.Shape.Rectangle)) ? 0 : 1;
            if (num != num2)
            {
                return num - num2;
            }
            if (this is Program.Shape.Circle)
            {
                Program.Shape.Circle circle = (Program.Shape.Circle)this;
                Program.Shape.Circle circle2 = (Program.Shape.Circle)shape;
                double item = circle.item;
                double item2 = circle2.item;
                if (item < item2)
                {
                    return -1;
                }
                if (item > item2)
                {
                    return 1;
                }
                if (item == item2)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(comp, item, item2);
            }
            else
            {
                Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
                Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)shape;
                double item3 = rectangle.item1;
                double item4 = rectangle2.item1;
                int num3 = (item3 >= item4) ? ((item3 <= item4) ? ((item3 != item4) ? LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(comp, item3, item4) : 0) : 1) : -1;
                if (num3 < 0)
                {
                    return num3;
                }
                if (num3 > 0)
                {
                    return num3;
                }
                double item5 = rectangle.item2;
                double item6 = rectangle2.item2;
                if (item5 < item6)
                {
                    return -1;
                }
                if (item5 > item6)
                {
                    return 1;
                }
                if (item5 == item6)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(comp, item5, item6);
            }
        }
        else
        {
            if ((Program.Shape)obj != null)
            {
                return -1;
            }
            return 0;
        }
    }
    [CompilerGenerated]
    public sealed override int GetHashCode(IEqualityComparer comp)
    {
        if (this == null)
        {
            return 0;
        }
        int num;
        if (this is Program.Shape.Circle)
        {
            Program.Shape.Circle circle = (Program.Shape.Circle)this;
            num = 0;
            return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic<double>(comp, circle.item) + ((num << 6) + (num >> 2)));
        }
        Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
        num = 1;
        num = -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic<double>(comp, rectangle.item2) + ((num << 6) + (num >> 2)));
        return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic<double>(comp, rectangle.item1) + ((num << 6) + (num >> 2)));
    }
    [CompilerGenerated]
    public sealed override int GetHashCode()
    {
        return this.GetHashCode(LanguagePrimitives.GenericEqualityComparer);
    }
    [CompilerGenerated]
    public sealed override bool Equals(object obj, IEqualityComparer comp)
    {
        if (this == null)
        {
            return obj == null;
        }
        Program.Shape shape = obj as Program.Shape;
        if (shape == null)
        {
            return false;
        }
        Program.Shape shape2 = shape;
        int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
        Program.Shape shape3 = shape2;
        int num2 = (!(shape3 is Program.Shape.Rectangle)) ? 0 : 1;
        if (num != num2)
        {
            return false;
        }
        if (this is Program.Shape.Circle)
        {
            Program.Shape.Circle circle = (Program.Shape.Circle)this;
            Program.Shape.Circle circle2 = (Program.Shape.Circle)shape2;
            return circle.item == circle2.item;
        }
        Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
        Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)shape2;
        return rectangle.item1 == rectangle2.item1 && rectangle.item2 == rectangle2.item2;
    }
    [CompilerGenerated]
    public sealed override bool Equals(Program.Shape obj)
    {
        if (this == null)
        {
            return obj == null;
        }
        if (obj == null)
        {
            return false;
        }
        int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
        int num2 = (!(obj is Program.Shape.Rectangle)) ? 0 : 1;
        if (num != num2)
        {
            return false;
        }
        if (this is Program.Shape.Circle)
        {
            Program.Shape.Circle circle = (Program.Shape.Circle)this;
            Program.Shape.Circle circle2 = (Program.Shape.Circle)obj;
            double item = circle.item;
            double item2 = circle2.item;
            return (item != item && item2 != item2) || item == item2;
        }
        Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
        Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)obj;
        double item3 = rectangle.item1;
        double item4 = rectangle2.item1;
        if ((item3 != item3 && item4 != item4) || item3 == item4)
        {
            double item5 = rectangle.item2;
            double item6 = rectangle2.item2;
            return (item5 != item5 && item6 != item6) || item5 == item6;
        }
        return false;
    }
    [CompilerGenerated]
    public sealed override bool Equals(object obj)
    {
        Program.Shape shape = obj as Program.Shape;
        return shape != null && this.Equals(shape);
    }

Whew! That's a lot of code! Let's break down some of this code:

  • The DebuggerDisplay, DebuggerTypeProxy, DebuggerBrowsable, DebuggerNonUserCode attributes and DebugTypeProxy classes enhance the debugging experience.
  • Discriminated union types are marked as Serializable.
  • Discriminated union types implement equality and comparison (IEquatable, IComparable, IStructuralEquatable, IStructuralComparable, Equals(), GetHashCode())
  • Tags (simple integer constants) are used to optimize the implementation of equality and comparison.

It seems inviable to write such an amount of code in C# every time we want an ADT. However, underneath all the attributes and noise, the gist of it is quite simple: a class hierarchy starting with an abstract class plus a concrete subclass for each case:

abstract class Shape {
    class Circle: Shape {
        public readonly float Radius;

        public Circle(float radius) {
            Radius = radius;
        }
    }

    class Rectangle: Shape {
        public readonly double Height;
        public readonly double Width;

        public Rectangle(double height, double width) {
            Height = height;
            Width = width;
        }
    }
}

Shape is abstract because the only valid cases are Circle and Rectangle. Instantiating a Shape doesn't make sense!

Now, there's a detail we have to take care of: ADTs are closed, which means that we can't add new shapes to the type without changing the Shape type itself. This is in contradiction with the general practice of OOP: if we wanted a Square we could just create a new subclass of Shape. That, however, complicates things since it makes writing total functions over Shapes harder (impossible?). For more information about this and a comparison of OO (classes/subtyping) vs FP (closed ADTs) see this Stackoverflow question. OCaml supports "open ADTs" (actually called "open variants" or "polymorphic variants") which are powerful but have their cons too.

But I digress. The goal is to prevent subclasses of Rectangle, Circle and further subclasses of Shape. We can do this by making the constructor of Shape private, and sealing Rectangle and Shape:

public abstract class Shape {
    private Shape() {}

    public sealed class Circle : Shape {
        public readonly float Radius;

        public Circle(float radius) {
            Radius = radius;
        }
    }

    public sealed class Rectangle : Shape {
        public readonly double Height;
        public readonly double Width;

        public Rectangle(double height, double width) {
            Height = height;
            Width = width;
        }
    }
}

This is also why I chose to place Circle and Rectangle as nested classes of Shape.

So far we have a general structure and constructors. But we're not done yet! Given a Shape how do we know if it's a Circle or a Rectangle? How do we get to the data (radius, height, width)?

Any red-blooded object-oriented programmer would at this point be yelling "implement a Visitor!". Languages with first-class support for ADTs like ML dialects use pattern matching instead. Continuing with the same MSDN sample, we calculate the area for a Shape:

let shape = Circle 2.0

let area =
    match shape with
    | Circle radius -> System.Math.PI * radius * radius
    | Rectangle (h, w) -> h * w

System.Console.WriteLine area

Again we open this with ILSpy and see the following C# code (I edited it a bit to remove name mangling):

Shape shape = new Circle(2.0f);
double arg_67_0;
if (shape is Shape.Rectangle)
{
    var rectangle = (Shape.Rectangle)shape;
    double w = rectangle.Width;
    double h = rectangle.Height;
    arg_67_0 = h * w;
}
else
{
    var circle = (Shape.Circle)shape;
    double radius = circle.Radius;
    arg_67_0 = 3.1415926535897931 * radius * radius;
}
Console.WriteLine(arg_67_0);

It does runtime type testing and downcasting! As it turns out, runtime type testing is faster than a vtable dispatch, and the F# compiler optimizes taking advantage of this fact. If we had more cases in the discriminated union we'd see that the F# compiler simply nests ifs to tests all cases. This is fast but it's not very practical to write such code in C#. Also, we can't statically check if the pattern match was exhaustive. What we really want is a method on Shape (we'll call it Match) that takes two functions as parameters: one to handle the Circle case, another to handle the Rectangle case.

At this point we have several alternatives. We can implement a full visitor pattern as I said earlier, or we can take advantage of the fact that it's a closed type and simply encapsulate the type testing and downcasting, like this:

public T Match<T>(Func<float, T> circle, Func<double, double, T> rectangle) {
    if (this is Circle) {
        var x = (Circle)this;
        return circle(x.Radius);
    }
    var y = (Rectangle)this;
    return rectangle(y.Width, y.Height);
}

And now we can calculate the area of a Shape like this:

Shape shape = new Shape.Circle(2.0f);
var area = shape.Match<double>(circle: radius => Math.PI * radius * radius,
                               rectangle: (width, height) => width * height);
Console.WriteLine(area);

Note how named arguments in C# 4 make this a bit more readable. Also, using the Match method ensures that we always cover all cases. Alas, the C# compiler can't infer the return type, we have to type it explicitly.

Another alternative is to pass the whole object to the handling functions instead of just its data, e.g.

public T Match<T>(Func<Circle, T> circle, Func<Rectangle, T> rectangle)

Now we have a usable algebraic data type. When compared to F#, the boilerplate required in C# is considerable and tedious, but still worth it in my opinion. However, if you also need equality and comparison you might as well just use F# instead ;-)