Tail call optimization in Nix (today)

The following is some kind of reflection on Nix code beyond stdenv.mkDerivation - or general purpose Nix if you will. What I present is not necessarily practical, but may be interesting.

With Nix being a purely functional programming language, recursion is the obvious way to solve most problems. However, since the Nix evaluator doesn’t optimize, handling arbitrarily big input can be an issue:

  • Stack size is limited, so recursion can’t be arbitrarily deep. This situation is worsened that thunks reserve a lot of space on the stack. This can be seen easily by comparing builtins.valueSize {} and let foo = {}; in builtins.deepSeq foo (builtins.valueSize foo). The available workaround, deepSeqing at every recursion step, sadly kills any semblance of performance.
  • Since Nix is pure and non-optimizing, any operation on data structures copies the data structure, even if it would be possible to do it in place. Thus extending / updating lists and attribute sets is quite expensive.

As a result, the art of Nix programming is (ab)using builtins to the greatest extent possible, since these are implemented in C++ and thus not constrained by the semantics of the Nix language, specifically, some things may happen in place and builtins like foldl' will rely on loops over recursion, eliminating the stack overflow issue.

One very powerful, but undocumented such primop is builtins.genericClosure which was initially conceived to compute the dependency of, say, a package. It receives a startSet (a list of attribute set whose elements all need to have an unique key attribute) and an operator (a function that receives one such element and returns a list of new such elements). The operator is then repeatedly applied to the elements of the growing startSet until no new elements are produced (operator is only applied to each unique element once). There’s also a nice usage example in a Nix issue (concerning its documentation) which may help you get a sense of its workings. Apart from its intended use (computing a dependency closure of sorts), it can be used for basically any recursive algorithm that produces a list in the end, e.g. using it made my pure Nix UTF-8 decoder much more efficient (where much is an understatement).

It’s relatively obvious that we can in fact express any recursive algorithm in terms of builtins.genericClosure. But instead of leaving it as the proverbial exercise to the reader, we can show it for real by writing a Nix function that converts a recursive function into one that uses builtins.genericClosure. This will also constitute (semi-manual) tail call optimization for Nix, hence this post’s title.

For this we’ll have to impose the following restrictions on the function we’ll take as an input:

  1. We’ll need to inject some custom code to break up the normal recursion of the function, so we require that the self reference necessary for the recursion is realized via a fixed point we can manipulate (as opposed to Nix’s lexical scope).
  2. Also, the function needs to be tail recursive (i. e. for its return value, the outermost call must be to itself or a value that doesn’t depend on the function’s self reference), otherwise we can’t really inject anything useful to us.

So we may need to rewrite our recursive functions a bit, for example:

  # instead of
  facSimple = n: if n == 0 then 1 else n * facSimple (n - 1);

  # we'll use a fixed point and tail recursion
  fac = self: acc: n: if n == 0 then acc else self (n * acc) (n - 1);

# Compute 15! without any kind of optimization
lib.fix fac 1 15

The basic idea for the optimization is to pass a fake version of the function that merely captures and returns all information pertaining to the call.

  fakeFac = acc: n: { __tailCall = true; args = [ acc n ]; };

fac fakeFac 1 15
# => { __tailCall = true; args = [ 15 14 ]; }

fac fakeFac 1307674368000 0
# => 1307674368000

As you can see, if the function tries to recurse, it just returns an attribute set containing the arguments it wanted to pass. When we reach the base case, however, it returns normally.

This will allow us to limit the amount of recursion steps possible to a single one when calling the function in operator. The next recursion step would then be performed in the next invocation of operator, using the arguments we just captured. The __tailCall attribute is an imperfect way to distinguishing an ordinary return value from a captured tail call.

This should already allow us to wire something up with builtins.genericClosure, but it’s not really pretty that we have to set up fakeFac manually, so we’ll take a small detour and figure out some hacks in order to do this automatically.

First of all, we’ll need to figure out how many arguments a function receives. We’ll do that by feeding it builtins.throw until it starts to explode:

  argCount = f:
      # N.B. since we are only interested if the result of calling is a function
      # as opposed to a normal value or evaluation failure, we never need to
      # check success, as value will be false (i.e. not a function) in the
      # failure case.
      called = builtins.tryEval (
        f (builtins.throw "You should never see this error message")
    if !(builtins.isFunction f || builtins.isFunction (f.__functor or null))
    then 0
    else 1 + argCount called.value;

argCount (lib.fix fac)
# => 2

Then we’ll need a way to make a function that collects its n arguments into a list. I’ve called this unapply because of a) a lack the imagination on my part and b) the fact that it can be used to form the identity relation together with apply (which I’ll show later):

  unapply =
      unapply' = acc: n: f: x:
        if n == 1
        then f (acc ++ [ x ])
        else unapply' (acc ++ [ x ]) (n - 1) f;
    unapply' [ ];

unapply 3 lib.id 1 2 3
# => [ 1 2 3 ]

This allows us to generate the fakeFac function we had above programmatically:

unapply (argCount (lib.fix fac)) (args: {
  __tailCall = true;
  inherit args;

apply can be used to call a function with such a list of arguments - which we’ll have to do eventually:

  apply = f: args: builtins.foldl' (f: x: f x) f args;

apply builtins.sub [ 10 20 ]
# => -10

Now it’s time to unceremoniously paste the entire monstrosity that puts everything together. I’ve tried to comment the code a bit with explanations, but feel free to ask about anything unclear below.

  tailCallOpt = f:
      argc = argCount (lib.fix f);

      # This function simulates being f for f's self reference. Instead of
      # recursing, it will just return the arguments received as a specially
      # tagged set, so the recursion step can be performed later.
      fakef = unapply argc (args: {
        __tailCall = true;
        inherit args;
      # Pass fakef to f so that it'll be called instead of recursing, ensuring
      # only one recursion step is performed at a time.
      encodedf = f fakef;

      # This is the main function, implementing the “optimized” recursion
      opt = args:
          steps = builtins.genericClosure {
            # This is how we encode a (tail) call: A set with final == false
            # and the list of arguments to pass to be found in args.
            startSet = [
                key = "0";
                id = 0;
                final = false;
                inherit args;

            operator =
              { id, final, ... }@state:
                # Generate a new, unique key to make genericClosure happy
                newIds = {
                  key = toString (id + 1);
                  id = id + 1;

                # Perform recursion step
                call = apply encodedf state.args;

                # If call encodes a new call, return the new encoded call,
                # otherwise signal that we're done.
                newState =
                  if builtins.isAttrs call && call.__tailCall or false
                  then newIds // {
                    final = false;
                    inherit (call) args;
                  } else newIds // {
                    final = true;
                    value = call;

              if final
              then [ ] # end condition for genericClosure
              else [ newState ];
        # The returned list contains intermediate steps we need to ignore
        (builtins.head (builtins.filter (x: x.final) steps)).value;
    # make it look like a normal function again
    unapply argc opt;

tailCallOpt fac 1 15
# => 1307674368000

Okay, so we have something that behaves just as (lib.fix fac) does - but does it help us anywhere beyond that? In terms of stack overflows, fac is not a good example, since integers take very little space on the stack, so the returned integer will overflow long before the stack.

So let’s take a function that has something big on the stack (i. e. a thunk):

  emphasize = self: acc: before: after: n:
    if n == 0
    then before + " " + acc + after
    else self (acc + "very ") before after (n - 1);

lib.fix emphasize "" "I like Nix" "much" 2
# => "I like Nix very very much"

How does it measure up against its optimized version? Well, let’s see:

  # pass the starting accumulator already for convenience
  emphPlain = lib.fix emphasize "";
  emphOpt = tailCallOpt emphasize "";

emphPlain "This is" "cursed" 20 == emphOpt "This is" "cursed" 20
# => true 

emphPlain "This is" "cursed" 10000
# error: stack overflow (possible infinite recursion)

emphOpt "This is" "cursed" 10000
# => "This is very very very very very very very very very very very very…

So there it is, tail call optimization in Nix! To return to the initial point about practicability, however, it is pretty much useless. In most situation where this would help you, the “optimized” algorithm would probably be dealing with an attribute set or list and be slowed down incredibly by the constant copying involved with list/set operations. Additionally, any hand rolled use of builtins.genericClosure is probably considerably faster, as it would save on a lot of function applications and intermediate data structuers for a large enough n. It was a fun exercise, though, and hopefully interesting to you as well.

As an additional note, it is still possible to overflow the stack despite tailCallOpt if you pick a big enough n: My theory is that the thunk can just grow that much that it’s capable of overflowing the available stack space without any recursion. You can work around this by using builtins.deepSeq on your accumulator argument and it’s even possible to make a strict version of tailCallOpt that does this for you. I’ve not toyed with this, since in practice I’d just run out of memory before the computation finished. Not sure if there’s a way past that with the right GC settings.