tensorflow
Usando se la condizione all'interno del grafico TensorFlow con tf.cond
Ricerca…
Parametri
Parametro | Dettagli |
---|---|
pred | un tensore TensorFlow di tipo bool |
fn1 | una funzione callable, senza argomenti |
fn2 | una funzione callable, senza argomenti |
nome | (facoltativo) nome per l'operazione |
Osservazioni
-
pred
non può essere soloTrue
oFalse
, deve essere un Tensore - La funzione
fn1
efn2
dovrebbero restituire lo stesso numero di uscite, con gli stessi tipi.
Esempio di base
x = tf.constant(1.)
bool = tf.constant(True)
res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.
Quando f1 e f2 restituiscono più tensioni
Le due funzioni fn1
e fn2
possono restituire più fn2
, ma devono restituire esattamente lo stesso numero e tipi di uscite.
x = tf.constant(1.)
bool = tf.constant(True)
def fn1():
return tf.add(x, 1.), x
def fn2():
return tf.add(x, 10.), x
res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]
definire e utilizzare le funzioni f1 e f2 con i parametri
Puoi passare i parametri alle funzioni in tf.cond () usando lambda e il codice è come muggito.
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)
def fn1(a, b):
return tf.mul(a, b)
def fn2(a, b):
return tf.add(a, b)
pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
Quindi puoi chiamarlo come muggito:
with tf.Session() as sess:
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
# The result is 2.0
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
# The result is 5.0
Modified text is an extract of the original Stack Overflow Documentation
Autorizzato sotto CC BY-SA 3.0
Non affiliato con Stack Overflow