Recherche…


Introduction

Tensorflow fait la distinction entre sauvegarder / restaurer les valeurs actuelles de toutes les variables d'un graphe et sauvegarder / restaurer la structure réelle du graphe. Pour restaurer le graphique, vous êtes libre d'utiliser les fonctions de Tensorflow ou d'appeler à nouveau votre morceau de code, ce qui a créé le graphique en premier lieu. Lors de la définition du graphique, vous devez également réfléchir à la manière et à la manière dont les variables / opérations doivent être récupérables une fois le graphique enregistré et restauré.

Remarques

Dans la section du modèle de restauration ci-dessus, si je comprends bien, vous construisez le modèle, puis vous restaurez les variables. Je crois que la reconstruction du modèle n'est pas nécessaire tant que vous ajoutez les tenseurs / espaces réservés pertinents lors de l'enregistrement avec tf.add_to_collection() . Par exemple:

tf.add_to_collection('cost_op', cost_op)

Ensuite, vous pouvez restaurer le graphique enregistré et accéder à cost_op utilisant

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

Même si vous tf.add_to_collection() pas tf.add_to_collection() , vous pouvez récupérer vos tenseurs, mais le processus est un peu plus compliqué et vous devrez peut-être faire quelques recherches pour trouver les noms corrects. Par exemple:

dans un script qui construit un graphe de tensorflow, nous définissons un ensemble de tenseurs lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

nous pouvons les rappeler plus tard comme suit:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

Il y a plusieurs façons de trouver le nom d'un tenseur - vous pouvez le trouver dans votre graphique sur le tableau des tenseurs, ou vous pouvez le rechercher avec quelque chose comme:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

Sauvegarder le modèle

Enregistrer un modèle dans le tensorflow est assez facile.

Disons que vous avez un modèle linéaire avec entrée x et que vous voulez prédire une sortie y . La perte ici est l'erreur quadratique moyenne (MSE). La taille du lot est de 16.

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

Voici l’objet Saver, qui peut avoir plusieurs paramètres (cf. doc ).

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

Enfin, nous formons le modèle dans un tf.Session() , pour 1000 itérations. Nous ne sauvegardons le modèle que toutes les 100 itérations ici.

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

Après avoir exécuté ce code, vous devriez voir les 5 derniers points de contrôle dans votre répertoire:

  • model-500 et model-500.meta
  • model-600 et model-600.meta
  • model-700 et model-700.meta
  • model-800 et model-800.meta
  • model-900 et model-900.meta

Notez que dans cet exemple, alors que l' saver enregistre en fait à la fois les valeurs actuelles des variables comme un point de contrôle et la structure du graphique ( *.meta ), aucun soin particulier a été prise WRT comment récupérer par exemple les espaces réservés x et y , une fois la le modèle a été restauré. Par exemple, si la restauration est effectuée ailleurs que dans ce script de formation, il peut être compliqué de récupérer x et y dans le graphique restauré (en particulier dans les modèles plus compliqués). Pour éviter cela, donnez toujours des noms à vos variables / espaces réservés / ops ou pensez à utiliser tf.collections comme indiqué dans l'une des remarques.

Restaurer le modèle

La restauration est également très agréable et facile.

Voici une fonction d'aide pratique:

def restore_vars(saver, sess, chkpt_dir):
    """ Restore saved net, global score and step, and epsilons OR
    create checkpoint directory for later storage. """
    sess.run(tf.initialize_all_variables())

    checkpoint_dir = chkpt_dir 

    if not os.path.exists(checkpoint_dir):
        try:
            print("making checkpoint_dir")
            os.makedirs(checkpoint_dir)
            return False
        except OSError:
            raise

    path = tf.train.get_checkpoint_state(checkpoint_dir)
    print("path = ",path)
    if path is None:
        return False
    else:
        saver.restore(sess, path.model_checkpoint_path)
        return True

Code principal:

path_to_saved_model = './'
max_steps = 1

# Start a session
with tf.Session() as sess:

    ... define the model here ...

    print("define the param saver")
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # restore session if there is a saved checkpoint
    print("restoring model")
    restored = restore_vars(saver, sess, path_to_saved_model)
    print("model restored ",restored)

    # Now continue training if you so choose

    for step in range(max_steps):

        # do an update on the model (not needed)
        loss_value = sess.run([loss])
        # Now save the model
        saver.save(sess, "./model", global_step=step)


Modified text is an extract of the original Stack Overflow Documentation
Sous licence CC BY-SA 3.0
Non affilié à Stack Overflow