tensorflow
Een model opslaan en herstellen in TensorFlow
Zoeken…
Invoering
Tensorflow maakt onderscheid tussen het opslaan / herstellen van de huidige waarden van alle variabelen in een grafiek en het opslaan / herstellen van de werkelijke grafiekstructuur. Om de grafiek te herstellen, bent u vrij om de functies van Tensorflow te gebruiken of gewoon uw stukje code opnieuw te bellen, dat de grafiek in de eerste plaats heeft gebouwd. Bij het definiëren van de grafiek moet u ook nadenken over welke en hoe variabelen / ops moeten worden opgehaald nadat de grafiek is opgeslagen en hersteld.
Opmerkingen
Als ik het goed begrijp, bouw je in het gedeelte over het herstellen van het model hierboven en herstel je de variabelen. Ik geloof dat het opnieuw opbouwen van het model niet nodig is, zolang je de relevante tensoren / tijdelijke aanduidingen tf.add_to_collection()
wanneer je opslaat met tf.add_to_collection()
. Bijvoorbeeld:
tf.add_to_collection('cost_op', cost_op)
Later kunt u de opgeslagen grafiek herstellen en toegang krijgen tot cost_op
met
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model.meta')`
new_saver.restore(sess, 'model')
cost_op = tf.get_collection('cost_op')[0]
Zelfs als je tf.add_to_collection()
niet uitvoert, kun je je tensoren ophalen, maar het proces is een beetje lastiger en je moet misschien wat graafwerk doen om de juiste namen voor dingen te vinden. Bijvoorbeeld:
in een script dat een tensorflow-grafiek maakt, definiëren we een aantal tensoren lab_squeeze
:
...
with tf.variable_scope("inputs"):
y=tf.convert_to_tensor([[0,1],[1,0]])
split_labels=tf.split(1,0,x,name='lab_split')
split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
saver=tf.train.Saver(sess,split_labels)
saver.save("./checkpoint.chk")
we kunnen ze later als volgt terugroepen:
with tf.Session() as sess:
g=tf.get_default_graph()
new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')`
new_saver.restore(sess, './checkpoint.chk')
split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']
split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0')
split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")
Er zijn een aantal manieren om de naam van een tensor te vinden - je kunt deze vinden in je grafiek op het tensorbord, of je kunt er naar zoeken met zoiets als:
sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph
Het model opslaan
Een model opslaan in tensorflow is vrij eenvoudig.
Stel dat u een lineair model hebt met invoer x
en een uitvoer y
wilt voorspellen. Het verlies hier is de gemiddelde kwadratische fout (MSE). De batchgrootte is 16.
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
Hier komt het Saver-object, dat meerdere parameters kan hebben (zie doc. ).
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
Uiteindelijk trainen we het model in een tf.Session()
, voor 1000
iteraties. We slaan het model hier slechts om de 100
iteraties op.
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
Na het uitvoeren van deze code, zou u de laatste 5 controlepunten in uw map moeten zien:
-
model-500
enmodel-500.meta
-
model-600
enmodel-600.meta
-
model-700
enmodel-700.meta
model-700
.model-700.meta
-
model-800
enmodel-800.meta
-
model-900
enmodel-900.meta
Merk op dat in dit voorbeeld, terwijl de saver
zowel de huidige waarden van de variabelen als een controlepunt en de structuur van de grafiek ( *.meta
) *.meta
, er geen specifieke aandacht is besteed aan het ophalen van bijvoorbeeld de tijdelijke aanduidingen x
en y
zodra de model werd gerestaureerd. Als het herstel bijvoorbeeld ergens anders wordt uitgevoerd dan dit trainingsscript, kan het lastig zijn om x
en y
te halen uit de herstelde grafiek (vooral in meer gecompliceerde modellen). Om dit te voorkomen, geef altijd namen aan uw variabelen / placeholders / ops of tf.collections
om tf.collections
zoals getoond in een van de opmerkingen.
Model herstellen
Restaureren is ook best leuk en gemakkelijk.
Hier is een handige helpfunctie:
def restore_vars(saver, sess, chkpt_dir):
""" Restore saved net, global score and step, and epsilons OR
create checkpoint directory for later storage. """
sess.run(tf.initialize_all_variables())
checkpoint_dir = chkpt_dir
if not os.path.exists(checkpoint_dir):
try:
print("making checkpoint_dir")
os.makedirs(checkpoint_dir)
return False
except OSError:
raise
path = tf.train.get_checkpoint_state(checkpoint_dir)
print("path = ",path)
if path is None:
return False
else:
saver.restore(sess, path.model_checkpoint_path)
return True
Hoofd code:
path_to_saved_model = './'
max_steps = 1
# Start a session
with tf.Session() as sess:
... define the model here ...
print("define the param saver")
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
# restore session if there is a saved checkpoint
print("restoring model")
restored = restore_vars(saver, sess, path_to_saved_model)
print("model restored ",restored)
# Now continue training if you so choose
for step in range(max_steps):
# do an update on the model (not needed)
loss_value = sess.run([loss])
# Now save the model
saver.save(sess, "./model", global_step=step)