tensorflow
Creación de una operación personalizada con tf.py_func (solo CPU)
Buscar..
Parámetros
Parámetro | Detalles |
---|---|
función | función python, que toma matrices numpy como sus entradas y devuelve matrices numpy como sus salidas |
En p | Lista de tensores (entradas) |
Revendedor | Lista de tipos de datos de tensorflow para las salidas de func |
Ejemplo basico
El tf.py_func(func, inp, Tout)
crea una operación TensorFlow que llama a una función de Python, func
en una lista de tensores inp
.
Consulte la documentación para tf.py_func(func, inp, Tout)
.
Advertencia : la operación tf.py_func()
solo se ejecutará en la CPU. Si está utilizando TensorFlow distribuido, la operación tf.py_func()
debe colocarse en un dispositivo CPU en el mismo proceso que el cliente.
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
Por qué usar tf.py_func
El operador tf.py_func()
permite ejecutar código Python arbitrario en medio de un gráfico TensorFlow. Es particularmente conveniente para envolver operadores NumPy personalizados para los cuales no existe un operador TensorFlow equivalente (aún). Agregar tf.py_func()
es una alternativa al uso de llamadas sess.run()
dentro del gráfico.
Otra forma de hacerlo es cortar la gráfica en dos partes:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
Con tf.py_func
esto es mucho más fácil:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)