tensorflow
tf.py_funcを使ってカスタム操作を作成する(CPUのみ)
サーチ…
パラメーター
パラメータ | 詳細 |
---|---|
機能 | python関数は、 numpy配列を入力として受け取り、 numpy配列をその出力として返します |
inp | テンソルのリスト(入力) |
タウト | func の出力のテンソルフローデータ型のリスト |
基本的な例
tf.py_func(func, inp, Tout)
演算子は、Python関数を呼び出し、テンソルinp
リストをfunc
するTensorFlow演算を作成します。
tf.py_func(func, inp, Tout)
のドキュメントを参照してください。
警告 : tf.py_func()
操作はCPU上でのみ実行されます。分散TensorFlowを使用している場合、 tf.py_func()
操作は、クライアントと同じプロセスで CPUデバイスに配置する必要があります。
def func(x):
return 2*x
x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1
tf.py_funcを使用する理由
tf.py_func()
演算子を使用すると、TensorFlowグラフの途中で任意のPythonコードを実行できます。同等のTensorFlow演算子(まだ)が存在しないカスタムNumPy演算子をラップする場合に特に便利です。 tf.py_func()
追加することは、グラフ内でsess.run()
を使用する代わりに使用sess.run()
ます。
これを行うもう1つの方法は、グラフを2つの部分に分けることです。
# Part 1 of the graph
inputs = ... # in the TF graph
# Get the numpy array and apply func
val = sess.run(inputs) # get the value of inputs
output_val = func(val) # numpy array
# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...
# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})
tf.py_func
とはるかに簡単です:
# Part 1 of the graph
inputs = ...
# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]
# Part 2 of the graph
train_op = ...
# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)
Modified text is an extract of the original Stack Overflow Documentation
ライセンスを受けた CC BY-SA 3.0
所属していない Stack Overflow