Buscar..


Parámetros

Parámetro Detalles
función función python, que toma matrices numpy como sus entradas y devuelve matrices numpy como sus salidas
En p Lista de tensores (entradas)
Revendedor Lista de tipos de datos de tensorflow para las salidas de func

Ejemplo basico

El tf.py_func(func, inp, Tout) crea una operación TensorFlow que llama a una función de Python, func en una lista de tensores inp .

Consulte la documentación para tf.py_func(func, inp, Tout) .

Advertencia : la operación tf.py_func() solo se ejecutará en la CPU. Si está utilizando TensorFlow distribuido, la operación tf.py_func() debe colocarse en un dispositivo CPU en el mismo proceso que el cliente.

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

Por qué usar tf.py_func

El operador tf.py_func() permite ejecutar código Python arbitrario en medio de un gráfico TensorFlow. Es particularmente conveniente para envolver operadores NumPy personalizados para los cuales no existe un operador TensorFlow equivalente (aún). Agregar tf.py_func() es una alternativa al uso de llamadas sess.run() dentro del gráfico.

Otra forma de hacerlo es cortar la gráfica en dos partes:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

Con tf.py_func esto es mucho más fácil:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)


Modified text is an extract of the original Stack Overflow Documentation
Licenciado bajo CC BY-SA 3.0
No afiliado a Stack Overflow