tensorflow
Utilisation de la condition if dans le graphe TensorFlow avec tf.cond
Recherche…
Paramètres
Paramètre | Détails |
---|---|
pred | un tenseur TensorFlow de type bool |
fn1 | une fonction appelable, sans argument |
fn2 | une fonction appelable, sans argument |
prénom | (facultatif) nom de l'opération |
Remarques
-
pred
ne peut pas être justeTrue
ouFalse
, il doit être un tensionneur - La fonction
fn1
etfn2
devrait renvoyer le même nombre de sorties, avec les mêmes types.
Exemple de base
x = tf.constant(1.)
bool = tf.constant(True)
res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.
Lorsque f1 et f2 renvoient plusieurs tenseurs
Les deux fonctions fn1
et fn2
peuvent renvoyer plusieurs tenseurs, mais elles doivent renvoyer exactement le même nombre et le même type de sorties.
x = tf.constant(1.)
bool = tf.constant(True)
def fn1():
return tf.add(x, 1.), x
def fn2():
return tf.add(x, 10.), x
res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]
définir et utiliser les fonctions f1 et f2 avec des paramètres
Vous pouvez passer des paramètres aux fonctions dans tf.cond () en utilisant lambda et le code est comme ci-dessous.
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)
def fn1(a, b):
return tf.mul(a, b)
def fn2(a, b):
return tf.add(a, b)
pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
Alors vous pouvez l'appeler comme beuglant:
with tf.Session() as sess:
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
# The result is 2.0
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
# The result is 5.0
Modified text is an extract of the original Stack Overflow Documentation
Sous licence CC BY-SA 3.0
Non affilié à Stack Overflow