【6.3】tensorflow保存和调取网络

一、保存模型

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")  ## 通过name来建立变量名,用于后续模型中调用
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

二、调用模型

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")  ## 获得已建立模型tensor的变量名

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

三、报错

报错1. no variables to save

原因:没有在模型中找到变量,未按上面的方式保存和调用tensorflow

报错2:

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'init_all_tables': Operation was explicitly assigned to /job:worker/task:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:GPU:2, /job:localhost/replica:0/task:0/device:GPU:3, /job:localhost/replica:0/task:0/device:GPU:4 ]. Make sure the device specification refers to a valid device.
     [[Node: init_all_tables = NoOp[_device="/job:worker/task:0"]()]]

官方给的说明(https://www.tensorflow.org/api_guides/python/meta_graph):

Import a graph with preset devices.

Sometimes an exported meta graph is from a training environment that the importer doesn't have. For example, the model might have been trained on GPUs, or in a distributed environment with replicas. When importing such models, it's useful to be able to clear the device settings in the graph so that we can run it on locally available devices. This can be achieved by calling import_meta_graph with the clear_devices option set to True.

说白了,就是模型是其他的节点创建的,模型中含有节点的信息,需要先抹掉

解决办法:

加入clear_devices=True

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta',
      clear_devices=True)
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  ...

四、我的案例

保存模型

saver = tf.train.Saver()
model_path = '%s/model-%s-total.ckpt' % (MODEL_SAVE_PATH,epoch)
save_path = saver.save(sess,model_path)

调用模型

model_path = '%s/model-%s-total.ckpt.meta' % (MODEL_SAVE_PATH,epoch)
model_path_2 = '%s/model-%s-total.ckpt' % (MODEL_SAVE_PATH,epoch)      
with tf.Session(config=config) as sess:
	# 初始化所有变量
    init = tf.global_variables_initializer()
    sess.run(init)

    # 获得连通图
    saver = tf.train.import_meta_graph(model_path, clear_devices=True)
    saver.restore(sess, model_path_2)
    graph = tf.get_default_graph()

    ## 获得对应的变量
    feed_dict.update(
                {graph.get_tensor_by_name("adj_mats_%s_%s_%s/shape:0" % (kk[0], kk[1], ii)):
                     adj_mats_orig[kk[0], kk[1]][ii].shape,
                 graph.get_tensor_by_name("adj_mats_%s_%s_%s/values:0" % (kk[0], kk[1], ii)):
                     adj_mats_orig[kk[0], kk[1]][ii].data,
                 graph.get_tensor_by_name("adj_mats_%s_%s_%s/indices:0" % (kk[0], kk[1], ii)):
                     preprocessing.sparse_to_tuple(adj_mats_orig[kk[0], kk[1]][ii])[0]
                 })


    rec = sess.run([op_to_restore], feed_dict=feed_dict)[0]

参考资料

药企,独角兽,苏州。团队长期招人,感兴趣的都可以发邮件聊聊:tiehan@sina.cn
个人公众号,比较懒,很少更新,可以在上面提问题,如果回复不及时,可发邮件给我: tiehan@sina.cn