Theano Dimshuffle equivalent in Google's TensorFlow?
There are three relevant ops for implementing Theano's dimshuffle
in TensorFlow:
tf.transpose()
is used to permute the dimensions of a tensor. If the pattern specified in the arguments todimshuffle
is a permutation of the input tensor's dimensions (i.e. there is no'x'
or missing dimension) you can usetf.transpose()
to implementdimshuffle()
.tf.expand_dims()
is used to add one or more size-1 dimensions to a tensor. This handles the case where'x'
is specified as part of thedimshuffle()
pattern, but does not reorder the existing dimensions.tf.squeeze()
is used to remove one or more size-1 dimensions from a tensor. This handles the case where a dimension is omitted from adimshuffle()
pattern, but it does not reorder the existing dimensions.
Assuming that the input is a vector, your example (dimshuffle(0, 'x')
) can be expressed using tf.expand_dims()
only:
input = tf.placeholder(tf.float32, [None]) # Defines an arbitrary-sized vector.result = tf.expand_dims(input, 1)print result.get_shape() # ==> TensorShape([Dimension(None), Dimension(1)])
Taking a more complicated example, dimshuffle(1, 'x', 0)
applied to a matrix would be:
input = tf.placeholder(tf.float32, [128, 32]) # Defines a matrix.output = tf.expand_dims(tf.transpose(input, [1, 0]), 1)print output.get_shape()# ==> TensorShape([Dimension(32), Dimension(1), Dimension(128)])
I implemented dimshuffle
for TensorFlow in our framework Returnn (here). The code is this:
def expand_multiple_dims(x, axes, name="expand_multiple_dims"): """ :param tf.Tensor x: :param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes :param str name: scope name :return: y where we have a new broadcast axis for each axis in axes :rtype: tf.Tensor """ with tf.name_scope(name): for i in sorted(axes): x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i) return xdef dimshuffle(x, axes, name="dimshuffle"): """ Like Theanos dimshuffle. Combines tf.transpose, tf.expand_dims and tf.squeeze. :param tf.Tensor x: :param list[int|str]|tuple[int|str] axes: :param str name: scope name :rtype: tf.Tensor """ with tf.name_scope(name): assert all([i == "x" or isinstance(i, int) for i in axes]) real_axes = [i for i in axes if isinstance(i, int)] bc_axes = [i for (i, j) in enumerate(axes) if j == "x"] if x.get_shape().ndims is None: x_shape = tf.shape(x) x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)]) # will have static ndims assert x.get_shape().ndims is not None # First squeeze missing axes. i = 0 while i < x.get_shape().ndims: if i not in real_axes: x = tf.squeeze(x, axis=i) real_axes = [(j if (j < i) else (j - 1)) for j in real_axes] else: i += 1 # Now permute. assert list(sorted(real_axes)) == list(range(x.get_shape().ndims)) if real_axes != list(range(x.get_shape().ndims)): x = tf.transpose(x, real_axes) # Now add broadcast dimensions. if bc_axes: x = expand_multiple_dims(x, bc_axes) assert len(axes) == x.get_shape().ndims return x
If tensorflow is your backend
from keras import baskend as KK.permute_dimension should do