How to implement tail calls in a custom VM
-
26-09-2019 - |
Question
How can I implement tail calls in a custom virtual machine?
I know that I need to pop off the original function's local stack, then it's arguments, then push on the new arguments. But, if I pop off the function's local stack, how am I supposed to push on the new arguments? They've just been popped off the stack.
Solution
I take it for granted that we're discussing a traditional "stack-based" virtual machine here.
You pop off the current function's local stack preserving the still-relevant parts in non-stack "registers" (where the "relevant parts" are, clearly, the argument for the forthcoming recursive tail call), then (once all of the function's local stack and arguments are cleaned up) you push the arguments for the recursive call. E.g., suppose the function you're optimizing is something like:
def aux(n, tot):
if n <= 1: return tot
return aux(n-1, tot * n)
which without optimization might produce byte-code symbolically like:
AUX: LOAD_VAR N
LOAD_CONST 1
COMPARE
JUMPIF_GT LAB
LOAD_VAR TOT
RETURN_VAL
LAB: LOAD_VAR N
LOAD_CONST 1
SUBTRACT
LOAD_VAR TOT
LOAD_VAR N
MULTIPLY
CALL_FUN2 AUX
RETURN_VAL
the CALL_FUN2 means "call a function with two arguments". With the optimization, it could become sometime like:
POP_KEEP 2
POP_DISCARD 2
PUSH_KEPT 2
JUMP AUX
Of course I'm making up my symbolic bytecodes as I go along, but I hope the intent is clear: POP_DISCARD n
is the normal pop that just discards the top n
entries from the stack, but POP_KEEP n
is a variant that keeps them "somewhere" (e.g. in an auxiliary stack not directly accessible to the application but only to the VM's own machinery -- storage with such a character is sometimes called "a register" when discussing VM implementation) and a matching PUSH_KEPT n
which empties the "registers" back into the VM's normal stack.
OTHER TIPS
I think you're looking at this the wrong way. Instead of popping the old variables off the stack and then pushing the new ones, simply reassign the ones already there (carefully). This is roughly the same optimization that would happen if you rewrote the code to be the equivalent iterative algorithm.
For this code:
int fact(int x, int total=1) {
if (x == 1)
return total;
return fact(x-1, total*x);
}
would be
fact:
jmpne x, 1, fact_cont # if x!=1 jump to multiply
retrn total # return total
fact_cont: # update variables for "recursion
mul total,x,total # total=total*x
sub x,1,x # x=x-1
jmp fact #"recurse"
There's no need to pop or push anything on the stack, merely reassign.
Clearly, this can be further optimized, by putting the exit condition second, allowing us to skip a jump, resulting in fewer operations.
fact_cont: # update variables for "recursion
mul total,x,total # total=total*x
sub x,1,x # x=x-1
fact:
jmpne x, 1, fact_cont # if x!=1 jump to multiply
retrn total # return total
Looking again, this "assembly" better reflects this C++, which clearly has avoided the recursion calls
int fact(int x, int total=1)
for( ; x>1; --x)
total*=x;
return total;
}