Buscar..


Introducción

Varios ejemplos que muestran cómo Tensorflow admite la indexación en tensores, destacando las diferencias y similitudes con la indexación tipo numpy siempre que sea posible.

Extraer una rebanada de un tensor

Consulte la documentación de tf.slice(input, begin, size) para obtener información detallada.

Argumentos:

  • input : tensor
  • begin : ubicación inicial para cada dimensión de input
  • size : número de elementos para cada dimensión de input , utilizando -1 incluye todos los elementos restantes

Rebanadas de tipo numpy

# x has shape [2, 3, 2]
x = tf.constant([[[1., 2.], [3., 4. ], [5. , 6. ]],
                 [[7., 8.], [9., 10.], [11., 12.]]])

# Extracts x[0, 1:2, :] == [[[ 3.,  4.]]]
res = tf.slice(x, [0, 1, 0], [1, 1, -1])

Usando indexación negativa, para recuperar el último elemento en la tercera dimensión:

# Extracts x[0, :, -1:] == [[[2.], [4.], [6.]]]
last_indice = x.get_shape().as_list()[2] - 1
res = tf.slice(x, [0, 1, last_indice], [1, -1, -1])

Extraiga segmentos no contiguos de la primera dimensión de un tensor

En general, tf.gather le da acceso a los elementos en la primera dimensión de un tensor (por ejemplo, las filas 1, 3 y 7 en un Tensor bidimensional). Si necesita acceder a cualquier otra dimensión que no sea la primera, o si no necesita toda la división, pero, por ejemplo, solo la quinta entrada en la 1ª, 3ª y 7ª fila, es mejor que use tf.gather_nd (vea más adelante ejemplo para esto).

tf.gather argumentos:

  • params : Un tensor del que desea extraer valores.
  • indices : un tensor que especifica los índices que apuntan a params

Consulte la documentación de tf.gather (parámetros, índices) para obtener información detallada.


Queremos extraer la 1ª y 4ª fila en un tensor bidimensional.

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          ...
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
params = tf.constant(data)
indices = tf.constant([0, 3])
selected = tf.gather(params, indices)

selected tiene forma [2, 6] e imprimiendo su valor da

[[ 0  1  2  3  4  5]
 [18 19 20 21 22 23]]

indices también pueden ser solo escalares (pero no pueden contener índices negativos). Por ejemplo, en el ejemplo anterior:

tf.gather(params, tf.constant(3))

imprimiría

[18 19 20 21 22 23]

Tenga en cuenta que los indices pueden tener cualquier forma, pero los elementos almacenados en los indices siempre se refieren a la primera dimensión de los params . Por ejemplo, si desea recuperar la primera y la tercera fila y la segunda y la cuarta fila al mismo tiempo, puede hacer esto:

indices = tf.constant([[0, 2], [1, 3]])
selected = tf.gather(params, indices)

Ahora selected tendrá forma [2, 2, 6] y su contenido dice:

[[[ 0  1  2  3  4  5]
  [12 13 14 15 16 17]]

 [[ 6  7  8  9 10 11]
  [18 19 20 21 22 23]]]

Puedes usar tf.gather para calcular una permutación. Por ejemplo, lo siguiente invierte todas las filas de params :

indices = tf.constant(list(range(4, -1, -1)))
selected = tf.gather(params, indices)

selected es ahora

[[24 25 26 27 28 29]
 [18 19 20 21 22 23]
 [12 13 14 15 16 17]
 [ 6  7  8  9 10 11]
 [ 0  1  2  3  4  5]]

Si necesita acceder a cualquier otro que no sea la primera dimensión, puede tf.transpose utilizando tf.transpose : Por ejemplo, para reunir columnas en lugar de filas en nuestro ejemplo, puede hacer esto:

indices = tf.constant([0, 2])
selected = tf.gather(tf.transpose(params, [1, 0]), indices)
selected_t = tf.transpose(selected, [1, 0]) 

selected_t es de forma [5, 2] y lee:

[[ 0  2]
 [ 6  8]
 [12 14]
 [18 20]
 [24 26]]

Sin embargo, tf.transpose es bastante caro, por lo que podría ser mejor usar tf.gather_nd para este caso de uso.

Indización numpy como tensores

Este ejemplo se basa en esta publicación: TensorFlow : indexación de tensor similar a un número .

En Numpy puede usar matrices para indexar en una matriz. Por ejemplo, para seleccionar los elementos en (1, 2) y (3, 2) en una matriz bidimensional, puede hacer esto:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)

Esto imprimirá:

[ 8 20]

Para obtener el mismo comportamiento en Tensorflow, puedes usar tf.gather_nd , que es una extensión de tf.gather . El ejemplo anterior se puede escribir así:

x = tf.constant(data)
idx1 = tf.constant(a)
idx2 = tf.constant(b)
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
        
with tf.Session() as sess:
    print(sess.run(result))

Esto imprimirá:

[ 8 20]

tf.stack es el equivalente de np.asarray y en este caso apila los dos vectores de índice a lo largo de la última dimensión (que en este caso es la primera) para producir:

[[1 2]
 [3 2]]

Cómo usar tf.gather_nd

tf.gather_nd es una extensión de tf.gather en el sentido de que le permite no solo acceder a la primera dimensión de un tensor, sino a todos potencialmente.

Argumentos:

  • params : un tensor de rango P representa el tensor al que queremos indexar
  • indices : un tensor de rango Q representa los índices en params que queremos acceder

La salida de la función depende de la forma de los indices . Si la dimensión más interna de los indices tiene la longitud P , estamos recolectando elementos individuales de los params . Si es menor que P , estamos recolectando segmentos, al igual que con tf.gather pero sin la restricción de que solo podemos acceder a la primera dimensión.


Recogiendo elementos de un tensor de rango 2.

Para acceder al elemento en (1, 2) en una matriz, podemos usar:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])

donde el result será 8 como se esperaba. Observe en qué se diferencia esto de tf.gather : los mismos índices pasados ​​a tf.gather(x, [1, 2]) se habrían dado como la segunda y tercera fila de los data .

Si desea recuperar más de un elemento al mismo tiempo, simplemente pase una lista de pares de índices:

result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])

el cual volverá [ 8 27 17]


Recogiendo filas de un tensor de rango 2

Si en el ejemplo anterior desea recopilar filas (es decir, segmentos) en lugar de elementos, ajuste el parámetro de indices siguiente manera:

data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])

Esto le dará la segunda y cuarta fila de data , es decir,

[[ 6  7  8  9 10 11]
 [18 19 20 21 22 23]]

Recogiendo elementos de un tensor de rango 3.

El concepto de cómo acceder a los tensores de rango 2 se traduce directamente en tensores de mayor dimensión. Por lo tanto, para acceder a los elementos en un tensor de rango 3, la dimensión más interna de los indices debe tener la longitud 3.

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])

result ahora se verá así: [ 0 11]


Recogiendo filas por lotes de un tensor de rango 3

Pensemos en un tensor de rango 3 como un lote de matrices con forma (batch_size, m, n) . Si desea recopilar la primera y la segunda fila para cada elemento del lote, puede usar esto:

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])

que resultará en esto:

[[[0 1]
  [2 3]]

 [[6 7]
  [8 9]]]

Observe cómo la forma de los indices influye en la forma del tensor de salida. Si hubiéramos usado un tensor de rango 2 para el argumento de los indices :

result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])

la salida hubiera sido

[[0 1]
 [2 3]
 [6 7]
 [8 9]]


Modified text is an extract of the original Stack Overflow Documentation
Licenciado bajo CC BY-SA 3.0
No afiliado a Stack Overflow