tensorflow
Gebruiken van voorwaarde in de TensorFlow-grafiek met tf.cond
Zoeken…
parameters
Parameter | Details |
---|---|
pred | een TensorFlow-tensor van het type bool |
fn1 | een opvraagbare functie, zonder argument |
FN2 | een opvraagbare functie, zonder argument |
naam | (optioneel) naam voor de bewerking |
Opmerkingen
-
pred
kan niet alleenTrue
ofFalse
, het moet een Tensor zijn - De functie
fn1
enfn2
moet hetzelfde aantal uitgangen retourneren, met dezelfde typen.
Basis voorbeeld
x = tf.constant(1.)
bool = tf.constant(True)
res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.
Wanneer f1 en f2 meerdere tensoren retourneren
De twee functies fn1
en fn2
kunnen meerdere tensoren retourneren, maar ze moeten exact hetzelfde aantal en hetzelfde type uitgangen retourneren.
x = tf.constant(1.)
bool = tf.constant(True)
def fn1():
return tf.add(x, 1.), x
def fn2():
return tf.add(x, 10.), x
res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]
definieer en gebruik functies f1 en f2 met parameters
U kunt parameters doorgeven aan de functies in tf.cond () met behulp van lambda en de code is hieronder.
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)
def fn1(a, b):
return tf.mul(a, b)
def fn2(a, b):
return tf.add(a, b)
pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
Dan kun je het als volgt noemen:
with tf.Session() as sess:
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
# The result is 2.0
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
# The result is 5.0
Modified text is an extract of the original Stack Overflow Documentation
Licentie onder CC BY-SA 3.0
Niet aangesloten bij Stack Overflow