Zoeken…


Invoering

Tensorflow maakt onderscheid tussen het opslaan / herstellen van de huidige waarden van alle variabelen in een grafiek en het opslaan / herstellen van de werkelijke grafiekstructuur. Om de grafiek te herstellen, bent u vrij om de functies van Tensorflow te gebruiken of gewoon uw stukje code opnieuw te bellen, dat de grafiek in de eerste plaats heeft gebouwd. Bij het definiëren van de grafiek moet u ook nadenken over welke en hoe variabelen / ops moeten worden opgehaald nadat de grafiek is opgeslagen en hersteld.

Opmerkingen

Als ik het goed begrijp, bouw je in het gedeelte over het herstellen van het model hierboven en herstel je de variabelen. Ik geloof dat het opnieuw opbouwen van het model niet nodig is, zolang je de relevante tensoren / tijdelijke aanduidingen tf.add_to_collection() wanneer je opslaat met tf.add_to_collection() . Bijvoorbeeld:

tf.add_to_collection('cost_op', cost_op)

Later kunt u de opgeslagen grafiek herstellen en toegang krijgen tot cost_op met

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

Zelfs als je tf.add_to_collection() niet uitvoert, kun je je tensoren ophalen, maar het proces is een beetje lastiger en je moet misschien wat graafwerk doen om de juiste namen voor dingen te vinden. Bijvoorbeeld:

in een script dat een tensorflow-grafiek maakt, definiëren we een aantal tensoren lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

we kunnen ze later als volgt terugroepen:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

Er zijn een aantal manieren om de naam van een tensor te vinden - je kunt deze vinden in je grafiek op het tensorbord, of je kunt er naar zoeken met zoiets als:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

Het model opslaan

Een model opslaan in tensorflow is vrij eenvoudig.

Stel dat u een lineair model hebt met invoer x en een uitvoer y wilt voorspellen. Het verlies hier is de gemiddelde kwadratische fout (MSE). De batchgrootte is 16.

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

Hier komt het Saver-object, dat meerdere parameters kan hebben (zie doc. ).

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

Uiteindelijk trainen we het model in een tf.Session() , voor 1000 iteraties. We slaan het model hier slechts om de 100 iteraties op.

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

Na het uitvoeren van deze code, zou u de laatste 5 controlepunten in uw map moeten zien:

  • model-500 en model-500.meta
  • model-600 en model-600.meta
  • model-700 en model-700.meta model-700 . model-700.meta
  • model-800 en model-800.meta
  • model-900 en model-900.meta

Merk op dat in dit voorbeeld, terwijl de saver zowel de huidige waarden van de variabelen als een controlepunt en de structuur van de grafiek ( *.meta ) *.meta , er geen specifieke aandacht is besteed aan het ophalen van bijvoorbeeld de tijdelijke aanduidingen x en y zodra de model werd gerestaureerd. Als het herstel bijvoorbeeld ergens anders wordt uitgevoerd dan dit trainingsscript, kan het lastig zijn om x en y te halen uit de herstelde grafiek (vooral in meer gecompliceerde modellen). Om dit te voorkomen, geef altijd namen aan uw variabelen / placeholders / ops of tf.collections om tf.collections zoals getoond in een van de opmerkingen.

Model herstellen

Restaureren is ook best leuk en gemakkelijk.

Hier is een handige helpfunctie:

def restore_vars(saver, sess, chkpt_dir):
    """ Restore saved net, global score and step, and epsilons OR
    create checkpoint directory for later storage. """
    sess.run(tf.initialize_all_variables())

    checkpoint_dir = chkpt_dir 

    if not os.path.exists(checkpoint_dir):
        try:
            print("making checkpoint_dir")
            os.makedirs(checkpoint_dir)
            return False
        except OSError:
            raise

    path = tf.train.get_checkpoint_state(checkpoint_dir)
    print("path = ",path)
    if path is None:
        return False
    else:
        saver.restore(sess, path.model_checkpoint_path)
        return True

Hoofd code:

path_to_saved_model = './'
max_steps = 1

# Start a session
with tf.Session() as sess:

    ... define the model here ...

    print("define the param saver")
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # restore session if there is a saved checkpoint
    print("restoring model")
    restored = restore_vars(saver, sess, path_to_saved_model)
    print("model restored ",restored)

    # Now continue training if you so choose

    for step in range(max_steps):

        # do an update on the model (not needed)
        loss_value = sess.run([loss])
        # Now save the model
        saver.save(sess, "./model", global_step=step)


Modified text is an extract of the original Stack Overflow Documentation
Licentie onder CC BY-SA 3.0
Niet aangesloten bij Stack Overflow