Sök…


Introduktion

Tensorflöde skiljer mellan att spara / återställa de aktuella värdena för alla variabler i en graf och spara / återställa den faktiska grafstrukturen. För att återställa grafen är det fritt att använda antingen Tensorflows funktioner eller bara ringa din kodkod igen, som byggde grafen i första hand. När du definierar diagrammet bör du också tänka på vilka och hur variabler / ops ska hämtas när grafen har sparats och återställts.

Anmärkningar

I avsnittet om återställning av modellen ovan om jag förstår rätt bygger du modellen och återställer sedan variablerna. Jag anser att det inte är nödvändigt att bygga om modellen så länge du lägger till relevanta tensorer / platshållare när du sparar med tf.add_to_collection() . Till exempel:

tf.add_to_collection('cost_op', cost_op)

Senare kan du återställa den sparade grafen och få tillgång till cost_op med

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')` 
    new_saver.restore(sess, 'model')
    cost_op = tf.get_collection('cost_op')[0]

Även om du inte kör tf.add_to_collection() kan du hämta dina tensorer, men processen är lite mer besvärlig och du kanske måste göra lite grävning för att hitta rätt namn på saker. Till exempel:

i ett skript som bygger ett tensorflödesdiagram definierar vi någon uppsättning av lab_squeeze :

...
with tf.variable_scope("inputs"):
    y=tf.convert_to_tensor([[0,1],[1,0]])
    split_labels=tf.split(1,0,x,name='lab_split')
    split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
    saver=tf.train.Saver(sess,split_labels)
    saver.save("./checkpoint.chk")
    

vi kan komma ihåg dem senare på följande sätt:

with tf.Session() as sess:
    g=tf.get_default_graph()
    new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')` 
    new_saver.restore(sess, './checkpoint.chk')
    split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']

    split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0') 
    split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")

Det finns ett antal sätt att hitta namnet på en tensor - du kan hitta det i din graf på tensortavlan, eller du kan söka igenom det med något som:

sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph

Spara modellen

Att spara en modell i tensorflow är ganska enkelt.

Låt oss säga att du har en linjär modell med ingång x och vill förutsäga en utgång y . Förlusten här är det genomsnittliga kvadratfelet (MSE). Batchstorleken är 16.

# Define the model
x = tf.placeholder(tf.float32, [16, 10])  # input
y = tf.placeholder(tf.float32, [16, 1])   # output

w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)

res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))

train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

Här kommer Saver-objektet, som kan ha flera parametrar (jfr doc ).

# Define the tf.train.Saver object
# (cf. params section for all the parameters)    
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

Slutligen tränar vi modellen i en tf.Session() för 1000 iterationer. Vi sparar bara modellen varje 100 iterationer här.

# Start a session
max_steps = 1000
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.initialize_all_variables())

    for step in range(max_steps):
        feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)}  # dummy input
        _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

        # Save the model every 100 iterations
        if step % 100 == 0:
            saver.save(sess, "./model", global_step=step)

När du har kört den här koden ska du se de fem senaste kontrollpunkterna i din katalog:

  • model-500 och model-500.meta
  • model-600 och model-600.meta
  • model-700 och model-700.meta
  • model-800 och model-800.meta
  • model-900 och model-900.meta

Observera att i detta exempel, medan saver faktiskt sparar både de aktuella värdena för variablerna som en kontrollpunkt och strukturen för diagrammet ( *.meta ), har ingen särskild försiktighet tagits när det gäller att hämta t.ex. platshållarna x och y när modellen återställdes. Exempelvis om återställningen görs någon annanstans än detta träningsskript, kan det vara besvärligt att hämta x och y från den återställda grafen (särskilt i mer komplicerade modeller). För att undvika det, ge alltid namn på dina variabler / platshållare / ops eller tänk på att använda tf.collections som visas i ett av kommentarerna.

Återställa modellen

Återställa är också ganska trevligt och enkelt.

Här är en praktisk hjälpfunktion:

def restore_vars(saver, sess, chkpt_dir):
    """ Restore saved net, global score and step, and epsilons OR
    create checkpoint directory for later storage. """
    sess.run(tf.initialize_all_variables())

    checkpoint_dir = chkpt_dir 

    if not os.path.exists(checkpoint_dir):
        try:
            print("making checkpoint_dir")
            os.makedirs(checkpoint_dir)
            return False
        except OSError:
            raise

    path = tf.train.get_checkpoint_state(checkpoint_dir)
    print("path = ",path)
    if path is None:
        return False
    else:
        saver.restore(sess, path.model_checkpoint_path)
        return True

Huvudkod:

path_to_saved_model = './'
max_steps = 1

# Start a session
with tf.Session() as sess:

    ... define the model here ...

    print("define the param saver")
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

    # restore session if there is a saved checkpoint
    print("restoring model")
    restored = restore_vars(saver, sess, path_to_saved_model)
    print("model restored ",restored)

    # Now continue training if you so choose

    for step in range(max_steps):

        # do an update on the model (not needed)
        loss_value = sess.run([loss])
        # Now save the model
        saver.save(sess, "./model", global_step=step)


Modified text is an extract of the original Stack Overflow Documentation
Licensierat under CC BY-SA 3.0
Inte anslutet till Stack Overflow