Recursive Functions in Python
Recursive functions call themselves to solve problems by breaking them into smaller subproblems. They're useful for tasks like traversing trees, calculating factorials, and processing nested structures.
Contents
Basic recursion
A recursive function calls itself with modified arguments until it reaches a base case.
def countdown(n):
if n <= 0:
print("Done!")
else:
print(n)
countdown(n - 1)
countdown(3)
>>> 3
>>> 2
>>> 1
>>> Done!
Base case and recursive case
Every recursive function needs a base case that stops the recursion and a recursive case that calls itself.
def sum_numbers(n):
# Base case
if n == 0:
return 0
# Recursive case
return n + sum_numbers(n - 1)
print(sum_numbers(5))
>>> 15
Without a proper base case, recursion continues indefinitely, causing a stack overflow.
def infinite_recursion():
infinite_recursion() # No base case!
# infinite_recursion() # This will cause RecursionError
Factorial example
Factorial is a classic recursive problem: n! = n × (n-1)!
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
print(factorial(5))
>>> 120
The recursive approach mirrors the mathematical definition.
# 5! = 5 × 4 × 3 × 2 × 1
# = 5 × 4!
# = 5 × 4 × 3!
# ... and so on
print(factorial(0))
>>> 1
print(factorial(1))
>>> 1
print(factorial(6))
>>> 720
Fibonacci sequence
The Fibonacci sequence is another classic recursive problem where each number is the sum of the two preceding ones.
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
print(fibonacci(7))
>>> 13
This naive implementation is inefficient for large values due to repeated calculations.
# First 10 Fibonacci numbers
for i in range(10):
print(f"F({i}) = {fibonacci(i)}")
>>> F(0) = 0
>>> F(1) = 1
>>> F(2) = 1
>>> F(3) = 2
>>> F(4) = 3
>>> F(5) = 5
>>> F(6) = 8
>>> F(7) = 13
>>> F(8) = 21
>>> F(9) = 34
Working with nested structures
Recursion excels at processing nested data structures like lists, dictionaries, and trees.
def flatten_list(nested_list):
result = []
for item in nested_list:
if isinstance(item, list):
result.extend(flatten_list(item))
else:
result.append(item)
return result
nested = [1, [2, 3], [4, [5, 6]], 7]
print(flatten_list(nested))
>>> [1, 2, 3, 4, 5, 6, 7]
Recursion handles arbitrary nesting depth naturally.
deeply_nested = [1, [2, [3, [4, [5]]]]]
print(flatten_list(deeply_nested))
>>> [1, 2, 3, 4, 5]
You can also work with nested dictionaries.
def find_value(data, key):
if isinstance(data, dict):
if key in data:
return data[key]
for value in data.values():
result = find_value(value, key)
if result is not None:
return result
elif isinstance(data, list):
for item in data:
result = find_value(item, key)
if result is not None:
return result
return None
nested_dict = {
"a": 1,
"b": {
"c": 2,
"d": {
"e": 3
}
}
}
print(find_value(nested_dict, "e"))
>>> 3
Tail recursion
Tail recursion occurs when the recursive call is the last operation. Python doesn't optimise tail recursion, but it's still a useful pattern.
def factorial_tail(n, accumulator=1):
if n <= 1:
return accumulator
return factorial_tail(n - 1, n * accumulator)
print(factorial_tail(5))
>>> 120
The accumulator pattern avoids building up a large call stack.
def sum_tail(numbers, accumulator=0):
if not numbers:
return accumulator
return sum_tail(numbers[1:], accumulator + numbers[0])
print(sum_tail([1, 2, 3, 4, 5]))
>>> 15
Common pitfalls
Recursion can be inefficient or cause stack overflow if not used carefully.
# Inefficient: calculates same values multiple times
def fibonacci_naive(n):
if n <= 1:
return n
return fibonacci_naive(n - 1) + fibonacci_naive(n - 2)
# Better: use memoization
def fibonacci_memo(n, memo={}):
if n in memo:
return memo[n]
if n <= 1:
return n
memo[n] = fibonacci_memo(n - 1, memo) + fibonacci_memo(n - 2, memo)
return memo[n]
print(fibonacci_memo(40))
>>> 102334155
Python has a recursion limit to prevent stack overflow.
import sys
print(sys.getrecursionlimit())
>>> 1000
# You can change it, but be careful
# sys.setrecursionlimit(2000)
For deep recursion, consider iterative solutions or iterative algorithms.
def factorial_iterative(n):
result = 1
for i in range(1, n + 1):
result *= i
return result
print(factorial_iterative(5))
>>> 120