← Go Back

How to Flatten a List

Many programming languages have a standard flatten() function that takes any amount of nested lists and returns a single non-nested list. For example:

>>> flatten([1, [2, 3], [[4], 5], 6, [[[7]]]])
[1, 2, 3, 4, 5, 6, 7]

Since there is no such a function in Python, here's a possible implementation:

def flatten(nested_list):
try:
head = nested_list[0]
except IndexError:
return []
return ((flatten(head) if isinstance(head, list) else [head]) +
flatten(nested_list[1:]))

Note that this code only works with lists. In order to support other collections, add them to the isinstance() call. For example, isinstance(head, (list, tuple)) makes the function also work with tuples.

An alternative implementation using a generator:

def iflatten(nested_list):
try:
head = nested_list[0]
except IndexError:
return
try:
yield from iflatten(head)
except TypeError:
yield head
yield from iflatten(nested_list[1:])

This is more efficient if the elements of the resulting list are accessed one at a time, since they are not loaded into memory at the same time, especially useful if the original list is very long. Additionally, it can operate on any iterable object. However, for large volumes of data consider NumPy's ndarray.flatten() function.

lists tuples numpy


🐍 You might also find interesting: