Recursion

In C++, a function that calls itself, either directly or indirectly, is said to be recursive. For example:

void a() {
  cout << "Hello!\n";
  a();
}

In theory, a() runs forever because it keeps calling itself. But, in practice, it probably stops due to an “out of memory” run-time error. In practice, every call to a() uses a little bit of memory to store the address of where the program should start running after a() finishes.

Note

It is possible that a() could run forever, depending on whether or not your C++ compiler uses a performance optimization trick called tail-call elimination. Basically, a smart compiler can replace the call to a() with a loop, and then a() really would run forever (or until you turn off your computer).

The compiler we’re using, g++, doesn’t do tail-call elimination by default. But if you run g++ with the -O2 flag, it will try to. With -O2, a() really does run forever for me!

The function b() is a variation of a:

void b(int n) {
  cout << n << ": hello!\n";
  b(n + 1);  // notice n + 1 is passed as the parameter
}

This function prints how many times its called, which lets us see how many times it can be executed until all the computers memory is used up. Notice that the second line of the function was b(n), then the value of n wouldn’t change.

Having a function run until it crashes isn’t very useful. In this version, we stop the function when n is 10 or bigger:

void c(int n) {
  if (n >= 10) {
    // all done: do nothing
  } else {
    cout << n << ": hello!\n";
    c(n + 1);
  }
}

For example, calling c(4) prints this:

4: hello!
5: hello!
6: hello!
7: hello!
8: hello!
9: hello!

Calling c(10) prints nothing:

(nothing printed)

Function c is useful, but it stopping at 10 is arbitrary and infleixble. A better way to write it is like this:

void d(int n) {
  if (n <= 0) {  // n is decreasing, so check when it gets to 0 or lower
    // all done: do nothing
  } else {
    cout << n << ": hello!\n";
    d(n - 1);   // subtract 1 instead of add 1
  }
}

For example, calling d(5) prints this:

5: hello!
4: hello!
3: hello!
2: hello!
1: hello!

This version of the function counts down from n to 1, where n is any int. Notice that if you call something like d(-3), nothing is printed.

With another small change, we can make a function that prints “hello” n times:

void say_hello(int n) {
  if (n > 0) {
    cout << n << ": hello!\n";
    say_hello(n - 1);
  }
}

Or more generally:

void say(const string& msg, int n) {
  if (n > 0) {
    cout << n << ": " << msg << "\n";
    say(msg, n - 1);
  }
}

This prints any string exactly n times. Notice that there is no else-block here: if n > 0 is not true, then the flow of control skips to after the if-block, and so does nothing.

As another example, lets write a recursive function that prints the numbers from n down to 1 on the screen. For example, count_down(5) prints this:

5
4
3
2
1

It’s easy to modify the say_hello function to do this:

// prints the numbers from n down to 1
void count_down(int n) {
  if (n > 0) {
    cout << n << "\n";
    count_down(n - 1);
  }
}

Now consider the opposite problem: write a recursive function that prints the numbers from 1 up to n. For example, count_up(5) prints this:

1
2
3
4
5

This is a bit trickier than counting downwards.

Here is a solution:

// prints the numbers from 1 up to n
void count_up(int n) {
  if (n > 0) {
    count_up(n - 1);
    cout << n << "\n";  // printing comes after the recursive call
  }
}

The essential different between count_up and count_down is when the recursive call is made. In count_down, it’s made after printing, and in count_up it’s before. That one little change makes a big difference!

Of course, in practice, for-loops or while-loops would be the best way to implement any of these functions. But our goal here is to understand recursion, and so it is best to start with simple — if impractical — functions.

Recurrence Relations: Recursive Functions in Mathematics

Recursive functions are commonly used in mathematics. For example, consider the function \(S(n)\) defined as follows:

\[\begin{split}\textrm{(base case)} \;\;\; S(0) &= 0 \\ \textrm{(recursive case)} \;\;\; S(n) &= n + S(n - 1), \;\;\; n > 0\end{split}\]

\(S(0) = 0\) is called the base case, and \(S(n) = n + S(n - 1)\) is the recursive case. Any useful recursive function needs at least one base case and one recursive case.

Using these two cases — which we’ll call rules — we can calculate \(S(n)\) for any non-negative integer \(n\).

\(S(0)\) is easy: it is simply 0, as defined by the base case. \(S(1)\) is a little more work:

\[\begin{split}S(1) &= 1 + S(0) \\ &= 1 + 0 \\ &= 1\end{split}\]

For \(S(2)\), we apply the recursive rule a couple of times:

\[\begin{split}S(2) &= 2 + S(1) \\ &= 2 + (1 + S(0)) \\ &= 2 + (1 + 0) \\ &= 3\end{split}\]

And \(S(3)\):

\[\begin{split}S(3) &= 3 + S(2) \\ &= 3 + (2 + S(1)) \\ &= 3 + (2 + (1 + S(0))) \\ &= 3 + (2 + (1 + 0)) \\ &= 6\end{split}\]

You can see the pattern: \(S(n) = n + ((n-1) + ((n-2) + ... + (2 + (1 + 0))))\). Since addition can be done in any order, this is the same as \(S(n) = 1 + 2 + ... + n\).

We can implement \(S(n)\) directly like this:

// returns 1 + 2 + ... + n (assuming n >= 0)
int S(int n) {
  if (n == 0) {            // base case
    return 0;
  } else {
    return n + S(n - 1);   // recursive case
  }
}

Notice how similar this is to the mathematical definition of \(S(n)\). Indeed, when writing a recursive function it is often helpful to work out the cases on paper, and then translate them into code.

The base case is essential in a recursive function because it determines when the recursion stops. It plays the same role as the condition in a for- loop or while-loop.

Tracing the calls and returns made by a recursive function is often useful. For example, here’s the trace of the call S(5):

S(5) entered ...
 S(4) entered ...
  S(3) entered ...
   S(2) entered ...
    S(1) entered ...
     S(0) entered ...
      ... S(0) exited
     ... S(1) exited
    ... S(2) exited
   ... S(3) exited
  ... S(4) exited
 ... S(5) exited

You can see here that S gets called exactly 6 times, and that it exits exactly 6 times. The indentation shows which calls go with which exits; notice that S(5) is the first to be called but the last to exit.

See the end of this note for how to use cmpt_trace.h to generate these tracings.

Fibonacci Numbers

Base cases might have multiple rules. For example, the Fibonacci numbers are 1, 1, 2, 3, 5, 8, 13, …. The rule is that the first two numbers of the sequence are 1 and 1, and then after that each number is the sum of the two before it. More mathematically, we can define them like this:

\[\begin{split}f(1) &= 1 \\ f(2) &= 1 \\ f(n) &= f(n-1) + f(n-2), \;\;\; n > 2\end{split}\]

Converting this definition to C++ is not too hard:

// Returns the nth Fibonacci number (assuming n > 0)
int f(int n) {
  if (n == 1) {               // base case
    return 1;
  } else if (n == 2) {        // base case
    return 1;
  } else {
    return f(n-1) + f(n-2);   // recursive case
  }
}

Lets try calculating \(f(5)\) by hand:

\[\begin{split}f(5) &= f(4) + f(3) \\ &= (f(3) + f(2)) + (f(2) + f(1)) \\ &= ((f(2) + f(1)) + f(2)) + (f(2) + f(1)) \\ &= ((1 + 1) + 1) + (1 + 1) \\ &= 5\end{split}\]

This is more work than calculating \(S(5)\) because there are two recursive calls to \(f\) in the recursive case. For large values of n, those two calls could cause 2 more calls each, i.e. 4 more calls. Then those 4 calls could cause 2 more calls each, i.e. 8 calls. For large values of n, the number of recursive calls increases exponentially, which means that f will take a long time to calculate all but the smallest values of n.

This is one of the problems with recursive functions: they can make a lot of function calls which, and that eats up a lot of time and memory.

Tracing f shows a more elaborate pattern of entry/exit messages:

f(5) entered ...
 f(4) entered ...
  f(3) entered ...
   f(2) entered ...
   exited f(2)
   f(1) entered ...
   exited f(1)
  exited f(3)
  f(2) entered ...
  exited f(2)
 exited f(4)
 f(3) entered ...
  f(2) entered ...
  exited f(2)
  f(1) entered ...
  exited f(1)
 exited f(3)
exited f(5)

There are 9 calls to f here, many of them dumbly re-calculating values that have already been calculated.

A non-recursive function for computing Fibonacci numbers is much faster and uses much less memory:

// Returns nth Fibonacci number (assuming n > 0)
int f2(int n) {
  if (n == 1 || n == 2) {
    return 1;
  } else {
    int a = 1;
    int b = 1;
    int c = 0;
    for (int i = 2; i < n; ++i) {
      c = a + b;
      a = b;
      b = c;
    }
    return c;
  }
}

A disadvantage of f2 is that it’s harder to understand. If you were given just f2 with no explanation of what it’s about, it might take a minute or two to realize that it computes Fibonacci numbers.

Note

The nth Fibonacci number \(f(n)\) can also be directly calculated using the non-recursive formula

\[f(n) = \frac{\phi^n - \psi^n}{\sqrt 5}\]

where \(\phi = \frac{1 + \sqrt 5}{2}\) and \(\psi = \frac{1 - \sqrt 5}{2}\).

Recursion on Vectors

Suppose we want to sum the numbers in a vector. We can do that recursively as follows:

  • Base case: the empty vector has sum 0.
  • Recursive case: the sum of all the elements in v is v[0] + sum(rest(v)); the function rest(v) returns a copy of the original vector with its first element removed.

This definition is precise enough that we can trace examples by hand. For instance:

\[\begin{split}sum(\{8, 1, 4, 2\}) &= 8 + sum(\{1, 4, 2\}) \\ &= 8 + 1 + sum(\{4, 2\}) \\ &= 8 + 1 + 4 + sum({2}) \\ &= 8 + 1 + 4 + 2 + sum(\{\}) \\ &= 8 + 1 + 4 + 2 + 0 \\ &= 15\end{split}\]

To implement this in C++, we could write the rest function like this:

// Returns a new vector w of size v.size() - 1 such that
// w[0] == v[1], w[1] == v[2], ..., w[v.size() - 2] == v[v.size() - 1].
// In other words, it returns a copy of v with the first element
// removed.
vector<int> rest(const vector<int>& v) {
   vector<int> result;
   for (int i = 1; i < v.size(); ++i) {  // i starts at 1
      result.push_back(v[i]);
   }
   return result;
}

Now we can write sum as follows:

int sum1(const vector<int>& v) {
   if (v.empty()) {  // base case
      return 0;
   } else {  // recursive case
      return v[0] + sum1(rest(v));
   }
}

This works! Unfortunately, the rest function is extremely inefficient: for every call to sum1 we end up making a new copy of almost the entire passed-in vector.

A more efficient approach is to simulate rest by re-writing sum1 to accept begin and end parameters specifying the range of values we want summed. Then we can efficiently access any sub-vector:

// returns v[begin] + v[begin + 1] + ... + v[end - 1]
int sum2(const vector<int>& v, int begin, int end) {
  if (begin >= end) {
    return 0;
  } else {
    return v[begin] + sum2(v, begin + 1, end);
  }
}

// returns the sum of all the elements in v
int sum2(const vector<int>& v) {
  return sum2(v, 0, v.size());
}

Adding extra parameters in this way is a standard trick when writing recursive functions. Notice that we don’t even need to include end in this case, because it never changes. So we could have just written this:

// returns v[begin] + v[begin + 1] + ... + v[end - 1]
int sum3(const vector<int>& v, int begin) {
  if (begin >= v.size()) {
    return 0;
  } else {
    return v[begin] + sum3(v, begin + 1);
  }
}

// returns the sum of all the elements in v
int sum3(const vector<int>& v) {
  return sum3(v, 0);
}

Here’s another example. Suppose we want a recursive function that returns true if all the numbers in a vector are even, and false otherwise:

// Pre-condition:
//     all ints in v are >= 0
// Post-condition:
//     If v[0], v[1], ... v[n-1] are all even (n is v's size),
//     true is returned. Otherwise, false is returned.
//     If v is empty, true is returned
bool all_even(const vector<int>& v) {
  // ...
}

The recursive idea for implementing this function is the same as for the sum function: we check if the first number is even, and then recursively call all_even to check that the rest of the numbers are even. As with sum, we will use a helper function with an extra parameter to keep track of the sub-vector of v that is being processed:

// Pre-condition:
//     begin >= 0
//     all ints in v are >= 0
// Post-condition:
//     returns true if v[begin], v[begin+1], ... v[n-1] are all even,
//     where n is the size of v; false otherwise
bool all_even(const vector<int>& v, int begin) {
  if (begin >= v.size()) {
    return true;
  } else if (v[begin] % 2 == 0) {
    return all_even(v, begin + 1);
  } else {
    return false;
  }
}

The first condition of the if-statement checks if the sub-vector being processed is empty. If begin is equal to the size of the vector, or is greater than the size, we consider that to be an empty vector, and so return true.

The next condition checks if the first element of the sub-vector is even. Since begin marks the start of the vector, v[begin] is that start of the vector (not v[0]!). If the first element is indeed even, then all_even is called on the rest of the vector, i.e. from location begin + 1 onwards.

Finally, the else part of the if-statement occurs just when v[begin] is odd. In that case, we no that the entire vector cannot be all even, and so false is returned immediately.

For convenience, we provide a function that doesn’t require a starting index:

bool all_even(const vector<int>& v) {
  return all_even(v, 0);
}

Functions like this to hide some of the arguments are quite common when writing recursive functions. Testing if an entire vector is all even is probably the most common case, and so it is helpful to make this common case easier, and less error-prone, to use.

Finally, note how we specified the behaviour of the function using a pre-condition and a post-condition. A pre-condition for a function states exactly what must be true before the function is called. If a the pre-condition is not true, then the function might not work properly. It’s the responsibility of the code called before the function to ensure the pre-condition is true. The post-condition states what will be true after the function finishes (assuming the pre-condition was true when it was called).

Together, pre-conditions and post-conditions have proven to be a very good way to precisely define the behaviour of functions, and so we will use them more and more.

Why Recursion?

Many students wonder why we teach recursion. A lot of professional programmers would have a hard time pointing to even a single example of where they have used recursion outside of school. In practice, iteration, i.e. loops, are far more common than recursion. However, in theory, recursion is one the most important ideas in all of computer science.

Some of the benefits to learning recursion are:

  • For a few kinds of algorithms, such as parsers and the (important!) sorting algorithms quicksort and mergesort, recursion is the most common implementation method. Non-recursive versions of these algorithms are usually harder to understand and implement.
  • There are some programming languages, such as Haskell and Erlang, that have no loops! You have no choice but to use recursion, or functions that are based on recursion.
  • It is often easier to reason about a recursive function than an iterative one. Recursive functions are often mathematical definitions in disguise, and so you may be able to use that mathematics to help better understand your function’s correctness, performance, or memory usage.
  • Recursive functions often result in source code that is shorter, simpler, and more elegant than non-recursive functions. This can make your programs more readable, and less likely to have bugs (bugs love to hide in hard-to-read code).
  • Recursion plays a fundamental role in theoretical computer science. Recursive functions can be used as the basis for all computation, e.g. any loop can be translated into a recursive function that does the same thing.

In practice, recursion is probably best thought of as one of many tools a programmer can use to solve programming problems. Use it when it makes sense, and avoid it when it doesn’t.

Extra: Recursive Acronyms

Recursive acronyms are a fun example of recursion. Perhaps you’ve heard of the GNU project (they make g++, and much other software). GNU is an acronym that expands like this:

GNU = GNU's Not Unix

If you keep replacing GNU with its expansion you get an infinitely long string:

GNU = GNU's Not Unix
    = (GNU's Not Unix)'s Not Unix
    = ((GNU's Not Unix)'s Not Unix)'s Not Unix
    = (((GNU's Not Unix)'s Not Unix)'s Not Unix)'s Not Unix
    = ...

There is no base case, so it expands forever.

Here are a few more recursive acronyms:

  • YOPY = Your Own Personal YOPY

  • LAME = LAME Ain’t an MP3 Encoder

  • These acronyms are co-recursive:

    • HURD = Hird of Unix-Replacing Daemons
    • HIRD = Hurd of Interfaces Representing Depth
  • MOMS is doubly recursive and so expands very quickly:

    MOMS = MOMS Offering MOMS Support
         = MOMS Offering MOMS Support Offering MOMS Offering MOMS Support
           Support
         = MOMS Offering MOMS Support Offering MOMS Offering MOMS Support
           Support Offering MOMS Offering MOMS Support Offering MOMS Offering
           MOMS Support Support Support
         = ...
    

Here is a function that expands MOMS:

const string acronym = "MOMS";
const string expansion = "MOMS Offering MOMS Support";
string expand_MOMS(int n) {
   if (n == 0) {
      return "";
   } else if (n == 1) {
      return expansion;
   } else { // n > 1
      string prev = expand_MOMS(n - 1);
      // replace every occurrence of "MOMS" with its expansion
      int i = prev.find(acronym);
      while (i != string::npos) {
         prev.replace(i, acronym.size(), expansion);
         // search for next acronym after the current one
         i += expansion.size(); // skip over the just-replaced string
                                // so the acronyms within it are not expanded
         i = prev.find(acronym, i);
     }
     return prev;
   } // if
}

Recursion Practice Questions

Implement the following using recursion (and no loops, or library functions that do the hard part). For some questions, you may want to create a helper function with extra parameters. Sample solutions are here.

  1. The product of the integers from 1 to \(n\), i.e. the factorial function \(n! = 1 \cdot 2 \cdot \ldots \cdot n\), for \(n \geq 0\). Note that \(0! = 1\).
  2. The sum of the first n squares, i.e. \(1^2 + 2^2 + \ldots + n^2\), for \(n \geq 0\).
  3. Print the numbers from n down to 1 on cout, one number per line. Assume \(n \geq 1\).
  4. Print the numbers from 1 up to n on cout, one number per line. Assume \(n \geq 1\).
  5. The sum of just the positive numbers in a vector. For example, the sum of the positive numbers in {1, -3, -2, 6} is 7.
  6. The number of times x occurs in a vector. For example, in {5, 2, 1, 5, 5}, 5 occurs three times.
  7. Print the elements of a vector<int>, one number per line.
  8. Find the biggest number in a vector<int>.
  9. Write a function similar to expand_MOMS(n) that expands, n times, the acronym YOPY = Your Own Personal YOPY.

Sample solutions are here.

A Tracing Class

cmpt_trace.h contains a helpful class called cmpt::Trace that can be used to log when a functions is called and when it exits. For example, you can use it to print a trace the Fibonacci function as follows:

#include "cmpt_trace.h"

int f(int n) {
  cmpt::Trace trace{"f(" + to_string(n) + ")"};
  if (n == 1) {               // base case
    return 1;
  } else if (n == 2) {        // base case
    return 1;
  } else {
    return f(n-1) + f(n-2);   // recursive case
  }
}