tensorflow
Speichern und Wiederherstellen eines Modells in TensorFlow
Suche…
Einführung
Tensorflow unterscheidet zwischen dem Speichern / Wiederherstellen der aktuellen Werte aller Variablen in einem Diagramm und dem Speichern / Wiederherstellen der tatsächlichen Diagrammstruktur. Um die Grafik wiederherzustellen, können Sie entweder die Tensorflow-Funktionen verwenden oder einfach Ihren Code erneut aufrufen, der die Grafik ursprünglich erstellt hat. Bei der Definition des Graphen sollten Sie auch darüber nachdenken, welche Variablen und Operationen abrufbar sind, nachdem der Graph gespeichert und wiederhergestellt wurde.
Bemerkungen
Wenn ich im Abschnitt zum Wiederherstellen des Modells richtig verstehe, erstellen Sie das Modell und stellen die Variablen wieder her. Ich bin der Meinung, dass ein Neuaufbau des Modells nicht erforderlich ist, solange Sie beim Speichern mit tf.add_to_collection()
die entsprechenden Tensoren / Platzhalter tf.add_to_collection()
. Zum Beispiel:
tf.add_to_collection('cost_op', cost_op)
Später können Sie dann das gespeicherte Diagramm wiederherstellen und mit cost_op
zugreifen
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model.meta')`
new_saver.restore(sess, 'model')
cost_op = tf.get_collection('cost_op')[0]
Selbst wenn Sie tf.add_to_collection()
nicht ausführen, können Sie Ihre Tensoren abrufen, der Vorgang ist jedoch etwas umständlicher und Sie müssen möglicherweise etwas suchen, um die richtigen Namen für die Dinge zu finden. Zum Beispiel:
In einem Skript, das ein Tensorflow-Diagramm erstellt, definieren wir eine Reihe von Tensoren lab_squeeze
:
...
with tf.variable_scope("inputs"):
y=tf.convert_to_tensor([[0,1],[1,0]])
split_labels=tf.split(1,0,x,name='lab_split')
split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
saver=tf.train.Saver(sess,split_labels)
saver.save("./checkpoint.chk")
Wir können sie später wie folgt abrufen:
with tf.Session() as sess:
g=tf.get_default_graph()
new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')`
new_saver.restore(sess, './checkpoint.chk')
split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']
split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0')
split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")
Es gibt verschiedene Möglichkeiten, den Namen eines Tensors zu finden - Sie können ihn in Ihrem Graphen auf der Tensorplatine finden, oder Sie können mit etwas wie dem folgenden suchen:
sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph
Speichern des Modells
Das Speichern eines Modells im Tensorflow ist ziemlich einfach.
Nehmen wir an, Sie haben ein lineares Modell mit Eingabe x
und möchten eine Ausgabe y
vorhersagen. Der Verlust ist hier der mittlere quadratische Fehler (MSE). Die Losgröße beträgt 16.
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
Hier kommt das Saver-Objekt, das mehrere Parameter haben kann (vgl. Doc ).
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
Zum Schluss trainieren wir das Modell in einer tf.Session()
für 1000
Iterationen. Wir speichern das Modell hier nur alle 100
Iterationen.
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
Nachdem Sie diesen Code ausgeführt haben, sollten Sie die letzten 5 Prüfpunkte in Ihrem Verzeichnis sehen:
-
model-500
undmodel-500.meta
-
model-600
undmodel-600.meta
-
model-700
undmodel-700.meta
-
model-800
undmodel-800.meta
-
model-900
undmodel-900.meta
Beachten Sie, dass in diesem Beispiel der saver
sowohl die aktuellen Werte der Variablen als Prüfpunkt als auch die Struktur des Diagramms ( *.meta
) *.meta
, es wurde jedoch keine besondere Sorgfalt darauf verwendet, z. B. die Platzhalter x
und y
einmal abzurufen Modell wurde restauriert. Wenn die Wiederherstellung beispielsweise an einem anderen Ort als diesem Trainingsskript durchgeführt wird, kann es umständlich sein, x
und y
aus dem wiederhergestellten Diagramm abzurufen (insbesondere bei komplizierteren Modellen). Um dies zu vermeiden, sollten Sie Ihren Variablen / Platzhaltern / tf.collections
immer Namen geben oder die Verwendung von tf.collections
in tf.collections
wie in einer der Anmerkungen gezeigt.
Modell wiederherstellen
Das Wiederherstellen ist auch ganz nett.
Hier ist eine praktische Hilfsfunktion:
def restore_vars(saver, sess, chkpt_dir):
""" Restore saved net, global score and step, and epsilons OR
create checkpoint directory for later storage. """
sess.run(tf.initialize_all_variables())
checkpoint_dir = chkpt_dir
if not os.path.exists(checkpoint_dir):
try:
print("making checkpoint_dir")
os.makedirs(checkpoint_dir)
return False
except OSError:
raise
path = tf.train.get_checkpoint_state(checkpoint_dir)
print("path = ",path)
if path is None:
return False
else:
saver.restore(sess, path.model_checkpoint_path)
return True
Haupt code:
path_to_saved_model = './'
max_steps = 1
# Start a session
with tf.Session() as sess:
... define the model here ...
print("define the param saver")
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
# restore session if there is a saved checkpoint
print("restoring model")
restored = restore_vars(saver, sess, path_to_saved_model)
print("model restored ",restored)
# Now continue training if you so choose
for step in range(max_steps):
# do an update on the model (not needed)
loss_value = sess.run([loss])
# Now save the model
saver.save(sess, "./model", global_step=step)