令人困惑的TensorFlow (II)
消息来源:baojiabao.com 作者: 发布时间:2026-05-15

选自
Jacobbuckman
作者
:
Jacob Buckman
机器之心编译
参与:高璇、王淑婷
六月底,机器之心发布了“令人困惑的 TensorFlow!”,讲述了初上手 TensorFlow 时会遇到的麻烦。日前,作者更新博客,对该文写了续篇,
主要讲述了保存和载入 TensorFlow 模型以及上下文管理器的一些问题。
命名和域
命名变数和张量
正如我们在第一部分讨论的,每次调用 tf.get_variable() 时,都需要为变数赋予一个新的唯一名称。实际上,图中的每个张量也需要一个唯一的名称。可以通过张量、操作和变数的 .name 属性访问该名称。绝大多数情况下,名称会自动创建;例如,一个常量节点会以 Const 命名,当创建更多常量节点时,其名称将是 Const_1,Const_2 等。还可以通过 name=的属性设置节点名称,列举后缀仍会自动添加:
代码:
import as 0. 1. 2. "cool_const" 3. "cool_const" ) print
a = tf.constant(
b = tf.constant(
c = tf.constant(
d = tf.constant(
输出:
0 0 0 0Const:
虽然节点命名并非必要,但在调试时非常有用。当 Tensorflow 代码崩溃时,error trace 将指向一个特定的操作。如果有很多同类型的操作,那么很难确定是哪一个出了问题。而通过明确命名每个节点,可以获得信息详细的 error trace,并更快地识别问题。
使用范围
随着图形越来越复杂,手动命名所有内容变得愈加困难。Tensorflow 提供 tf.variable_scope 对象,它通过将图形细分为更小的组块,使图形更易梳理。通过将一段图形创建代码封装在 with tf.variable_scope(scope_name):语句中,创建的所有节点名称都将自动以 scope_name 字元串作为前缀。此外,这些作用域堆栈,在另一个范围内创建的作用域会简单地将前缀链接在一起,用斜杠分隔。
代码:
import as 0. 1. with "first_scope" 2. "cool_const" "coef" 2. with "second_scope" "coef" 3. 1.
a = tf.constant(
b = tf.constant(
c = a + b
d = tf.constant(
coef1 = tf.get_variable(
e = coef1 * d
coef2 = tf.get_variable(
f = tf.constant(
g = coef2 * f
输出:
0 0 0 0 0 0 0 0 0Const:
first_scope/add:
first_scope/second_scope/mul:
first_scope/coef:
first_scope/second_scope/coef:
我们能够使用代码 coef 创建两个名称相同的变数。这是因为作用域可以将名称转换为 first_scope/coef:0 和 first_scope/second_scope/coef:0,它们是不同的。
保存和载入
训练好的神经网络包括两个基本组成部分:
已经学习过某些任务优化的网络权重
说明如何利用权重获得结果的网络图
Tensorflow 将这两个组件分开,但很明显它们需要紧密匹配。如果没有图结构进行说明,那权重也无用,而带有随机权重的图也效果也不好。事实上,即使仅交换两个权重矩阵也可能完全破坏模型。这通常会让 Tensorflow 初学者感觉很挫败。使用预先训练好的模型作为神经网络的一个组成部分不失为加速训练的好方法,但是也有可能搞砸一切。
保存模型
当只有单个模型时,Tensorflow 用于保存和载入的内置工具使用很方便:只需创建一个 tf.train.Saver()。类似于 tf.train.Optimizer,tf.train.Saver 本身并不是一个节点,而是在已有图形上执行有用功能的更高级类别。你可能已经预料到 tf.train 的“有用功能”了,即保存和载入模型。
代码:
import as "a" "b"
a = tf.get_variable(
b = tf.get_variable(
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.save(sess,
"./tftcp.model"
)输出
四个新文件:
-00000 -00001checkpoint
tftcp.model.data
tftcp.model.index
tftcp.model.meta
具体内容分析如下:
首先:当我们只保存一个模型时,为什么会输出四个文件?重建模型所需的信息被分散到它们当中。如果想复制或者备份模型,需要有四个文件(前缀为文件名)。下面简述答案:
tftcp.model.data-00000-of-00001 包含模型权重(上述第一个要点)。它可能这里最大的文件。
tftcp.model.meta 是模型的网络结构(上述第二个要点)。它包含重建图形所需的所有信息。
tftcp.model.index 是连接前两点的索引结构。用于在数据文件中找到对应节点的参数。
checkpoint 实际上不需要重建模型,但如果在整个训练过程中保存了多个版本的模型,那它会跟踪所有内容。
其次,我为什么一定要为该示例创建 tf.Session 和 tf.global_variables_initializer 呢?
因为,如果要保存一个模型,我们需要保存相关的内容。计算存于图中,但数值存于会话中。tf.train.Saver 可以通过指向图表的全局指针访问网络结构。但当我们保存变数的值(即网络权重)时,我们需要访问 tf.Session 来确定这些值;这就是为什么 sess 作为 save 函数的第一个参数传入。此外,尝试保存未初始化的变数会引发错误,因为尝试访问未初始化变数的值总是会引发错误。因此,我们需要一个会话和一个初始化程序(或等价的 tf.assign)。
载入模型
既然我们已经保存了模型,现在重新载入它。第一步是重新创建变数:我们希望变数的名称、形状和类型都与保存时一致。第二步是创建与之前一样的 tf.train.Saver,并调用 restore 函数。
代码:
import as "a" "b"
a = tf.get_variable(
b = tf.get_variable(
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess,
"./tftcp.model"
)sess.run([a,b])
输出:
1.3106428 0.6413864[
在运行之前,我们不需要初始化 a 或 b!这是因为 restore 运算将值从文件移动到会话的变数中。由于会话不再包含任何空值变数,因此不再需要初始化。(如果不小心,会适得其反:还原后运行 init 会使随机初始化的值覆盖载入的值。)
选择变数
当一个 tf.train.Saver 程序初始化后,它会查看当前图形并获取变数列表;这是 saver“关心”的永久存储的变数列表。我们可以用._var_list 属性来检查:
代码:
import as "a" "b" "c" print
a = tf.get_variable(
b = tf.get_variable(
saver = tf.train.Saver()
c = tf.get_variable(
输出:
[
因为在创建 saver 时 c 还没有出现,所以它并没有成为函数的一部分。一般来说,你要在创建 saver 之前确保已经创建了所有的变数。
当然,在某些特定的情况下,可能只需保存变数的一个子集。当创建 var_list 以期望它跟踪可用变数子集时,tf.train.Saver 允许传递 var_list。
代码:
import as "a" "b" "c" print
a = tf.get_variable(
b = tf.get_variable(
c = tf.get_variable(
saver = tf.train.Saver(var_list=[a,b])
输出:
[
载入修正模型
上面例子中涵盖的模型载入方案类似于物理中的“真空中无摩擦的完美球体”(perfect sphere in frictionless vacuum)场景。只要你使用自己的代码保存和载入模型,且不擅自更改二者,实现保存和载入轻而易举。但很多情况下,并不会有如此完美的场景。在这些情况下,我们需要多加思量。
让我们通过几个场景来说明这些问题。首先,如果我们想保存一个完整的模型,但只想载入其中的一部分怎么办?(在下面代码示例中,我依次运行两个脚本。)
代码:
import as "a" "b" "./tftcp.model"
a = tf.get_variable(
b = tf.get_variable(
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.save(sess,
import as "a" "./tftcp.model"
a = tf.get_variable(
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.restore(sess,
sess.run(a)
输出:
1.1700551
OK。当我们在相反的场景里,就会出现失败的状况:我们希望将一个模型作为大型模型的组件载入。
代码:
import as "a" "./tftcp.model"
a = tf.get_variable(
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.save(sess,
import
tensorflowas
tfa = tf.get_variable(
"a"
, [])d = tf.get_variable(
"d"
, [])init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.restore(sess,
"./tftcp.model"
)输出:
not in "/job:localhost/replica:0/task:0/device:CPU:0"Key d
[[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT], _device=
我们只想载入 a,却忽略了新变数 b。我们犯了一个错误,却抱怨 d 没有出现在 checkpoint 中。
第三种情况是,我们想将一个模型的参数载入到另一个模型的计算图中。这也会引发一个错误,原因很明显:Tensorflow 不知道把载入的所有参数放置在何处。幸好有个方法可以给它点提示。
还记得 var_list 吗?或者更准确来说是“var_list_or_dictionary_mapping_names_to_vars”,但这个名字有点拗口,所以他们使用第一个。
保存模型是 Tensorflow 要求使用全局唯一变数名的关键原因之一。在保存-模型-文件中,每个保存变数的名称都与其形状和值有关。将其载入到新的计算图中与将想要载入的变数的原始名称映射到当前模型的变数中一样简单。示例如下:
代码:
import as "a" "./tftcp.model"
a = tf.get_variable(
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.save(sess,
import
tensorflowas
tfd = tf.get_variable(
"d"
, [])init = tf.global_variables_initializer()
saver = tf.train.Saver(var_list={
"a"
: d})sess = tf.Session()
sess.run(init)
saver.restore(sess,
"./tftcp.model"
)sess.run(d)
输出:
-0.9303965
这是一种关键机制,通过这个机制,可以将没有相同计算图的模型组合在一起。例如,你可能从网上获得了一个预训练好的语言模型,希望重用词嵌入。或者你可能在两次训练之间改变了模型的参数化,想让这个新版本在旧版本的基础上继续前进;但你又不想重新训练整个过程。在这两种情况下,你只需手动创建一个字典,将旧变数名称映射到新变数即可。
需要注意的是:你要明确地知道正在载入的参数是如何使用的。如果可以,你应该使用原作者用来构建模型的确切代码,以确保计算图的组件与训练时看起来一样。如果需要复现模型,务必记住,无论多微小的更改,都可能严重损害预训练网络的性能。所以始终要将复现结果和原来的结果进行对比。
模型检查
如果想载入的模型来源于网络或由自己创建(两个月前),那你很可能不知道原始变数是如何命名的。要检查保存的模型,需要使用官方 Tensorflow 库的一些工具。
链接:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/framework/python/framework/checkpoint_utils.py
代码:
import as "a" "b" 10 20 "c" "./tftcp.model" print "./tftcp.model"
a = tf.get_variable(
b = tf.get_variable(
c = tf.get_variable(
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
saver.save(sess,
输出:
"a" "b" 10 20 "c"[(
利用这些工具(结合原始代码库一起使用)通常可以找到你想要的变数名称。
结论
希望本文能帮你了解关于保存和载入 Tensorflow 模型的基础知识。还有其他一些高级技巧,比如自动 checkpoint 和保存/恢复元图,可能会在以后的文章中提到;但是根据我的经验,这些并不常用,特别是对于初学者来说。
原文链接:https://jacobbuckman.com/post/tensorflow-the-confusing-parts-2/
本文为机器之心编译,
转载请联系本公众号获得授权
。?------------------------------------------------
加入机器之心(全职记者 / 实习生):hr@jiqizhixin.com
投稿或寻求报道:
content
@jiqizhixin.com广告 & 商务合作:bd@jiqizhixin.com
相关文章
B站怎么炸崩了哔哩哔哩服务器今日怎么又炸挂了?技术团队公开早先原因2023-03-06 19:05:55
苹果iPhoneXS/XR手机电池容量续航最强?答案揭晓2023-02-19 15:09:54
华为荣耀两款机型起内讧:荣耀Play官方价格同价同配该如何选?2023-02-17 23:21:27
google谷歌原生系统Pixel3 XL/4/5/6 pro手机价格:刘海屏设计顶配版曾卖6900元2023-02-17 18:58:09
科大讯飞同传同声翻译软件造假 浮夸不能只罚酒三杯2023-02-17 18:46:15
华为mate20pro系列手机首发上市日期价格,屏幕和电池参数配置对比2023-02-17 18:42:49
小米MAX4手机上市日期首发价格 骁龙720打造大屏标准2023-02-17 18:37:22
武汉弘芯遣散!结局是总投资1280亿项目烂尾 光刻机抵押换钱2023-02-16 15:53:18
谷歌GoogleDrive网云盘下载改名“GoogleOne” 容量提升价格优惠2023-02-16 13:34:45
巴斯夫将裁员6000人 众化工巨头裁员潮再度引发关注2023-02-13 16:49:06
人手不足 韵达快递客服回应大量包裹派送异常没有收到2023-02-07 15:25:20
资本微念与李子柒销声匿迹谁赢? 微念公司退出子柒文化股东2023-02-02 09:24:38
三星GalaxyS8 S9 S10系统恢复出厂设置一直卡在正在检查更新怎么办2023-01-24 10:10:02
华为Mate50 RS保时捷最新款顶级手机2022多少钱?1.2万元售价外观图片吊打iPhone142023-01-06 20:27:09
芯片常见的CPU芯片封装方式 QFP和QFN封装的区别?2022-12-02 17:25:17
华为暂缓招聘停止社招了吗?官方回应来了2022-11-19 11:53:50
热血江湖手游:长枪铁甲 刚猛热血 正派枪客全攻略技能介绍大全2022-11-16 16:59:09
东京把玩了尼康微单相机Z7 尼康Z7现在卖多少钱?2022-10-22 15:21:55
苹果iPhone手机灵动岛大热:安卓灵动岛App应用下载安装量超100万次2022-10-03 22:13:45
苹果美版iPhone可以在中国保修 从哪看怎么查询iPhone的生产日期?2022-09-22 10:00:07










