Flattening and unflattening a nested list of numpy arrays Flattening and unflattening a nested list of numpy arrays numpy numpy

Flattening and unflattening a nested list of numpy arrays


I was looking for a solution to flatten and unflatten nested lists of numpy arrays, but only found this unanswered question, so I came up with this:

def _flatten(values):    if isinstance(values, np.ndarray):        yield values.flatten()    else:        for value in values:            yield from _flatten(value)def flatten(values):    # flatten nested lists of np.ndarray to np.ndarray    return np.concatenate(list(_flatten(values)))def _unflatten(flat_values, prototype, offset):    if isinstance(prototype, np.ndarray):        shape = prototype.shape        new_offset = offset + np.product(shape)        value = flat_values[offset:new_offset].reshape(shape)        return value, new_offset    else:        result = []        for value in prototype:            value, offset = _unflatten(flat_values, value, offset)            result.append(value)        return result, offsetdef unflatten(flat_values, prototype):    # unflatten np.ndarray to nested lists with structure of prototype    result, offset = _unflatten(flat_values, prototype, 0)    assert(offset == len(flat_values))    return result

Example:

a = [    np.random.rand(1),    [        np.random.rand(2, 1),        np.random.rand(1, 2, 1),    ],    [[]],]b = flatten(a)# 'c' will have values of 'b' and structure of 'a'c = unflatten(b, a)

Output:

a:[array([ 0.26453544]), [array([[ 0.88273824],       [ 0.63458643]]), array([[[ 0.84252894],        [ 0.91414218]]])], [[]]]b:[ 0.26453544  0.88273824  0.63458643  0.84252894  0.91414218]c:[array([ 0.26453544]), [array([[ 0.88273824],       [ 0.63458643]]), array([[[ 0.84252894],        [ 0.91414218]]])], [[]]]

License: WTFPL


Here is what I come up with, which turned out to be ~30x faster than iterating over the nested list and loading individually.

def flatten(nl):    l1 = [len(s) for s in itertools.chain.from_iterable(nl)]    l2 = [len(s) for s in nl]    nl = list(itertools.chain.from_iterable(        itertools.chain.from_iterable(nl)))    return nl,l1,l2def reconstruct(nl,l1,l2):    return np.split(np.split(nl,np.cumsum(l1)),np.cumsum(l2))[:-1]L_flat,l1,l2 = flatten(L)L_reconstructed = reconstruct(L_flat,l1,l2)

A better solution solution would work iteratively for an arbitrary number of nested levels.


You're building a paradox: you want to flatten the object, but you don't want to flatten the object, retaining its structural information somewhere in the object.

So the pythonic way to do this is not to flatten the object, but writing a class that will have an __iter__ that allows you to sequentially (ie. in a flat manner) go through the elements of the underlying object. That will be about as fast as the conversion to a flat thing (if only applied once per element), and you don't duplicate or change the original non-flat container.