サーチ…


前書き

Tensorflowは、グラフ内のすべての変数の現在の値の保存/復元と、実際のグラフ構造の保存/復元を区別します。グラフを復元するには、Tensorflowの関数を使用するか、最初にグラフを作成したコードをもう一度呼び出すだけです。グラフを定義するときには、グラフが保存され復元されたときにどのように変数やオペレーションを取り出せるかを考える必要があります。

備考

上記の復元モデルのセクションで、私が正しく理解すれば、モデルを構築してから変数を復元します。 tf.add_to_collection()を使用して保存するときに関連するテンソル/プレースホルダを追加する限り、モデルを再構築する必要はありません。例えば:

tf.add_to_collection('cost_op', cost_op)

その後、あなたは保存されたグラフを復元し、へのアクセスを得ることができますcost_op使用します

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

tf.add_to_collection()実行しなくても、テンソルを取得することはできますが、プロセスは少し面倒です。物事に適した名前を見つけるためには、いくつかの掘り下げを行う必要があります。例えば:

テンソルlab_squeezeグラフを作成するスクリプトでは、いくつかのテンソルセットlab_squeezeを定義します。

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

後に次のように思い出すことができます:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

テンソルの名前を見つけるにはいくつかの方法があります。テンソルボード上のグラフで見つけることができます。あるいは、次のような方法で検索することができます:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

モデルを保存する

モデルをテンソルフローで保存するのは簡単です。

たとえば、入力x線形モデルがあり、出力yを予測したいとします。ここでの損失は平均二乗誤差(MSE)です。バッチサイズは16です。

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

ここでは複数のパラメータ( cf.doc )を持つことができるSaverオブジェクトがあります。

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

最後に、 tf.Session()1000回の反復についてモデルを訓練します。 100回の繰り返しごとにモデルを保存するだけです。

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

このコードを実行すると、ディレクトリ内の最後の5つのチェックポイントが表示されます。

  • model-500model-500.meta
  • model-600model-600.meta
  • model-700model-700.meta
  • model-800model-800.meta
  • model-900model-900.meta

この例では、 saver実際に変数の現在の値をチェックポイントとグラフの構造( *.meta )として*.metaしていますが、プレースホルダxyを検索する方法については特に注意を払っていませんモデルが復元されました。たとえば、このトレーニングスクリプト以外の場所でリストアを行うと、復元されたグラフ(特に複雑なモデル)からxyを取り出すのは面倒です。これを避けるには、変数/プレースホルダ/オペレーションに常に名前を付けてください。また、 tf.collections使用については、備考の1つに示すように考えてtf.collections

モデルを復元する

復元も非常に簡単で簡単です。

便利なヘルパー関数があります:

def restore_vars(saver, sess, chkpt_dir):
    """ Restore saved net, global score and step, and epsilons OR
    create checkpoint directory for later storage. """
    sess.run(tf.initialize_all_variables())

    checkpoint_dir = chkpt_dir 

    if not os.path.exists(checkpoint_dir):
        try:
            print("making checkpoint_dir")
            os.makedirs(checkpoint_dir)
            return False
        except OSError:
            raise

    path = tf.train.get_checkpoint_state(checkpoint_dir)
    print("path = ",path)
    if path is None:
        return False
    else:
        saver.restore(sess, path.model_checkpoint_path)
        return True

メインコード:

path_to_saved_model = './'
max_steps = 1

# Start a session
with tf.Session() as sess:

    ... define the model here ...

    print("define the param saver")
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # restore session if there is a saved checkpoint
    print("restoring model")
    restored = restore_vars(saver, sess, path_to_saved_model)
    print("model restored ",restored)

    # Now continue training if you so choose

    for step in range(max_steps):

        # do an update on the model (not needed)
        loss_value = sess.run([loss])
        # Now save the model
        saver.save(sess, "./model", global_step=step)


Modified text is an extract of the original Stack Overflow Documentation
ライセンスを受けた CC BY-SA 3.0
所属していない Stack Overflow