Ricerca…


Parametri

Parametro Dettagli
func funzione python, che accetta gli array numpy come input e restituisce gli array numpy come output
inp elenco di Tensors (input)
tout lista dei tipi di dati tensorflow per le uscite di func

Esempio di base

Il tf.py_func(func, inp, Tout) operatore crea un'operazione tensorflow che chiama una funzione Python, func su un elenco di tensori inp .

Vedere la documentazione per tf.py_func(func, inp, Tout) .

Attenzione : l'operazione tf.py_func() verrà eseguita solo sulla CPU. Se si utilizza TensorFlow distribuito, l'operazione tf.py_func() deve essere posizionata su un dispositivo CPU nello stesso processo del client.

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

Perché usare tf.py_func

L'operatore tf.py_func() consente di eseguire codice Python arbitrario nel mezzo di un grafico TensorFlow. È particolarmente utile per il wrapping di operatori NumPy personalizzati per i quali non esiste (ancora) un operatore TensorFlow equivalente. L'aggiunta di tf.py_func() è un'alternativa all'uso sess.run() chiamate sess.run() all'interno del grafico.

Un altro modo per farlo è quello di tagliare il grafico in due parti:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

Con tf.py_func questo è molto più semplice:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)


Modified text is an extract of the original Stack Overflow Documentation
Autorizzato sotto CC BY-SA 3.0
Non affiliato con Stack Overflow