I have a 3d tensor where I need to preserve vectors at certain positions in the second dimension, and zero out the remaining vectors. The positions are specified as a 1d array. I'm thinking the best way to do this is to multiply the tensor with a binary mask.
Here's a simple Numpy version:
A.shape: (b, n, m)
indices.shape: (b)
mask = np.zeros(A.shape)
for i in range(b):
mask[i][indices[i]] = 1
result = A*mask
So for each nxm matrix in A, I need to preserve rows specified by indices, and zero out the rest.
I'm trying to do this in TensorFlow using tf.scatter_nd op, but I can't figure out the correct shape of indices:
shape = tf.constant([3,5,4])
A = tf.random_normal(shape)
indices = tf.constant([2,1,4]) #???
updates = tf.ones((3,4))
mask = tf.scatter_nd(indices, updates, shape)
result = A*mask
See Question&Answers more detail:os