Recherche…


Paramètres

Paramètre Détails
func fonction python, qui prend en entrée les tableaux numpy et renvoie les tableaux numpy en sortie
inp liste des tenseurs (entrées)
Tout liste des types de données tensorflow pour les sorties de func

Exemple de base

L' tf.py_func(func, inp, Tout) crée une opération TensorFlow qui appelle une fonction Python, func sur une liste de tenseurs inp .

Voir la documentation de tf.py_func(func, inp, Tout) .

Attention : L'opération tf.py_func() ne fonctionnera que sur CPU. Si vous utilisez TensorFlow distribué, l'opération tf.py_func() doit être placée sur un périphérique CPU dans le même processus que le client.

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

Pourquoi utiliser tf.py_func

L'opérateur tf.py_func() vous permet d'exécuter du code Python arbitraire au milieu d'un graphique TensorFlow. C'est particulièrement pratique pour encapsuler des opérateurs NumPy personnalisés pour lesquels aucun opérateur TensorFlow équivalent n'existe encore. L'ajout de tf.py_func() est une alternative à l'utilisation des sess.run() dans le graphique.

Une autre façon de faire est de couper le graphique en deux parties:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

Avec tf.py_func c'est beaucoup plus facile:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)


Modified text is an extract of the original Stack Overflow Documentation
Sous licence CC BY-SA 3.0
Non affilié à Stack Overflow