Playing with monads in Python, 1

There are a number of topics in computer science that I have on my list of `stuff to understand at some point'. One of them is monads, which originally perplexed me when I came across Peter Thatcher's article about a year ago.

My favorite way to structure the code is to start with a base class Monad, which requires any subclass (type of monad we're defining, that is) to define methods bind and unit.
from abc import abstractmethod
from types import FunctionType as Function


class Monad:

    @abstractmethod
    def bind(self, f: Function) -> 'Monad':
        raise NotImplementedError

    def __rshift__(self, f: Function) -> 'Monad':
        return self.bind(f)

    @classmethod
    @abstractmethod
    def unit(cls, *args) -> 'Monad':
        raise NotImplementedError
As Thatcher wrote in his own implementation, I've overloaded the special method __rshift__ to mean bind, which gets us closer to Haskell's syntax (>>=).

I liked Boyer's SimpleMonad [1], so we'll play with that. All it basically does it store a value. In the following I've also added the special method __eq__ to compare values between two instances of SimpleMonad, and a getter.
class SimpleMonad(Monad):
    """ A really simple monad """
    def __init__(self, x: int) -> None:
        self.x = x

    def bind(self, f: Function) -> 'SimpleMonad':
        return f(self.x)

    @classmethod
    def unit(cls, value: int) -> 'SimpleMonad':
        return cls(value)

    def get_value(self) -> int:
        return self.x

    def __eq__(self, other: 'SimpleMonad') -> bool:
        """ Only for tests """
        return self.get_value() == other.get_value()
Now we can define two example actions as follows.
def xor(x: int) -> SimpleMonad:
    return SimpleMonad(x ^ 2)

def digits(x: int) -> SimpleMonad:
    return SimpleMonad(len(str(x)))
We can operate upon an integer in the monadic context using bind
value1 = SimpleMonad(5).bind(xor).bind(digits)
print(value1.get_value())  # 1
and we can also write this as
value2 = SimpleMonad(7) >> xor >> digits
print(value2.get_value())  # 1
Most sources I've read have mentioned that in order for a type to be a monadic type it must satisfy certain ``laws'', but they don't usually list or discuss them. However, for my own understanding, I wanted to show that SimpleMonad is indeed a monad, so I wrote the following tests.
from typing import Any


def check_laws(monad: Any, value: Any, fun1: Function, fun2: Function) -> None:
    name = monad.__name__

    # Left identity
    assert monad.unit(value) >> fun1 == fun1(value), \
        name + ' does not satisfy left identity axiom'

    # Right identity
    assert monad(value) >> monad.unit == monad(value), \
        name + ' does not satisfy right identity axiom'

    # Associativity
    assert ((monad(value) >> fun1) >> fun2) == (monad(value) >> (lambda v: fun1(v) >> fun2)), \
        name + ' does not satisfy associativity'

    print('** All tests passed for ' + name)


if __name__ == '__main__':
    check_laws(SimpleMonad, 6, lambda x: SimpleMonad(x + 6), lambda x: SimpleMonad(x - 1))

    # ** All tests passed for SimpleMonad
EDIT: I also want to talk a little about another monad I use every day: the list monad. I came across an excellent article on a wiki about Haskell that describes this monad and the types of bind and unit, and wrote the following class (vide quoque multiple inheritance).
from typing import Iterable

class ListMonad(list, Monad):

    ## Define `bind` and `unit` for `list` type so it can be passed to `check_laws`

    def bind(self, f: Function) -> 'ListMonad':
        for i, value in enumerate(self):
            self[i] = f(value)[0]    # basically `concat (map f xs)` in Haskell
        return self

    @classmethod
    def unit(cls, item: Iterable) -> 'ListMonad':
        return cls([item])
Now, actions as follows can act upon the list's elements.
def double(x: int) -> 'ListMonad':
    return ListMonad([2 * x])


def divider(x: int) -> 'ListMonad':
    return ListMonad([x // 2])
And again, this list doesn't have to contain objects of type int, it could be any type, and these functions could perform any operation upon that type. If I've done everything correctly, the following indicates this is indeed a monad as well.
check_laws(ListMonad, list(range(1)), lambda x: ListMonad([2 * x]), lambda x: ListMonad([x + 1]))

# ** All tests passed for ListMonad
I've also come across an interesting thread on SO discussing Python's context managers as monads; that is, Python's with-statement. However, some answers suggest it may be too general to be a monad. It's an interesting thought, however, because in Haskell IO isn't possible without monads. It's obvious that Python is not a pure functional language, but there seem to be blurred boundaries when it comes to many concepts.

Additional Resources

[1] Super Quick Intro to Monads, by Stephen Boyer
[2] Monads and Gonads, Google Tech Talk by Douglas Crockford