tensorflow
Skapa en anpassad operation med tf.py_func (endast CPU)
Sök…
parametrar
Parameter | detaljer |
---|---|
func | pythonfunktion, som tar numpy arrays som sina ingångar och returnerar numpy arrays som dess utgångar |
i P | lista över Tensorer (ingångar) |
Tout | lista med tensorflowdatatyper för utgångar från func |
Grundläggande exempel
tf.py_func(func, inp, Tout)
skapar en TensorFlow-operation som kallar en Python-funktion, func
på en lista med inp
.
Se dokumentationen för tf.py_func(func, inp, Tout)
.
Varning : tf.py_func()
-operationen körs endast på CPU. Om du använder distribuerad TensorFlow tf.py_func()
operationen tf.py_func()
placeras på en CPU-enhet i samma process som klienten.
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
Varför man använder tf.py_func
tf.py_func()
låter dig köra godtycklig Python-kod mitt i ett TensorFlow-diagram. Det är särskilt bekvämt för inslagning av anpassade NumPy-operatörer för vilka ingen motsvarande TensorFlow-operatör (ännu) finns. Att lägga till tf.py_func()
är ett alternativ till att använda sess.run()
inuti diagrammet.
Ett annat sätt att göra det är att skära grafen i två delar:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
Med tf.py_func
detta mycket lättare:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)