您的位置:首页 > 大数据 > 人工智能

Tail call optimization in Scala

2015-10-02 15:29 281 查看
Recursion provides us an clean way to solve some algorithm problems. But one drawback of it is the memory cost caused by recursion call stack, especially when memory is sensitive like in mobile applications.

One alternative is using a while loop instead of recursion, but while loop usually is not as easy as recursion. So most of the time, we will come up with solutions using recursion, after that will try to convert it into a while loop style algorithm. But
actually in some languages like Scala, it does provide the build-in support to compile the "special recursion" code to a while loop style. The"special recursion" here means that your recursion call should always be in a tail position, that is the caller does
nothing other than return the value of the recursive call.

Let's take the fibonacci number problem for an example. Usually our solution would be like below:

def fibonacci(x: Int): Int = {
if (x <= 2) x - 1
else fibonacci(x - 1) + fibonacci(x - 2) ------(1)
}


Here since the recursion caller in position 1 does more than just returning the recursive call result, instead there is an add operation inside, which means we need special memory to store these temporary values as well as the call stacks. But if we try to
write it in a tail-recursive way, things will be different.

def fibonacci(x: Int, acc1: Int, acc2: Int): Int = {
if (x <= 2) acc2
else fibonacci(x - 1, acc2, acc1 + acc2)  --------(2)
}


Notice that the recursion caller in position 2, it only returns the result of recursive call, in this way, Scala will help to compile tail-recursive function into a while-loop function, which is exactly what we want.
Now suppose that a super fibonacci, where f(x) = f(x-1) + f(x-2) + f(x-3), in the same approach, our code will be like this:

def sfibonacci(x: Int, acc1: Int, acc2: Int, acc3: Int): Int = {
if (x <= 3) acc3
else sfibonacci(x - 1, acc2, acc3, acc1 + acc2 + acc3)
}


Writing your recursion in a tail-call style, sometimes is a basic optimization for your program, helping to avoid the memory cost cased by the call stacks. By the way, in Scala it only works when you do the recursive call directly using the function itself,
below two cases will not work:

// case 1
val a = fibonacci _
def fibonacci(x: Int, acc1: Int, acc2: Int): Int = {
if (x <= 2) acc2
else a(x - 1, acc2, acc1 + acc2)
}

// case 2
def fibonacci2(x: Int, acc1: Int, acc2: Int): Int = {
if (x <= 2) acc2
else fun(x - 1, acc2, acc1 + acc2)
}
def fun(x: Int, acc1: Int, acc2: Int): Int = {
fibonacci2(x, acc1, acc2)
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: