tensorflow
tf.cond가있는 TensorFlow 그래프 내부의 if 조건 사용
수색…
매개 변수
매개 변수 | 세부 |
---|---|
pred | 형의 TensorFlow 텐서 bool |
fn1 | 인수가없는 호출 가능 함수 |
fn2 | 인수가없는 호출 가능 함수 |
이름 | (선택 사항) 작업 이름 |
비고
-
pred
는True
또는False
일 수 없으며 Tensor 여야합니다. - 함수
fn1
과fn2
는 동일한 유형의 출력과 동일한 수의 출력을 반환해야합니다.
기본 예제
x = tf.constant(1.)
bool = tf.constant(True)
res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.
f1과 f2가 여러 개의 텐서를 반환하면
두 함수 fn1
과 fn2
는 여러 개의 텐서를 반환 할 수 있지만 정확한 개수와 유형의 출력을 반환해야합니다.
x = tf.constant(1.)
bool = tf.constant(True)
def fn1():
return tf.add(x, 1.), x
def fn2():
return tf.add(x, 10.), x
res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]
매개 변수가있는 함수 f1과 f2 정의 및 사용
람다를 사용하여 tf.cond ()의 함수에 매개 변수를 전달할 수 있으며 코드는 다음과 같습니다.
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)
def fn1(a, b):
return tf.mul(a, b)
def fn2(a, b):
return tf.add(a, b)
pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
다음과 같이 호출 할 수 있습니다.
with tf.Session() as sess:
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
# The result is 2.0
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
# The result is 5.0
Modified text is an extract of the original Stack Overflow Documentation
아래 라이선스 CC BY-SA 3.0
와 제휴하지 않음 Stack Overflow