Buscar..


Parámetros

Parámetro Detalles
pred Un tensor TensorFlow de tipo bool
fn1 Una función llamable, sin argumento.
fn2 Una función llamable, sin argumento.
nombre (opcional) nombre para la operación

Observaciones

  • pred no puede ser solo True o False , necesita ser un Tensor
  • La función fn1 y fn2 deben devolver el mismo número de salidas, con los mismos tipos.

Ejemplo basico

x = tf.constant(1.)
bool = tf.constant(True)

res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.

Cuando f1 y f2 devuelven tensores múltiples.

Las dos funciones fn1 y fn2 pueden devolver varios tensores, pero tienen que devolver el mismo número y tipo exacto de salidas.

x = tf.constant(1.)
bool = tf.constant(True)

def fn1():
    return tf.add(x, 1.), x

def fn2():
    return tf.add(x, 10.), x

res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]

Definir y usar las funciones f1 y f2 con parámetros.

Puede pasar parámetros a las funciones en tf.cond () usando lambda y el código es el siguiente.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)

def fn1(a, b):
  return tf.mul(a, b)

def fn2(a, b):
  return tf.add(a, b)

pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

Entonces puedes llamarlo como bramando:

with tf.Session() as sess:
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
  # The result is 2.0
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
  # The result is 5.0


Modified text is an extract of the original Stack Overflow Documentation
Licenciado bajo CC BY-SA 3.0
No afiliado a Stack Overflow