Recursion¶
In C++, a function that calls itself, either directly or indirectly, is said to be recursive. For example:
void a() {
cout << "Hello!\n";
a();
}
In theory, a()
runs forever because it keeps calling itself. But, in
practice, it probably stops due to an “out of memory” run-time error. In
practice, every call to a()
uses a little bit of memory to store the
address of where the program should start running after a()
finishes.
Note
It is possible that a()
could run forever, depending on whether
or not your C++ compiler uses a performance optimization trick called
tail-call elimination.
Basically, a smart compiler can replace the call to a()
with a loop,
and then a()
really would run forever (or until you turn off your
computer).
The compiler we’re using, g++, doesn’t do tail-call elimination by default. But if you run g++
with the -O2
flag, it will try to. With -O2
, a()
really does
run forever for me!
The function b()
is a variation of a
:
void b(int n) {
cout << n << ": hello!\n";
b(n + 1); // notice n + 1 is passed as the parameter
}
This function prints how many times its called, which lets us see how many
times it can be executed until all the computers memory is used up. Notice
that the second line of the function was b(n)
, then the value of
n
wouldn’t change.
Having a function run until it crashes isn’t very useful. In this version, we
stop the function when n
is 10 or bigger:
void c(int n) {
if (n >= 10) {
// all done: do nothing
} else {
cout << n << ": hello!\n";
c(n + 1);
}
}
For example, calling c(4)
prints this:
4: hello!
5: hello!
6: hello!
7: hello!
8: hello!
9: hello!
Calling c(10)
prints nothing:
(nothing printed)
Function c
is useful, but it stopping at 10 is arbitrary and infleixble. A
better way to write it is like this:
void d(int n) {
if (n <= 0) { // n is decreasing, so check when it gets to 0 or lower
// all done: do nothing
} else {
cout << n << ": hello!\n";
d(n - 1); // subtract 1 instead of add 1
}
}
For example, calling d(5)
prints this:
5: hello!
4: hello!
3: hello!
2: hello!
1: hello!
This version of the function counts down from n
to 1, where n
is any
int
. Notice that if you call something like d(-3)
, nothing is printed.
With another small change, we can make a function that prints “hello” n times:
void say_hello(int n) {
if (n > 0) {
cout << n << ": hello!\n";
say_hello(n - 1);
}
}
Or more generally:
void say(const string& msg, int n) {
if (n > 0) {
cout << n << ": " << msg << "\n";
say(msg, n - 1);
}
}
This prints any string exactly n
times. Notice that there is no else-block
here: if n > 0
is not true, then the flow of control skips to after the
if-block, and so does nothing.
As another example, lets write a recursive function that prints the numbers
from n down to 1 on the screen. For example, count_down(5)
prints
this:
5
4
3
2
1
It’s easy to modify the say_hello
function to do this:
// prints the numbers from n down to 1
void count_down(int n) {
if (n > 0) {
cout << n << "\n";
count_down(n - 1);
}
}
Now consider the opposite problem: write a recursive function that prints the
numbers from 1 up to n. For example, count_up(5)
prints this:
1
2
3
4
5
This is a bit trickier than counting downwards.
Here is a solution:
// prints the numbers from 1 up to n
void count_up(int n) {
if (n > 0) {
count_up(n - 1);
cout << n << "\n"; // printing comes after the recursive call
}
}
The essential different between count_up
and count_down
is when the
recursive call is made. In count_down
, it’s made after printing, and in
count_up
it’s before. That one little change makes a big difference!
Of course, in practice, for-loops or while-loops would be the best way to implement any of these functions. But our goal here is to understand recursion, and so it is best to start with simple — if impractical — functions.
Recurrence Relations: Recursive Functions in Mathematics¶
Recursive functions are commonly used in mathematics. For example, consider the function \(S(n)\) defined as follows:
\(S(0) = 0\) is called the base case, and \(S(n) = n + S(n - 1)\) is the recursive case. Any useful recursive function needs at least one base case and one recursive case.
Using these two cases — which we’ll call rules — we can calculate \(S(n)\) for any non-negative integer \(n\).
\(S(0)\) is easy: it is simply 0, as defined by the base case. \(S(1)\) is a little more work:
For \(S(2)\), we apply the recursive rule a couple of times:
And \(S(3)\):
You can see the pattern: \(S(n) = n + ((n-1) + ((n-2) + ... + (2 + (1 + 0))))\). Since addition can be done in any order, this is the same as \(S(n) = 1 + 2 + ... + n\).
We can implement \(S(n)\) directly like this:
// returns 1 + 2 + ... + n (assuming n >= 0)
int S(int n) {
if (n == 0) { // base case
return 0;
} else {
return n + S(n - 1); // recursive case
}
}
Notice how similar this is to the mathematical definition of \(S(n)\). Indeed, when writing a recursive function it is often helpful to work out the cases on paper, and then translate them into code.
The base case is essential in a recursive function because it determines when the recursion stops. It plays the same role as the condition in a for- loop or while-loop.
Tracing the calls and returns made by a recursive function is often useful.
For example, here’s the trace of the call S(5)
:
S(5) entered ...
S(4) entered ...
S(3) entered ...
S(2) entered ...
S(1) entered ...
S(0) entered ...
... S(0) exited
... S(1) exited
... S(2) exited
... S(3) exited
... S(4) exited
... S(5) exited
You can see here that S
gets called exactly 6 times, and that it exits
exactly 6 times. The indentation shows which calls go with which exits; notice
that S(5)
is the first to be called but the last to exit.
See the end of this note for how to use cmpt_trace.h to generate these tracings.
Fibonacci Numbers¶
Base cases might have multiple rules. For example, the Fibonacci numbers are 1, 1, 2, 3, 5, 8, 13, …. The rule is that the first two numbers of the sequence are 1 and 1, and then after that each number is the sum of the two before it. More mathematically, we can define them like this:
Converting this definition to C++ is not too hard:
// Returns the nth Fibonacci number (assuming n > 0)
int f(int n) {
if (n == 1) { // base case
return 1;
} else if (n == 2) { // base case
return 1;
} else {
return f(n-1) + f(n-2); // recursive case
}
}
Lets try calculating \(f(5)\) by hand:
This is more work than calculating \(S(5)\) because there are two
recursive calls to \(f\) in the recursive case. For large values of n,
those two calls could cause 2 more calls each, i.e. 4 more calls. Then those 4
calls could cause 2 more calls each, i.e. 8 calls. For large values of n, the
number of recursive calls increases exponentially, which means that f
will
take a long time to calculate all but the smallest values of n
.
This is one of the problems with recursive functions: they can make a lot of function calls which, and that eats up a lot of time and memory.
Tracing f
shows a more elaborate pattern of entry/exit messages:
f(5) entered ...
f(4) entered ...
f(3) entered ...
f(2) entered ...
exited f(2)
f(1) entered ...
exited f(1)
exited f(3)
f(2) entered ...
exited f(2)
exited f(4)
f(3) entered ...
f(2) entered ...
exited f(2)
f(1) entered ...
exited f(1)
exited f(3)
exited f(5)
There are 9 calls to f
here, many of them dumbly re-calculating values
that have already been calculated.
A non-recursive function for computing Fibonacci numbers is much faster and uses much less memory:
// Returns nth Fibonacci number (assuming n > 0)
int f2(int n) {
if (n == 1 || n == 2) {
return 1;
} else {
int a = 1;
int b = 1;
int c = 0;
for (int i = 2; i < n; ++i) {
c = a + b;
a = b;
b = c;
}
return c;
}
}
A disadvantage of f2
is that it’s harder to understand. If you were given
just f2
with no explanation of what it’s about, it might take a minute or
two to realize that it computes Fibonacci numbers.
Note
The nth Fibonacci number \(f(n)\) can also be directly calculated using the non-recursive formula
where \(\phi = \frac{1 + \sqrt 5}{2}\) and \(\psi = \frac{1 - \sqrt 5}{2}\).
Recursion on Vectors¶
Suppose we want to sum the numbers in a vector. We can do that recursively as follows:
- Base case: the empty vector has sum 0.
- Recursive case: the sum of all the elements in
v
isv[0] + sum(rest(v))
; the functionrest(v)
returns a copy of the original vector with its first element removed.
This definition is precise enough that we can trace examples by hand. For instance:
To implement this in C++, we could write the rest
function like this:
// Returns a new vector w of size v.size() - 1 such that
// w[0] == v[1], w[1] == v[2], ..., w[v.size() - 2] == v[v.size() - 1].
// In other words, it returns a copy of v with the first element
// removed.
vector<int> rest(const vector<int>& v) {
vector<int> result;
for (int i = 1; i < v.size(); ++i) { // i starts at 1
result.push_back(v[i]);
}
return result;
}
Now we can write sum
as follows:
int sum1(const vector<int>& v) {
if (v.empty()) { // base case
return 0;
} else { // recursive case
return v[0] + sum1(rest(v));
}
}
This works! Unfortunately, the rest
function is extremely inefficient:
for every call to sum1
we end up making a new copy of almost the entire
passed-in vector.
A more efficient approach is to simulate rest
by re-writing sum1
to
accept begin and end parameters specifying the range of values we want
summed. Then we can efficiently access any sub-vector:
// returns v[begin] + v[begin + 1] + ... + v[end - 1]
int sum2(const vector<int>& v, int begin, int end) {
if (begin >= end) {
return 0;
} else {
return v[begin] + sum2(v, begin + 1, end);
}
}
// returns the sum of all the elements in v
int sum2(const vector<int>& v) {
return sum2(v, 0, v.size());
}
Adding extra parameters in this way is a standard trick when writing recursive
functions. Notice that we don’t even need to include end
in this case,
because it never changes. So we could have just written this:
// returns v[begin] + v[begin + 1] + ... + v[end - 1]
int sum3(const vector<int>& v, int begin) {
if (begin >= v.size()) {
return 0;
} else {
return v[begin] + sum3(v, begin + 1);
}
}
// returns the sum of all the elements in v
int sum3(const vector<int>& v) {
return sum3(v, 0);
}
Here’s another example. Suppose we want a recursive function that returns
true
if all the numbers in a vector are even, and false
otherwise:
// Pre-condition:
// all ints in v are >= 0
// Post-condition:
// If v[0], v[1], ... v[n-1] are all even (n is v's size),
// true is returned. Otherwise, false is returned.
// If v is empty, true is returned
bool all_even(const vector<int>& v) {
// ...
}
The recursive idea for implementing this function is the same as for the
sum
function: we check if the first number is even, and then recursively
call all_even
to check that the rest of the numbers are even. As with
sum
, we will use a helper function with an extra parameter to keep track
of the sub-vector of v
that is being processed:
// Pre-condition:
// begin >= 0
// all ints in v are >= 0
// Post-condition:
// returns true if v[begin], v[begin+1], ... v[n-1] are all even,
// where n is the size of v; false otherwise
bool all_even(const vector<int>& v, int begin) {
if (begin >= v.size()) {
return true;
} else if (v[begin] % 2 == 0) {
return all_even(v, begin + 1);
} else {
return false;
}
}
The first condition of the if-statement checks if the sub-vector being
processed is empty. If begin
is equal to the size of the vector, or is
greater than the size, we consider that to be an empty vector, and so return
true.
The next condition checks if the first element of the sub-vector is even.
Since begin
marks the start of the vector, v[begin]
is that start of
the vector (not v[0]
!). If the first element is indeed even, then
all_even
is called on the rest of the vector, i.e. from location begin +
1
onwards.
Finally, the else part of the if-statement occurs just when v[begin]
is
odd. In that case, we no that the entire vector cannot be all even, and so
false
is returned immediately.
For convenience, we provide a function that doesn’t require a starting index:
bool all_even(const vector<int>& v) {
return all_even(v, 0);
}
Functions like this to hide some of the arguments are quite common when writing recursive functions. Testing if an entire vector is all even is probably the most common case, and so it is helpful to make this common case easier, and less error-prone, to use.
Finally, note how we specified the behaviour of the function using a pre-condition and a post-condition. A pre-condition for a function states exactly what must be true before the function is called. If a the pre-condition is not true, then the function might not work properly. It’s the responsibility of the code called before the function to ensure the pre-condition is true. The post-condition states what will be true after the function finishes (assuming the pre-condition was true when it was called).
Together, pre-conditions and post-conditions have proven to be a very good way to precisely define the behaviour of functions, and so we will use them more and more.
Why Recursion?¶
Many students wonder why we teach recursion. A lot of professional programmers would have a hard time pointing to even a single example of where they have used recursion outside of school. In practice, iteration, i.e. loops, are far more common than recursion. However, in theory, recursion is one the most important ideas in all of computer science.
Some of the benefits to learning recursion are:
- For a few kinds of algorithms, such as parsers and the (important!) sorting algorithms quicksort and mergesort, recursion is the most common implementation method. Non-recursive versions of these algorithms are usually harder to understand and implement.
- There are some programming languages, such as Haskell and Erlang, that have no loops! You have no choice but to use recursion, or functions that are based on recursion.
- It is often easier to reason about a recursive function than an iterative one. Recursive functions are often mathematical definitions in disguise, and so you may be able to use that mathematics to help better understand your function’s correctness, performance, or memory usage.
- Recursive functions often result in source code that is shorter, simpler, and more elegant than non-recursive functions. This can make your programs more readable, and less likely to have bugs (bugs love to hide in hard-to-read code).
- Recursion plays a fundamental role in theoretical computer science. Recursive functions can be used as the basis for all computation, e.g. any loop can be translated into a recursive function that does the same thing.
In practice, recursion is probably best thought of as one of many tools a programmer can use to solve programming problems. Use it when it makes sense, and avoid it when it doesn’t.
Extra: Recursive Linear Search¶
Linear search is an algorithm that takes as input a vector v
and a target
element x
and returns either the index location of x
in v
, or -1
if x
is not in v
. If x
occurs more than once, then we’ll return
the location of the first x
.
Using begin and end values, we specify the range we want to search like this:
// Returns i such that v[i] == x and begin <= i < end;
// otherwise returns -1 if x is not in v.
int linear_search(const vector<int>& v, int x, int begin, int end)
A recursive implementation of linear_search
goes like this:
- Base case 1: if
v
is empty, return -1 (x
not found) - Base case 2: if
v[begin] == x
, returnbegin
(x
found at locationbegin
) - Recursive case: return the value of
linear_search(x, v, begin + 1, end)
Here’s a C++ implementation:
// Returns i such that v[i] == x and begin <= i < end;
// otherwise returns -1 if x is not in v.
int linear_search(int x, const vector<int>& v, int begin, int end) {
if (begin >= end) { // base case 1: range is empty
return -1;
} else if (v[begin] == x) { // base case 2: x is at the front of v
return begin;
} else { // recursive case
return linear_search(x, v, begin + 1, end); // note it's begin + 1
}
}
int linear_search(int x, const vector<int>& v) {
return linear_search(x, v, 0, v.size());
}
Extra: Recursive Binary Search¶
If you have a vector that is in sorted order, i.e. the numbers are arranged from biggest to smallest, then binary search is an extremely efficient way to search for a target value. The idea is straightforward: check to see if the middle element equals the target value. If it does, we’re done. Otherwise, re-do the binary search either on the left half of the vector or the right half, depending on whether the target was bigger or smaller than the middle element.
Here’s a recursive implementation of binary search:
// Pre-condition:
// v[begin] <= v[begin+1] <= ... <= v[end - 1]
// Post-condition:
// Returns index i such that v[i] == x; otherwise, returns -1
// if x is not in v.
int binary_search(int x, const vector<int>& v, int begin, int end) {
if (begin >= end) { // base case 1: range is empty
return -1;
} else {
int mid = (begin + end) / 2;
if (x == v[mid]) { // base case 2: x is found
return mid;
} else if (x < v[mid]) { // recursive case 1: search left half
return binary_search(x, v, begin, mid);
} else if (x > v[mid]) { // recursive case 2: search right half
return binary_search(x, v, mid + 1, end);
}
}
}
int binary_search(int x, const vector<int>& v) {
return binary_search(x, v, 0, v.size());
}
Be careful implementing binary search! It has a lot of little details that you must get right. Test it thoroughly!
Extra: Recursive Acronyms¶
Recursive acronyms are a fun example of recursion. Perhaps you’ve heard of the GNU project (they make g++, and much other software). GNU is an acronym that expands like this:
GNU = GNU's Not Unix
If you keep replacing GNU with its expansion you get an infinitely long string:
GNU = GNU's Not Unix
= (GNU's Not Unix)'s Not Unix
= ((GNU's Not Unix)'s Not Unix)'s Not Unix
= (((GNU's Not Unix)'s Not Unix)'s Not Unix)'s Not Unix
= ...
There is no base case, so it expands forever.
Here are a few more recursive acronyms:
YOPY = Your Own Personal YOPY
LAME = LAME Ain’t an MP3 Encoder
These acronyms are co-recursive:
- HURD = Hird of Unix-Replacing Daemons
- HIRD = Hurd of Interfaces Representing Depth
MOMS is doubly recursive and so expands very quickly:
MOMS = MOMS Offering MOMS Support = MOMS Offering MOMS Support Offering MOMS Offering MOMS Support Support = MOMS Offering MOMS Support Offering MOMS Offering MOMS Support Support Offering MOMS Offering MOMS Support Offering MOMS Offering MOMS Support Support Support = ...
Here is a function that expands MOMS:
const string acronym = "MOMS";
const string expansion = "MOMS Offering MOMS Support";
string expand_MOMS(int n) {
if (n == 0) {
return "";
} else if (n == 1) {
return expansion;
} else { // n > 1
string prev = expand_MOMS(n - 1);
// replace every occurrence of "MOMS" with its expansion
int i = prev.find(acronym);
while (i != string::npos) {
prev.replace(i, acronym.size(), expansion);
// search for next acronym after the current one
i += expansion.size(); // skip over the just-replaced string
// so the acronyms within it are not expanded
i = prev.find(acronym, i);
}
return prev;
} // if
}
Recursion Practice Questions¶
Implement the following using recursion (and no loops, or library functions that do the hard part). For some questions, you may want to create a helper function with extra parameters. Sample solutions are here.
- The product of the integers from 1 to \(n\), i.e. the factorial function \(n! = 1 \cdot 2 \cdot \ldots \cdot n\), for \(n \geq 0\). Note that \(0! = 1\).
- The sum of the first n squares, i.e. \(1^2 + 2^2 + \ldots + n^2\), for \(n \geq 0\).
- Print the numbers from n down to 1 on
cout
, one number per line. Assume \(n \geq 1\). - Print the numbers from 1 up to n on
cout
, one number per line. Assume \(n \geq 1\). - The sum of just the positive numbers in a vector. For example, the sum of
the positive numbers in
{1, -3, -2, 6}
is 7. - The number of times
x
occurs in a vector. For example, in{5, 2, 1, 5, 5}
, 5 occurs three times. - Print the elements of a
vector<int>
, one number per line. - Find the biggest number in a
vector<int>
. - Write a function similar to
expand_MOMS(n)
that expands, n times, the acronym YOPY = Your Own Personal YOPY.
A Tracing Class¶
cmpt_trace.h contains a helpful class called cmpt::Trace
that can be used
to log when a functions is called and when it exits. For example, you can use
it to print a trace the Fibonacci function as follows:
#include "cmpt_trace.h"
int f(int n) {
cmpt::Trace trace{"f(" + to_string(n) + ")"};
if (n == 1) { // base case
return 1;
} else if (n == 2) { // base case
return 1;
} else {
return f(n-1) + f(n-2); // recursive case
}
}