This is a game of typing tetris. Let's consider the types of your first two (correct) branches.
a :: a
f :: a -> Expr b
----------------
Expr b
We read this as "given all of things above the line, we must produce something of the type listed below the line". In this case, the answer is obvious: we just apply f
to a
.
a :: a
f :: a -> Expr b
----------------
f a :: Expr b
For Val n
we can repeat the process
n :: Int
f :: a -> Expr b
----------------
Expr b
at first this seems impossible, but we have elided a few things "above the line" that are actually available. Importantly, there's this:
n :: Int
f :: a -> Expr b
Val :: forall x . Int -> Expr x
--------------------
Expr b
and since x
is not specific (it's bound by forall
) it can unify with b
.
n :: Int
f :: a -> Expr b
Val :: forall x . Int -> Expr x
--------------------
Val n :: Expr b
For the final case, we have
x :: Expr a
y :: Expr a
f :: a -> Expr b
Add :: forall x . Expr x -> Expr x -> Expr x
--------------------------------------------
Expr b
Again, this seems impossible. We could reassemble the Add
using x
and y
but that just gives us a type Expr a
instead of Expr b
. We can't apply the f
because it needs an a
not an Expr a
.
So the trick is that with recursive data types you're almost certainly going to use recursive definitions to functions... So let's bring in one other thing we have from the environment.
x :: Expr a
y :: Expr a
f :: a -> Expr b
Add :: forall x . Expr x -> Expr x -> Expr x
(>>=) :: forall y z . Expr y -> (y -> Expr z) -> Expr z
-------------------------------------------------------
Expr b
Again, due to the forall
we can use (>>=)
on Expr
types whatever their variable is. Almost immediately we have only one way to go forward: the only value we have that could be the second argument to (>>=)
is f
x :: Expr a
y :: Expr a
Add :: forall x . Expr x -> Expr x -> Expr x
(>>= f) :: Expr a -> (a -> Expr b) -> Expr b
-------------------------------------------------------
Expr b
And now we could apply either x
or y
on the left side of (>>= f)
to get a value of type Expr b
. Unfortunately, both of these are wrong. One way to be sure of this is that for highly general functions like (>>=)
we almost never throw away information—each argument should be used non-trivially.
Fortunately, if we have two Expr b
s we can use Add
to combine them:
x :: Expr a
y :: Expr a
Add :: Expr b -> Expr b -> Expr b
(>>= f) :: Expr a -> (a -> Expr b) -> Expr b
-------------------------------------------------------
Expr b
and now we have a way to use both x
and y
non-trivially
x :: Expr a
y :: Expr a
Add :: Expr b -> Expr b -> Expr b
(>>= f) :: Expr a -> (a -> Expr b) -> Expr b
-------------------------------------------------------
Add (x >>= f) (y >>= f) :: Expr b
noting that we keep the x
and y
in the same order.
So this while game of type-tetris is a bit long-winded but demonstrates that with a few principles:
- Definitions over recursive types tend to need recursion
- Definitions of general functions will use every argument non-trivially if possible
we can get to the answer to how to define (>>=)
almost entirely mechanically. A good definition ought to feel satisfying, like there literally is no other choice.
instance Monad Expr where
Var a >>= f = f a
Val n >>= _ = Val n
Add x y >>= f = Add (x >>= f) (y >>= f)
...
We can interpret this definition as "reacting" the f
and the lead nodes of Var
where a
s exist. If no a
exists then we're free to do nothing. If recursive Expr
s exist then we just "push the call to (>>= f)
down" and rebuild the recursive type.