tensorflow
Enregistrer et restaurer un modèle dans TensorFlow
Recherche…
Introduction
Tensorflow fait la distinction entre sauvegarder / restaurer les valeurs actuelles de toutes les variables d'un graphe et sauvegarder / restaurer la structure réelle du graphe. Pour restaurer le graphique, vous êtes libre d'utiliser les fonctions de Tensorflow ou d'appeler à nouveau votre morceau de code, ce qui a créé le graphique en premier lieu. Lors de la définition du graphique, vous devez également réfléchir à la manière et à la manière dont les variables / opérations doivent être récupérables une fois le graphique enregistré et restauré.
Remarques
Dans la section du modèle de restauration ci-dessus, si je comprends bien, vous construisez le modèle, puis vous restaurez les variables. Je crois que la reconstruction du modèle n'est pas nécessaire tant que vous ajoutez les tenseurs / espaces réservés pertinents lors de l'enregistrement avec tf.add_to_collection()
. Par exemple:
tf.add_to_collection('cost_op', cost_op)
Ensuite, vous pouvez restaurer le graphique enregistré et accéder à cost_op
utilisant
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model.meta')`
new_saver.restore(sess, 'model')
cost_op = tf.get_collection('cost_op')[0]
Même si vous tf.add_to_collection()
pas tf.add_to_collection()
, vous pouvez récupérer vos tenseurs, mais le processus est un peu plus compliqué et vous devrez peut-être faire quelques recherches pour trouver les noms corrects. Par exemple:
dans un script qui construit un graphe de tensorflow, nous définissons un ensemble de tenseurs lab_squeeze
:
...
with tf.variable_scope("inputs"):
y=tf.convert_to_tensor([[0,1],[1,0]])
split_labels=tf.split(1,0,x,name='lab_split')
split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
saver=tf.train.Saver(sess,split_labels)
saver.save("./checkpoint.chk")
nous pouvons les rappeler plus tard comme suit:
with tf.Session() as sess:
g=tf.get_default_graph()
new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')`
new_saver.restore(sess, './checkpoint.chk')
split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']
split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0')
split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")
Il y a plusieurs façons de trouver le nom d'un tenseur - vous pouvez le trouver dans votre graphique sur le tableau des tenseurs, ou vous pouvez le rechercher avec quelque chose comme:
sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph
Sauvegarder le modèle
Enregistrer un modèle dans le tensorflow est assez facile.
Disons que vous avez un modèle linéaire avec entrée x
et que vous voulez prédire une sortie y
. La perte ici est l'erreur quadratique moyenne (MSE). La taille du lot est de 16.
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
Voici l’objet Saver, qui peut avoir plusieurs paramètres (cf. doc ).
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
Enfin, nous formons le modèle dans un tf.Session()
, pour 1000
itérations. Nous ne sauvegardons le modèle que toutes les 100
itérations ici.
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
Après avoir exécuté ce code, vous devriez voir les 5 derniers points de contrôle dans votre répertoire:
-
model-500
etmodel-500.meta
-
model-600
etmodel-600.meta
-
model-700
etmodel-700.meta
-
model-800
etmodel-800.meta
-
model-900
etmodel-900.meta
Notez que dans cet exemple, alors que l' saver
enregistre en fait à la fois les valeurs actuelles des variables comme un point de contrôle et la structure du graphique ( *.meta
), aucun soin particulier a été prise WRT comment récupérer par exemple les espaces réservés x
et y
, une fois la le modèle a été restauré. Par exemple, si la restauration est effectuée ailleurs que dans ce script de formation, il peut être compliqué de récupérer x
et y
dans le graphique restauré (en particulier dans les modèles plus compliqués). Pour éviter cela, donnez toujours des noms à vos variables / espaces réservés / ops ou pensez à utiliser tf.collections
comme indiqué dans l'une des remarques.
Restaurer le modèle
La restauration est également très agréable et facile.
Voici une fonction d'aide pratique:
def restore_vars(saver, sess, chkpt_dir):
""" Restore saved net, global score and step, and epsilons OR
create checkpoint directory for later storage. """
sess.run(tf.initialize_all_variables())
checkpoint_dir = chkpt_dir
if not os.path.exists(checkpoint_dir):
try:
print("making checkpoint_dir")
os.makedirs(checkpoint_dir)
return False
except OSError:
raise
path = tf.train.get_checkpoint_state(checkpoint_dir)
print("path = ",path)
if path is None:
return False
else:
saver.restore(sess, path.model_checkpoint_path)
return True
Code principal:
path_to_saved_model = './'
max_steps = 1
# Start a session
with tf.Session() as sess:
... define the model here ...
print("define the param saver")
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
# restore session if there is a saved checkpoint
print("restoring model")
restored = restore_vars(saver, sess, path_to_saved_model)
print("model restored ",restored)
# Now continue training if you so choose
for step in range(max_steps):
# do an update on the model (not needed)
loss_value = sess.run([loss])
# Now save the model
saver.save(sess, "./model", global_step=step)