tensorflow
Сохранение и восстановление модели в TensorFlow
Поиск…
Вступление
Tensorflow различает сохранение / восстановление текущих значений всех переменных в графе и сохранение / восстановление фактической структуры графика. Чтобы восстановить график, вы можете использовать либо функции Tensorflow, либо снова называть свой фрагмент кода, что построил граф в первую очередь. При определении графика вы также должны подумать о том, какие переменные / операторы должны быть восстановлены после того, как график будет сохранен и восстановлен.
замечания
В восстановительной части модели выше, если я правильно понимаю, вы строите модель, а затем восстанавливаете переменные. Я считаю, что восстановление модели не требуется, если вы добавляете соответствующие тензоры / заполнители при сохранении с помощью tf.add_to_collection()
. Например:
tf.add_to_collection('cost_op', cost_op)
Затем вы можете восстановить сохраненный график и получить доступ к cost_op
используя
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model.meta')`
new_saver.restore(sess, 'model')
cost_op = tf.get_collection('cost_op')[0]
Даже если вы не запустите tf.add_to_collection()
, вы можете получить свои тензоры, но процесс немного более громоздкий, и вам, возможно, придется кое-что сделать, чтобы найти правильные имена для вещей. Например:
в скрипте, который строит график тензорного потока, мы определяем некоторый набор тензоров lab_squeeze
:
...
with tf.variable_scope("inputs"):
y=tf.convert_to_tensor([[0,1],[1,0]])
split_labels=tf.split(1,0,x,name='lab_split')
split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
saver=tf.train.Saver(sess,split_labels)
saver.save("./checkpoint.chk")
мы можем вспомнить их позже:
with tf.Session() as sess:
g=tf.get_default_graph()
new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')`
new_saver.restore(sess, './checkpoint.chk')
split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']
split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0')
split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")
Существует несколько способов найти имя тензора - вы можете найти его в своем графике на тензорной доске, или вы можете найти для него что-то вроде:
sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph
Сохранение модели
Сохранение модели в тензорном потоке довольно просто.
Предположим, у вас есть линейная модель с входом x
и вы хотите предсказать выход y
. Потеря здесь представляет собой среднеквадратичную ошибку (MSE). Размер партии - 16.
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
Здесь находится объект Saver, который может иметь несколько параметров (см. Doc ).
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
Наконец, мы tf.Session()
модель в tf.Session()
, на 1000
итераций. Мы сохраняем модель только каждые 100
итераций.
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
sess.run(tf.initialize_all_variables())
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, "./model", global_step=step)
После запуска этого кода вы должны увидеть последние 5 контрольных точек в своем каталоге:
-
model-500
иmodel-500.meta
-
model-600
иmodel-600.meta
-
model-700
иmodel-700.meta
-
model-800
иmodel-800.meta
-
model-900
иmodel-900.meta
Обратите внимание, что в этом примере, в то время как saver
фактически сохраняет как текущие значения переменных в качестве контрольной точки, так и структуру графика ( *.meta
), никакой особой осторожности не было принято, чтобы получить, например, заполнители x
и y
после модель была восстановлена. Например, если восстановление выполняется в любом месте, кроме этого учебного сценария, может быть громоздким восстановить x
и y
из восстановленного графика (особенно в более сложных моделях). Чтобы этого избежать, всегда tf.collections
имена своим переменным / заполнителям / операциям или думайте об использовании tf.collections
как показано в одном из замечаний.
Восстановление модели
Восстановление также довольно приятно и легко.
Вот удобная вспомогательная функция:
def restore_vars(saver, sess, chkpt_dir):
""" Restore saved net, global score and step, and epsilons OR
create checkpoint directory for later storage. """
sess.run(tf.initialize_all_variables())
checkpoint_dir = chkpt_dir
if not os.path.exists(checkpoint_dir):
try:
print("making checkpoint_dir")
os.makedirs(checkpoint_dir)
return False
except OSError:
raise
path = tf.train.get_checkpoint_state(checkpoint_dir)
print("path = ",path)
if path is None:
return False
else:
saver.restore(sess, path.model_checkpoint_path)
return True
Основной код:
path_to_saved_model = './'
max_steps = 1
# Start a session
with tf.Session() as sess:
... define the model here ...
print("define the param saver")
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
# restore session if there is a saved checkpoint
print("restoring model")
restored = restore_vars(saver, sess, path_to_saved_model)
print("model restored ",restored)
# Now continue training if you so choose
for step in range(max_steps):
# do an update on the model (not needed)
loss_value = sess.run([loss])
# Now save the model
saver.save(sess, "./model", global_step=step)