tensorflow
Creazione di un'operazione personalizzata con tf.py_func (solo CPU)
Ricerca…
Parametri
Parametro | Dettagli |
---|---|
func | funzione python, che accetta gli array numpy come input e restituisce gli array numpy come output |
inp | elenco di Tensors (input) |
tout | lista dei tipi di dati tensorflow per le uscite di func |
Esempio di base
Il tf.py_func(func, inp, Tout)
operatore crea un'operazione tensorflow che chiama una funzione Python, func
su un elenco di tensori inp
.
Vedere la documentazione per tf.py_func(func, inp, Tout)
.
Attenzione : l'operazione tf.py_func()
verrà eseguita solo sulla CPU. Se si utilizza TensorFlow distribuito, l'operazione tf.py_func()
deve essere posizionata su un dispositivo CPU nello stesso processo del client.
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
Perché usare tf.py_func
L'operatore tf.py_func()
consente di eseguire codice Python arbitrario nel mezzo di un grafico TensorFlow. È particolarmente utile per il wrapping di operatori NumPy personalizzati per i quali non esiste (ancora) un operatore TensorFlow equivalente. L'aggiunta di tf.py_func()
è un'alternativa all'uso sess.run()
chiamate sess.run()
all'interno del grafico.
Un altro modo per farlo è quello di tagliare il grafico in due parti:
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
Con tf.py_func
questo è molto più semplice:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)