我有一个使用迭代器训练我的网络的模型;遵循Google现在推荐的新数据集API管道模型.

我读了tfrecord文件,将数据提供给网络,训练得很好,一切顺利,我在训练结束时保存了我的模型,所以我可以在以后运行推理.代码的简化版本如下:

""" Training and saving """

training_dataset = tf.contrib.data.TFRecordDataset(training_record)
training_dataset = training_dataset.map(ds._path_records_parser)
training_dataset = training_dataset.batch(BATCH_SIZE)
with tf.name_scope("iterators"):
  training_iterator = Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)
  next_training_element = training_iterator.get_next()
  training_init_op = training_iterator.make_initializer(training_dataset)

def train(num_epochs):
  # compute for the number of epochs
  for e in range(1,num_epochs+1):
    session.run(training_init_op) #initializing iterator here
    while True:
      try:
        images,labels = session.run(next_training_element)
        session.run(optimizer,Feed_dict={x: images,y_true: labels})
      except tf.errors.OutOfRangeError:
        saver_name = './saved_models/ucf-model'
        print("Finished Training Epoch {}".format(e))
        break



    """ Restoring """
# restoring the saved model and its variables
session = tf.Session()
saver = tf.train.import_Meta_graph(r'saved_models\ucf-model.Meta')
saver.restore(session,tf.train.latest_checkpoint('.\saved_models'))
graph = tf.get_default_graph()

# restoring relevant tensors/ops
accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch
testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing.
next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator
# loading my testing set tfrecords
testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path)
testing_dataset = testing_dataset.map(ds._path_records_parser,num_threads=4,output_buffer_size=BATCH_SIZE*20)
testing_dataset = testing_dataset.batch(BATCH_SIZE)

testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset

with tf.Session() as session:
  session.run(testing_init_op)
  while True:
    try:
      images,labels = session.run(next_testing_element)
      accuracy = session.run(accuracy,Feed_dict={x: test_images,y_true: test_labels}) #error here,x,y_true not defined
    except tf.errors.OutOfRangeError:
      break

我的问题主要是我恢复模型.如何将测试数据提供给网络?

>当我使用testing_iterator = graph.get_operation_by_name(“iterators / Iterator”),next_testing_element = graph.get_operation_by_name(“iterators / IteratorGetNext”)恢复我的Iterator时,出现以下错误:
GetNext()失败,因为迭代器尚未初始化.确保在获取下一个元素之前已为此迭代器运行初始化程序操作.
>所以我尝试使用以下方法初始化我的数据集:testing_init_op = testing_iterator.make_initializer(testing_dataset)).我收到此错误:AttributeError:’Operation’对象没有属性’make_initializer’

另一个问题是,由于正在使用迭代器,因此不需要在training_model中使用占位符,因为迭代器直接将数据提供给图形.但是这样,当我将数据提供给“准确度”操作时,如何在第3行到最后一行恢复我的Feed_dict键?

编辑:如果有人可以建议在迭代器和网络输入之间添加占位符的方法,那么我可以尝试通过评估“准确性”张量来运行图形,同时将数据提供给占位符并完全忽略迭代器.

解决方法

还原已保存的元图时,可以使用name恢复初始化操作,然后再次使用它来初始化输入管道以进行推理.

也就是说,在创建图表时,您可以这样做

dataset_init_op = iterator.make_initializer(dataset,name='dataset_init')

然后执行以下操作恢复此操作:

dataset_init_op = graph.get_operation_by_name('dataset_init')

这是一个自包含的代码片段,用于比较恢复前后的随机初始化模型的结果.

保存迭代器

np.random.seed(42)
data = np.random.random([4,4])
X = tf.placeholder(dtype=tf.float32,shape=[4,4],name='X')
dataset = tf.data.Dataset.from_tensor_slices(X)
iterator = tf.data.Iterator.from_structure(dataset.output_types,dataset.output_shapes)
dataset_next_op = iterator.get_next()

# name the operation
dataset_init_op = iterator.make_initializer(dataset,name='dataset_init')

w = np.random.random([1,4])
W = tf.Variable(w,name='W',dtype=tf.float32)
output = tf.multiply(W,dataset_next_op,name='output')     
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
sess.run(dataset_init_op,Feed_dict={X:data})
while True:
    try:
        print(sess.run(output))
    except tf.errors.OutOfRangeError:
        saver.save(sess,'tmp/',global_step=1002)
    break

然后您可以恢复相同的模型进行推理,如下所示:

恢复已保存的迭代器

np.random.seed(42)
data = np.random.random([4,4])
tf.reset_default_graph()
sess = tf.Session()
saver = tf.train.import_Meta_graph('tmp/-1002.Meta')
ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint'))
saver.restore(sess,ckpt.model_checkpoint_path)
graph = tf.get_default_graph()

# Restore the init operation
dataset_init_op = graph.get_operation_by_name('dataset_init')

X = graph.get_tensor_by_name('X:0')
output = graph.get_tensor_by_name('output:0')
sess.run(dataset_init_op,Feed_dict={X:data})
while True:
try:
    print(sess.run(output))
except tf.errors.OutOfRangeError:
    break

python – 恢复使用迭代器的Tensorflow模型的更多相关文章

  1. three.js模拟实现太阳系行星体系功能

    这篇文章主要介绍了three.js模拟实现太阳系行星体系功能,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下

  2. HTML5 Web缓存和运用程序缓存(cookie,session)

    这篇文章主要介绍了HTML5 Web缓存和运用程序缓存(cookie,session),小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  3. HTML5页面无缝闪开的问题及解决方案

    这篇文章主要介绍了HTML5页面无缝闪开方案,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  4. iOS Swift上弃用后Twitter.sharedInstance().session()?. userName的替代方案

    解决方法如果您仍在寻找解决方案,请参阅以下内容:

  5. 使用Fabric SDK iOS访问Twitter用户时间线

    我试图在这个问题上挣扎两天.我正在使用FabricSDK和Rest工具包,试图为Twitter使用不同的RestAPIWeb服务.我可以使用具有authTokenSecret,authToken和其他值的会话对象的TWTRLogInButton成功登录.当我尝试获取用户时间线时,我总是得到失败的响应,作为:{“errors”:[{“code”:215,“message”:“BadAuthentic

  6. ios – 如何从Apple Watch调用iPhone上定义的方法

    有没有办法从Watchkit扩展中调用iPhone上的类中定义的方法?根据我的理解,目前在Watchkit和iPhone之间进行本地通信的方法之一是使用NSUserDefaults,但还有其他方法吗?

  7. ios – 为什么,将nil作为参数从Objc C发送到swift类初始化器,用新对象替换nil参数

    除非属性本身被声明为nonnull:

  8. ios – 在Swift中对MKCircle进行子类化

    我想通过添加另一个String属性来继承MKCircle,我们称之为“代码”.这个属性不是可选的和常量的,所以我必须从初始化器设置它,对吧?有没有办法定义一个单一的便利初始化器,在这种情况下需要3个参数?本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请发送邮件至dio@foxmail.com举报,一经查实,本站将立刻删除。

  9. ios – 如何将视频从AVAssetExportSession保存到相机胶卷?

    在此先感谢您的帮助.解决方法只需使用session.outputURL=…

  10. ios – 使用AVCaptureSession sessionPreset = AVCaptureSessionPresetPhoto拉伸捕获的照片

    解决方法所以我解决了我的问题.这是我现在使用的代码,它工作正常:…重要的输出imagaView:一些额外的信息:相机图层必须是全屏,并且outputimageView也必须是.我希望这些对某些人来说也是有用的信息.

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部