Suche…


Einführung

Verschiedene Beispiele zeigen, wie Tensorflow die Indizierung in Tensoren unterstützt, wobei Unterschiede und Ähnlichkeiten mit der numerischen Indizierung hervorgehoben werden, wo dies möglich ist.

Extrahieren Sie eine Scheibe aus einem Tensor

Ausführliche Informationen finden Sie in der tf.slice(input, begin, size) zu tf.slice(input, begin, size) .

Argumente:

  • input : Tensor
  • begin : Startort für jede input
  • size : Anzahl der Elemente für jede input , wobei -1 alle übrigen Elemente enthält

Wackeliges Schneiden:

# x has shape [2, 3, 2]
x = tf.constant([[[1., 2.], [3., 4. ], [5. , 6. ]],
                 [[7., 8.], [9., 10.], [11., 12.]]])

# Extracts x[0, 1:2, :] == [[[ 3.,  4.]]]
res = tf.slice(x, [0, 1, 0], [1, 1, -1])

Verwenden Sie die negative Indizierung, um das letzte Element in der dritten Dimension abzurufen:

# Extracts x[0, :, -1:] == [[[2.], [4.], [6.]]]
last_indice = x.get_shape().as_list()[2] - 1
res = tf.slice(x, [0, 1, last_indice], [1, -1, -1])

Extrahieren Sie nicht zusammenhängende Schichten aus der ersten Dimension eines Tensors

Im Allgemeinen ermöglicht tf.gather den Zugriff auf Elemente in der ersten Dimension eines Tensors (z. B. Reihen 1, 3 und 7 in einem 2-dimensionalen Tensor). Wenn Sie Zugriff auf eine andere Dimension als die erste benötigen oder wenn Sie nicht das gesamte Slice benötigen, sondern z. B. nur den fünften Eintrag in der ersten, dritten und siebten Zeile, ist die Verwendung von tf.gather_nd besser tf.gather_nd Beispiel dafür).

tf.gather Argumente:

  • params : Ein Tensor, aus dem Sie Werte extrahieren möchten.
  • indices : Ein Tensor, der die Indizes angibt, die in params

Ausführliche Informationen finden Sie in der Dokumentation zu tf.gather (Parameter, Indizes) .


Wir wollen die 1. und 4. Reihe in einem 2-dimensionalen Tensor extrahieren.

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          ...
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
params = tf.constant(data)
indices = tf.constant([0, 3])
selected = tf.gather(params, indices)

selected hat die Form [2, 6] und der Druck ergibt den Wert

[[ 0  1  2  3  4  5]
 [18 19 20 21 22 23]]

indices können auch nur ein Skalar sein (dürfen jedoch keine negativen Indizes enthalten). ZB im obigen Beispiel:

tf.gather(params, tf.constant(3))

würde drucken

[18 19 20 21 22 23]

Beachten Sie, dass indices eine beliebige Form haben können, aber die in indices gespeicherten Elemente beziehen sich immer nur auf die erste Dimension von params . Wenn Sie beispielsweise sowohl die 1. und 3. Reihe als auch die 2. und 4. Reihe gleichzeitig abrufen möchten, können Sie Folgendes tun:

indices = tf.constant([[0, 2], [1, 3]])
selected = tf.gather(params, indices)

Die selected Form hat jetzt die Form [2, 2, 6] und ihr Inhalt lautet:

[[[ 0  1  2  3  4  5]
  [12 13 14 15 16 17]]

 [[ 6  7  8  9 10 11]
  [18 19 20 21 22 23]]]

Sie können tf.gather , um eine Permutation zu berechnen. params Beispiel werden alle params :

indices = tf.constant(list(range(4, -1, -1)))
selected = tf.gather(params, indices)

selected ist jetzt

[[24 25 26 27 28 29]
 [18 19 20 21 22 23]
 [12 13 14 15 16 17]
 [ 6  7  8  9 10 11]
 [ 0  1  2  3  4  5]]

Wenn Sie Zugriff auf eine andere als die erste Dimension benötigen, können Sie dies mit tf.transpose : ZB um Spalten anstelle von Zeilen in unserem Beispiel zu tf.transpose , können Sie tf.transpose tun:

indices = tf.constant([0, 2])
selected = tf.gather(tf.transpose(params, [1, 0]), indices)
selected_t = tf.transpose(selected, [1, 0]) 

selected_t hat die Form [5, 2] und lautet:

[[ 0  2]
 [ 6  8]
 [12 14]
 [18 20]
 [24 26]]

tf.transpose ist jedoch ziemlich teuer, daher ist es möglicherweise besser, für diesen Anwendungsfall tf.gather_nd zu verwenden.

Würfelartige Indizierung mit Tensoren

Dieses Beispiel basiert auf diesem Beitrag: TensorFlow - Numpy-artige Tensor-Indizierung .

In Numpy können Sie Arrays verwenden, um in ein Array zu indizieren. Um beispielsweise die Elemente bei (1, 2) und (3, 2) in einem 2-dimensionalen Array auszuwählen, können Sie Folgendes tun:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)

Dies wird drucken:

[ 8 20]

Um das gleiche Verhalten in Tensorflow zu erhalten, können Sie tf.gather_nd , eine Erweiterung von tf.gather . Das obige Beispiel kann wie folgt geschrieben werden:

x = tf.constant(data)
idx1 = tf.constant(a)
idx2 = tf.constant(b)
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
        
with tf.Session() as sess:
    print(sess.run(result))

Dies wird drucken:

[ 8 20]

tf.stack ist das Äquivalent von np.asarray und stapelt in diesem Fall die beiden np.asarray entlang der letzten Dimension (die in diesem Fall die 1. ist), um np.asarray zu erzeugen:

[[1 2]
 [3 2]]

Wie benutze ich tf.gather_nd

tf.gather_nd ist eine Erweiterung von tf.gather in dem Sinne, dass Sie nicht nur auf die 1. Dimension eines Tensors zugreifen können, sondern möglicherweise auf alle.

Argumente:

  • params : Ein Tensor von Rang P der den Tensor darstellt, in den wir indexieren möchten
  • indices : Ein Tensor von Rang Q der die Indizes in params wir zugreifen möchten

Die Ausgabe der Funktion hängt von der Form der indices . Wenn die innerste Dimension der indices die Länge P , werden einzelne Elemente aus params . Wenn es weniger als P , sammeln wir wie bei tf.gather Scheiben, jedoch ohne die Einschränkung, dass wir nur auf die 1. Dimension zugreifen können.


Sammeln von Elementen aus einem Tensor von Rang 2

Um auf das Element in (1, 2) in einer Matrix zuzugreifen, können wir Folgendes verwenden:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])

wobei das result wie erwartet nur 8 wird. Beachten Sie, wie sich dies von tf.gather : Die gleichen Indizes, die an tf.gather(x, [1, 2]) hätten als 2. und 3. Zeile aus den data .

Wenn Sie mehrere Elemente gleichzeitig abrufen möchten, übergeben Sie einfach eine Liste von Indexpaaren:

result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])

die zurückkehren wird [ 8 27 17]


Sammeln von Reihen aus einem Tensor von Rang 2

Wenn Sie im obigen Beispiel anstelle von Elementen Zeilen (dh Segmente) sammeln möchten, passen Sie den Parameter indices wie folgt an:

data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])

Dies gibt Ihnen die 2. und 4. Reihe von data , dh

[[ 6  7  8  9 10 11]
 [18 19 20 21 22 23]]

Sammeln von Elementen aus einem Tensor von Rang 3

Das Konzept des Zugriffs auf die Tensor-2-Tensoren wird direkt in höher dimensionierte Tensoren übersetzt. Um auf Elemente in einem Straff-3-Tensor zugreifen zu können, muss die innerste Dimension von indices die Länge 3 haben.

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])

result sieht nun so aus: [ 0 11]


Sammeln von Stapelreihen aus einem Tensor von Rang 3

Betrachten wir einen Tensor der Stufe 3 als eine Charge von Matrizen (batch_size, m, n) . Wenn Sie die erste und zweite Zeile für jedes Element im Stapel sammeln möchten, können Sie Folgendes verwenden:

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])

was dazu führen wird:

[[[0 1]
  [2 3]]

 [[6 7]
  [8 9]]]

Beachten Sie, wie die Form von indices die Form des Ausgangstensors beeinflusst. Wenn wir einen Tensor der Stufe 2 für das indices :

result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])

die Ausgabe wäre gewesen

[[0 1]
 [2 3]
 [6 7]
 [8 9]]


Modified text is an extract of the original Stack Overflow Documentation
Lizenziert unter CC BY-SA 3.0
Nicht angeschlossen an Stack Overflow