tensorflow
tf.condでTensorFlowグラフのif条件を使用する
サーチ…
パラメーター
パラメータ | 詳細 |
---|---|
前 | bool 型のTensorFlowテンソル |
fn1 | 引数を持たない呼び出し可能な関数 |
fn2 | 引数を持たない呼び出し可能な関数 |
名 | (オプション)操作の名前 |
備考
-
pred
はTrue
かFalse
であることはできません、それはテンソルである必要があります - 関数
fn1
とfn2
は、同じタイプの出力を同じ数だけ返す必要があります。
基本的な例
x = tf.constant(1.)
bool = tf.constant(True)
res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.
f1とf2が複数のテンソルを返すとき
2つの関数fn1
とfn2
は複数のテンソルを返すことができますが、それらは同じ数とタイプの出力を返す必要があります。
x = tf.constant(1.)
bool = tf.constant(True)
def fn1():
return tf.add(x, 1.), x
def fn2():
return tf.add(x, 10.), x
res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]
パラメータf1とf2を定義して使用する
ラムダを使用してtf.cond()の関数にパラメータを渡すことができ、コードは次のようになります。
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)
def fn1(a, b):
return tf.mul(a, b)
def fn2(a, b):
return tf.add(a, b)
pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
次に、それを次のように呼び出すことができます:
with tf.Session() as sess:
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
# The result is 2.0
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
# The result is 5.0
Modified text is an extract of the original Stack Overflow Documentation
ライセンスを受けた CC BY-SA 3.0
所属していない Stack Overflow