tensorflow
Angepasste Operation mit tf.py_func erstellen (nur CPU)
Suche…
Parameter
Parameter | Einzelheiten |
---|---|
func | Python-Funktion, die numpy-Arrays als Eingänge verwendet und numpy-Arrays als Ausgänge zurückgibt |
inp | Liste der Tensoren (Eingänge) |
Schlepper | Liste der Tensorflow-Datentypen für die Ausgaben von func |
Grundlegendes Beispiel
Die tf.py_func(func, inp, Tout)
schafft Bediener eine TensorFlow Operation , die eine Python - Funktion, ruft func
auf einer Liste von Tensoren inp
.
Siehe die Dokumentation für tf.py_func(func, inp, Tout)
.
Warnung : Die Operation tf.py_func()
kann nur auf der CPU ausgeführt werden. Wenn Sie verteiltes TensorFlow verwenden, muss die Operation tf.py_func()
auf einem CPU-Gerät im selben Prozess wie der Client platziert werden.
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
Warum tf.py_func verwenden?
Mit dem Operator tf.py_func()
können Sie beliebigen Python-Code in der Mitte eines TensorFlow-Diagramms ausführen. Dies ist besonders praktisch für das Umschließen von benutzerdefinierten NumPy-Operatoren, für die noch kein entsprechender TensorFlow-Operator existiert. Das Hinzufügen von tf.py_func()
ist eine Alternative zur Verwendung von sess.run()
Aufrufen innerhalb des Diagramms.
Eine andere Möglichkeit besteht darin, die Grafik in zwei Teile zu schneiden:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
Mit tf.py_func
ist das viel einfacher:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)