Recherche…


Paramètres

Paramètre Détails
pred un tenseur TensorFlow de type bool
fn1 une fonction appelable, sans argument
fn2 une fonction appelable, sans argument
prénom (facultatif) nom de l'opération

Remarques

  • pred ne peut pas être juste True ou False , il doit être un tensionneur
  • La fonction fn1 et fn2 devrait renvoyer le même nombre de sorties, avec les mêmes types.

Exemple de base

x = tf.constant(1.)
bool = tf.constant(True)

res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.

Lorsque f1 et f2 renvoient plusieurs tenseurs

Les deux fonctions fn1 et fn2 peuvent renvoyer plusieurs tenseurs, mais elles doivent renvoyer exactement le même nombre et le même type de sorties.

x = tf.constant(1.)
bool = tf.constant(True)

def fn1():
    return tf.add(x, 1.), x

def fn2():
    return tf.add(x, 10.), x

res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]

définir et utiliser les fonctions f1 et f2 avec des paramètres

Vous pouvez passer des paramètres aux fonctions dans tf.cond () en utilisant lambda et le code est comme ci-dessous.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)

def fn1(a, b):
  return tf.mul(a, b)

def fn2(a, b):
  return tf.add(a, b)

pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

Alors vous pouvez l'appeler comme beuglant:

with tf.Session() as sess:
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
  # The result is 2.0
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
  # The result is 5.0


Modified text is an extract of the original Stack Overflow Documentation
Sous licence CC BY-SA 3.0
Non affilié à Stack Overflow