サーチ…


パラメーター

パラメータ詳細
機能 python関数は、 numpy配列を入力として受け取り、 numpy配列をその出力として返します
inp テンソルのリスト(入力)
タウト funcの出力のテンソルフローデータ型のリスト

基本的な例

tf.py_func(func, inp, Tout)演算子は、Python関数を呼び出し、テンソルinpリストをfuncするTensorFlow演算を作成します。

tf.py_func(func, inp, Tout)ドキュメントを参照してください。

警告tf.py_func()操作はCPU上でのみ実行されます。分散TensorFlowを使用している場合、 tf.py_func()操作は、クライアントと同じプロセスで CPUデバイス配置する必要があります。

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

tf.py_funcを使用する理由

tf.py_func()演算子を使用すると、TensorFlowグラフの途中で任意のPythonコードを実行できます。同等のTensorFlow演算子(まだ)が存在しないカスタムNumPy演算子をラップする場合に特に便利です。 tf.py_func()追加することは、グラフ内でsess.run()を使用する代わりに使用sess.run()ます。

これを行うもう1つの方法は、グラフを2つの部分に分けることです。

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

tf.py_funcとはるかに簡単です:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)


Modified text is an extract of the original Stack Overflow Documentation
ライセンスを受けた CC BY-SA 3.0
所属していない Stack Overflow