수색…


매개 변수

매개 변수 세부
기능 python 함수는 numpy 배열 을 입력으로 사용하고 numpy 배열 을 출력으로 반환합니다.
inp Tensors 목록 (입력)
암표 장수 func 의 출력에 대한 텐서 흐름 데이터 유형 목록

기본 예제

tf.py_func(func, inp, Tout) 연산자는 TensorFlow 연산을 생성하여 파이썬 함수를 호출하고 tensors inp 목록에서 func 을 호출합니다.

tf.py_func(func, inp, Tout) 대한 문서 를 참조하십시오.

경고 : tf.py_func() 작업은 CPU에서만 실행됩니다. 분산 TensorFlow를 사용하는 경우 tf.py_func() 작업은 클라이언트 와 동일한 프로세스에서 CPU 장치 배치해야합니다.

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

tf.py_func를 사용하는 이유

tf.py_func() 연산자를 사용하면 TensorFlow 그래프 중간에서 임의의 Python 코드를 실행할 수 있습니다. 이에 해당하는 TensorFlow 연산자가 아직 존재하지 않는 사용자 정의 NumPy 연산자를 래핑하는 경우에 특히 편리합니다. 그래프 내에서 sess.run() 호출을 사용하는 대신 tf.py_func() 추가하는 방법이 있습니다.

이를 수행하는 또 다른 방법은 두 부분으로 그래프를 자르는 것입니다.

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

tf.py_func 사용하면 훨씬 쉽습니다.

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)


Modified text is an extract of the original Stack Overflow Documentation
아래 라이선스 CC BY-SA 3.0
와 제휴하지 않음 Stack Overflow