Suche…


Einführung

Tensorflow unterscheidet zwischen dem Speichern / Wiederherstellen der aktuellen Werte aller Variablen in einem Diagramm und dem Speichern / Wiederherstellen der tatsächlichen Diagrammstruktur. Um die Grafik wiederherzustellen, können Sie entweder die Tensorflow-Funktionen verwenden oder einfach Ihren Code erneut aufrufen, der die Grafik ursprünglich erstellt hat. Bei der Definition des Graphen sollten Sie auch darüber nachdenken, welche Variablen und Operationen abrufbar sind, nachdem der Graph gespeichert und wiederhergestellt wurde.

Bemerkungen

Wenn ich im Abschnitt zum Wiederherstellen des Modells richtig verstehe, erstellen Sie das Modell und stellen die Variablen wieder her. Ich bin der Meinung, dass ein Neuaufbau des Modells nicht erforderlich ist, solange Sie beim Speichern mit tf.add_to_collection() die entsprechenden Tensoren / Platzhalter tf.add_to_collection() . Zum Beispiel:

tf.add_to_collection('cost_op', cost_op)

Später können Sie dann das gespeicherte Diagramm wiederherstellen und mit cost_op zugreifen

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

Selbst wenn Sie tf.add_to_collection() nicht ausführen, können Sie Ihre Tensoren abrufen, der Vorgang ist jedoch etwas umständlicher und Sie müssen möglicherweise etwas suchen, um die richtigen Namen für die Dinge zu finden. Zum Beispiel:

In einem Skript, das ein Tensorflow-Diagramm erstellt, definieren wir eine Reihe von Tensoren lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

Wir können sie später wie folgt abrufen:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

Es gibt verschiedene Möglichkeiten, den Namen eines Tensors zu finden - Sie können ihn in Ihrem Graphen auf der Tensorplatine finden, oder Sie können mit etwas wie dem folgenden suchen:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

Speichern des Modells

Das Speichern eines Modells im Tensorflow ist ziemlich einfach.

Nehmen wir an, Sie haben ein lineares Modell mit Eingabe x und möchten eine Ausgabe y vorhersagen. Der Verlust ist hier der mittlere quadratische Fehler (MSE). Die Losgröße beträgt 16.

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

Hier kommt das Saver-Objekt, das mehrere Parameter haben kann (vgl. Doc ).

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

Zum Schluss trainieren wir das Modell in einer tf.Session() für 1000 Iterationen. Wir speichern das Modell hier nur alle 100 Iterationen.

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

Nachdem Sie diesen Code ausgeführt haben, sollten Sie die letzten 5 Prüfpunkte in Ihrem Verzeichnis sehen:

  • model-500 und model-500.meta
  • model-600 und model-600.meta
  • model-700 und model-700.meta
  • model-800 und model-800.meta
  • model-900 und model-900.meta

Beachten Sie, dass in diesem Beispiel der saver sowohl die aktuellen Werte der Variablen als Prüfpunkt als auch die Struktur des Diagramms ( *.meta ) *.meta , es wurde jedoch keine besondere Sorgfalt darauf verwendet, z. B. die Platzhalter x und y einmal abzurufen Modell wurde restauriert. Wenn die Wiederherstellung beispielsweise an einem anderen Ort als diesem Trainingsskript durchgeführt wird, kann es umständlich sein, x und y aus dem wiederhergestellten Diagramm abzurufen (insbesondere bei komplizierteren Modellen). Um dies zu vermeiden, sollten Sie Ihren Variablen / Platzhaltern / tf.collections immer Namen geben oder die Verwendung von tf.collections in tf.collections wie in einer der Anmerkungen gezeigt.

Modell wiederherstellen

Das Wiederherstellen ist auch ganz nett.

Hier ist eine praktische Hilfsfunktion:

def restore_vars(saver, sess, chkpt_dir):
    """ Restore saved net, global score and step, and epsilons OR
    create checkpoint directory for later storage. """
    sess.run(tf.initialize_all_variables())

    checkpoint_dir = chkpt_dir 

    if not os.path.exists(checkpoint_dir):
        try:
            print("making checkpoint_dir")
            os.makedirs(checkpoint_dir)
            return False
        except OSError:
            raise

    path = tf.train.get_checkpoint_state(checkpoint_dir)
    print("path = ",path)
    if path is None:
        return False
    else:
        saver.restore(sess, path.model_checkpoint_path)
        return True

Haupt code:

path_to_saved_model = './'
max_steps = 1

# Start a session
with tf.Session() as sess:

    ... define the model here ...

    print("define the param saver")
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # restore session if there is a saved checkpoint
    print("restoring model")
    restored = restore_vars(saver, sess, path_to_saved_model)
    print("model restored ",restored)

    # Now continue training if you so choose

    for step in range(max_steps):

        # do an update on the model (not needed)
        loss_value = sess.run([loss])
        # Now save the model
        saver.save(sess, "./model", global_step=step)


Modified text is an extract of the original Stack Overflow Documentation
Lizenziert unter CC BY-SA 3.0
Nicht angeschlossen an Stack Overflow