I have a TensorFlow (TF) model that I'd like to restore and retrain some of its parameters. I know that tf.get_operation_by_name('name of the optimizer')
retrieves the original optimizer that was used to train the model before it was stored. However, I don't know how to pass the new list of TF variables that I want the optimizer to retrain!
This example helps illustrate what I want to do:
learning_rate = 0.0001
training_iters = 60000
batch_size = 64
display_step = 20
ImVecDim = 784# The number of elements in a an image vector (flattening a 28x28 2D image)
NumOfClasses = 10
dropout = 0.8
with tf.Session() as sess:
LoadMod = tf.train.import_meta_graph('simple_mnist.ckpt.meta') # This object loads the model
LoadMod.restore(sess, tf.train.latest_checkpoint('./')) # Loading weights and biases and other stuff to the model
g = tf.get_default_graph()
# Variables to be retrained:
wc2 = g.get_tensor_by_name('wc2:0')
bc2 = g.get_tensor_by_name('bc2:0')
wc3 = g.get_tensor_by_name('wc3:0')
bc3 = g.get_tensor_by_name('bc3:0')
wd1 = g.get_tensor_by_name('wd1:0')
bd1 = g.get_tensor_by_name('bd1:0')
wd2 = g.get_tensor_by_name('wd2:0')
bd2 = g.get_tensor_by_name('bd2:0')
out_w = g.get_tensor_by_name('out_w:0')
out_b = g.get_tensor_by_name('out_b:0')
VarToTrain = [wc2,wc3,wd1,wd2,out_w,bc2,bc3,bd1,bd2,out_b]
# Retrieving the optimizer:
Opt = tf.get_operation_by_name('Adam')
# Retraining:
X = g.get_tensor_by_name('ImageIn:0')
Y = g.get_tensor_by_name('LabelIn:0')
KP = g.get_tensor_by_name('KeepProb:0')
accuracy = g.get_tensor_by_name('NetAccuracy:0')
cost = g.get_tensor_by_name('loss:0')
step = 1
while step * batch_size < training_iters:
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
#########################################################################
# Here I want to pass (VarToTrain) to the optimizer (Opt)! #
#########################################################################
if step % display_step == 0:
acc = sess.run(accuracy, feed_dict={X: batch_xs, Y: batch_ys, KP: 1.})
loss = sess.run(cost, feed_dict={X: batch_xs, Y: batch_ys, KP: 1.})
print("Iter " + str(step * batch_size) + ", Minibatch Loss= " + "{:.6f}".format(
loss) + ", Training Accuracy= " + "{:.5f}".format(acc))
step += 1
feed_dict = {X: mnist.test.images[:256], Y: mnist.test.labels[:256], KP: 1.0}
ModelAccuracy = sess.run(accuracy, feed_dict)
print('Retraining finished'+', Test Accuracy = %f' %ModelAccuracy)
See Question&Answers more detail:os