Zoeken…


Invoering

Verschillende voorbeelden laten zien hoe Tensorflow indexering in tensoren ondersteunt, waarbij verschillen en overeenkomsten met numpy-achtige indexering waar mogelijk worden benadrukt.

Pak een plak van een tensor

Raadpleeg de documentatie van tf.slice(input, begin, size) voor gedetailleerde informatie.

argumenten:

  • input : Tensor
  • begin : startlocatie voor elke input
  • size : aantal elementen voor elke input , met -1 bevat alle resterende elementen

Numpy-achtige plakken:

# x has shape [2, 3, 2]
x = tf.constant([[[1., 2.], [3., 4. ], [5. , 6. ]],
                 [[7., 8.], [9., 10.], [11., 12.]]])

# Extracts x[0, 1:2, :] == [[[ 3.,  4.]]]
res = tf.slice(x, [0, 1, 0], [1, 1, -1])

Negatieve indexering gebruiken om het laatste element in de derde dimensie op te halen:

# Extracts x[0, :, -1:] == [[[2.], [4.], [6.]]]
last_indice = x.get_shape().as_list()[2] - 1
res = tf.slice(x, [0, 1, last_indice], [1, -1, -1])

Extraheer niet-aaneengesloten plakjes uit de eerste dimensie van een tensor

In het tf.gather geeft tf.gather u toegang tot elementen in de eerste dimensie van een tensor (bijv. Rijen 1, 3 en 7 in een tweedimensionale Tensor). Als u toegang nodig hebt tot een andere dimensie dan de eerste, of als u niet het hele segment nodig hebt, maar bijvoorbeeld alleen het 5e item in de 1e, 3e en 7e rij, kunt u beter tf.gather_nd (zie binnenkort voorbeeld hiervoor).

tf.gather argumenten:

  • params : een tensor waaruit u waarden wilt extraheren.
  • indices : een tensor die de indices params die naar params wijzen

Raadpleeg de documentatie van tf.gather (params, indices) voor gedetailleerde informatie.


We willen de 1e en 4e rij extraheren in een tweedimensionale tensor.

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          ...
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
params = tf.constant(data)
indices = tf.constant([0, 3])
selected = tf.gather(params, indices)

selected heeft vorm [2, 6] en afdrukken geeft de waarde ervan

[[ 0  1  2  3  4  5]
 [18 19 20 21 22 23]]

indices kunnen ook gewoon een scalaire waarde zijn (maar mogen geen negatieve indices bevatten). Bijvoorbeeld in het bovenstaande voorbeeld:

tf.gather(params, tf.constant(3))

zou afdrukken

[18 19 20 21 22 23]

Merk op dat indices elke vorm kunnen hebben, maar de elementen die in indices opgeslagen, verwijzen altijd alleen naar de eerste dimensie van params . Als u bijvoorbeeld zowel de 1e en 3e rij als de 2e en 4e rij tegelijkertijd wilt ophalen, kunt u dit doen:

indices = tf.constant([[0, 2], [1, 3]])
selected = tf.gather(params, indices)

Nu selected heeft vorm [2, 2, 6] en de inhoud is als volgt:

[[[ 0  1  2  3  4  5]
  [12 13 14 15 16 17]]

 [[ 6  7  8  9 10 11]
  [18 19 20 21 22 23]]]

U kunt tf.gather gebruiken om een permutatie te berekenen. Bijvoorbeeld: het volgende keert alle rijen params :

indices = tf.constant(list(range(4, -1, -1)))
selected = tf.gather(params, indices)

selected is nu

[[24 25 26 27 28 29]
 [18 19 20 21 22 23]
 [12 13 14 15 16 17]
 [ 6  7  8  9 10 11]
 [ 0  1  2  3  4  5]]

Als u toegang nodig hebt tot een andere dan de eerste dimensie, kunt u dit omzeilen door tf.transpose : bijvoorbeeld om kolommen te verzamelen in plaats van rijen in ons voorbeeld, kunt u dit doen:

indices = tf.constant([0, 2])
selected = tf.gather(tf.transpose(params, [1, 0]), indices)
selected_t = tf.transpose(selected, [1, 0]) 

selected_t heeft de vorm [5, 2] en leest:

[[ 0  2]
 [ 6  8]
 [12 14]
 [18 20]
 [24 26]]

tf.transpose is echter vrij duur, dus het is misschien beter om tf.gather_nd te gebruiken voor deze use case.

Numpy-achtige indexering met behulp van tensoren

Dit voorbeeld is gebaseerd op dit bericht: TensorFlow - numpy-achtige tensor-indexering .

In Numpy kunt u arrays gebruiken om in een array te indexeren. Om bijvoorbeeld de elementen bij (1, 2) en (3, 2) in een tweedimensionale array te selecteren, kunt u dit doen:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)

Dit zal afdrukken:

[ 8 20]

Om hetzelfde gedrag in Tensorflow te krijgen, kun je tf.gather_nd , een extensie van tf.gather . Het bovenstaande voorbeeld kan als volgt worden geschreven:

x = tf.constant(data)
idx1 = tf.constant(a)
idx2 = tf.constant(b)
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
        
with tf.Session() as sess:
    print(sess.run(result))

Dit zal afdrukken:

[ 8 20]

tf.stack is het equivalent van np.asarray en stapelt in dit geval de twee np.asarray langs de laatste dimensie (die in dit geval de 1e is) om te produceren:

[[1 2]
 [3 2]]

Hoe tf.gather_nd te gebruiken

tf.gather_nd is een uitbreiding van tf.gather in die zin dat je hiermee niet alleen toegang hebt tot de 1e dimensie van een tensor, maar mogelijk allemaal.

argumenten:

  • params : een Tensor van rang P die de tensor vertegenwoordigt waarin we willen indexeren
  • indices : een Tensor van rang Q die de indices weergeeft in params we toegang willen hebben

De uitvoer van de functie is afhankelijk van de vorm van indices . Als de binnenste dimensie van indices lengte P , verzamelen we afzonderlijke elementen van params . Als het kleiner is dan P , verzamelen we plakjes, net als bij tf.gather maar zonder de beperking dat we alleen toegang hebben tot de 1e dimensie.


Elementen verzamelen van een tensor van rang 2

Om toegang te krijgen tot het element op (1, 2) in een matrix, kunnen we gebruiken:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])

waar het result slechts 8 zoals verwacht. Merk op hoe dit verschilt van tf.gather : dezelfde indices die zijn doorgegeven aan tf.gather(x, [1, 2]) zouden zijn gegeven als de 2e en 3e rij van data .

Als u meer dan één element tegelijkertijd wilt ophalen, geeft u gewoon een lijst met indexparen door:

result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])

die zal terugkeren [ 8 27 17]


Rijen verzamelen van een tensor van rang 2

Als u in het bovenstaande voorbeeld rijen (dwz segmenten) wilt verzamelen in plaats van elementen, past u de parameter indices als volgt aan:

data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])

Hiermee krijgt u de 2e en 4e rij met data , dat wil zeggen

[[ 6  7  8  9 10 11]
 [18 19 20 21 22 23]]

Elementen verzamelen van een tensor van rang 3

Het concept van toegang tot rang-2-tensoren vertaalt zich direct in hoger-dimensionale tensoren. Om toegang te krijgen tot elementen in een rang-3-tensor, moet de binnenste dimensie van indices dus lengte 3 hebben.

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])

result ziet er nu als volgt uit: [ 0 11]


Verzamelde rijen met rijen van een tensor van rang 3

Laten we een rank-3-tensor beschouwen als een batch met matrices in de vorm (batch_size, m, n) . Als u de eerste en tweede rij voor elk element in de batch wilt verzamelen, kunt u dit gebruiken:

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])

wat zal resulteren in dit:

[[[0 1]
  [2 3]]

 [[6 7]
  [8 9]]]

Merk op hoe de vorm van indices de vorm van de uitvoertensor beïnvloedt. Als we een rang-2-tensor zouden hebben gebruikt voor het argument van indices :

result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])

de output zou zijn geweest

[[0 1]
 [2 3]
 [6 7]
 [8 9]]


Modified text is an extract of the original Stack Overflow Documentation
Licentie onder CC BY-SA 3.0
Niet aangesloten bij Stack Overflow