tensorflow
Een aangepaste bewerking maken met tf.py_func (alleen CPU)
Zoeken…
parameters
Parameter | Details |
---|---|
func | python-functie, die numpy arrays als zijn invoer gebruikt en numpy arrays als zijn output teruggeeft |
INP | lijst van Tensoren (ingangen) |
Tout | lijst met tensorflow-gegevenstypen voor de uitgangen van func |
Basis voorbeeld
De tf.py_func(func, inp, Tout)
maakt een TensorFlow-bewerking die een Python-functie aanroept, func
op een lijst met tensors inp
.
Raadpleeg de documentatie voor tf.py_func(func, inp, Tout)
.
Waarschuwing : de tf.py_func()
wordt alleen uitgevoerd op de CPU. Als u gedistribueerde TensorFlow gebruikt, moet de tf.py_func()
op een CPU-apparaat worden geplaatst in hetzelfde proces als de client.
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
Waarom tf.py_func gebruiken
Met de operator tf.py_func()
kunt u willekeurige Python-code uitvoeren in het midden van een TensorFlow-grafiek. Het is met name handig voor het verpakken van aangepaste NumPy-operators waarvoor (nog) geen vergelijkbare TensorFlow-operator bestaat. Het toevoegen van tf.py_func()
is een alternatief voor het gebruik van sess.run()
aanroepen in de grafiek.
Een andere manier om dat te doen is om de grafiek in twee delen te knippen:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
Met tf.py_func
dit veel eenvoudiger:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)