引言

作为一款流行的开源深度学习框架,TensorFlow正被越来越多的人使用。然而,TensorFlow相比于其他框架来说,相对底层,有时候仅仅为了实现一些小想法也需要编写很多代码,又因为TensorFlow庞大的体系架构,如果不能理清之间的关系,很容易导致模型稍复杂,就需要经历冗繁的调试过程。鉴于此,写下这篇博客,用于记录学习TensorFlow的过程中所踩过的坑。之后遇到的问题也都记录于此。

保存和恢复模型

通常训练一段时间后,需要保存的内容有两部分,一部分是当前迭代步骤的模型参数数据,也就是checkpoint, 另一部分是构建的计算图模型了。从tensorflow版本 v0.11.RC0 之后,可以对计算图进行保存和恢复了,这意味着如果你需要在别人预训练好的模型之上进行finetune的话,只要别人提供checkpoint文件和模型文件(.meta)恢复和保存变量和计算图了。关于这部分内容可以参考tensorflow官方API. PS: 由于保存计算图的功能是在v0.11.RC0版本加进去的,所以很多人不知道官方文档里有这个 =。=

另一方面,如果想看看checkpoint文件里变量是怎么保存的,或者想打印checkpoint文件里的tensor值,可以参考inspect_checkpoint

保存模型

save_model.py

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
w1 = tf.Variable(tf.truncated_normal(shape=[10], stddev=0.01), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
w1_, w2_ = sess.run([w1, w2])
print(w1_, w2_)
saver.save(sess, 'my_model')

运行上述文件,你将得到下列文件

  • checkpoint
  • my_model.data-00000-of-00001
  • my_model.index
  • my_model.meta

前三个是数据文件,最后一个是模型文件。

恢复模型

restore_model.py

1
2
3
4
5
6
7
8
9
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('.'))
vars_op = tf.get_collection('vars')
for v in vars_op:
v_ = sess.run(v)
print(v_)

summary写入流程

TensorFlow提供了tensorboard可视化的功能,之前学习如何写summary的时候,都是仿照官方tutorial来走,由于不理解背后的逻辑,导致每次写代码都要查文档,这实在是很僵硬。这里对如何写summary进行一次总结,
加深印象。

总的来说共分为这么几步:

  1. 如果需要记录变量loss,在loss的定义处添加tf.summary.scalar("loss", loss_op)
  2. 定义writer,tf.summary.FileWriter(logging_dir, sess.graph), 需要注意的是在定义writer之后,如果还继续往graph上添加节点,在tensorboard中的图上是不会显示的。
  3. 定义需要fetch的Operationsummary_op = tf.summary.merge_all(), 由于之前在很多地方标记过,这里将所有定义的summary进行合并,然后就可以sess.run() 了。
  4. summary = sess.run(summary_op, feed_dict),形象地记为将生肉summary_op煮成熟肉summary
  5. 将熟肉summary喂给正处于饥饿的人 writer.add_summary(summary, global_step)

    如果把整个过程比喻成做饭给女友吃的话,大概是这样:

    summary

将数据转为tfrecord

tensorflow有自己的数据格式tfrecord,将自己的数据转换为tfrecord首先可以避免io瓶颈,其次能够将数据流和网络结构完美融合,从而能够建立QueueRunner一体化读取并训练。这种方法相比使用placeholder来读入数据,避免了将整个数据集同时读入内存,因此适用于读取大量数据情形。

非序列数据

如果数据本身是非序列数据,可以仿照官方给出的教程

  • 写入tfrecord convert_to_records.py
  • 读取tfrecord fully_connected_reader.py

序列数据

如果数据本身是序列数据,比如视频有很多帧图像,所有的时序数据都是这种类型。针对这种类型,tensorflow提供了tf.train.SequenceExample来序列化并写入tfrecord文件,但是官方文档中并没有这方面教程,经过一番摸索后总结如下。

写入tfrecord

  1. 定义三种特征int64, bytes, float。图像像素可以通过numpy.tostring()的方式转为bytes特征。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
  2. SequenceExample的参数有两个,分别是:contextfeature_lists, 创建SequenceExample的原型如下,其中的feat类型是python列表。

    1
    2
    3
    4
    5
    6
    7
    8
    ex = tf.train.SequenceExample(
    context=tf.train.Features(feature={
    "label": self._int64_feature(idx)
    }),
    feature_lists=tf.train.FeatureLists(feature_list={
    "frames": tf.train.FeatureList(feature=feat)
    })
    )
  3. 调用ex.SerializeToString()并写入writer = tf.python_io.TFRecordWriter(filename)

    1
    2
    3
    writer = tf.python_io.TFRecordWriter(filename)
    writer.write(ex.SerializeToString())

读取tfrecord

读取时首先建立tf.TFRecordReader()对象,然后调用tf.parse_single_sequence_example()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def _read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
context_features = {
"label": tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
"frames": tf.FixedLenSequenceFeature([], dtype=tf.string)
}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=serialized_example,
context_features=context_features,
sequence_features=sequence_features
)
label = tf.cast(context_parsed['label'], tf.int32)
frames = sequence_parsed['frames']
frames = tf.decode_raw(frames, np.uint8)
frames = tf.reshape(frames, (FRAME_COUNT, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL))
return frames, label
def make_batch(datadir, batch_size, min_queue_examples):
if datadir.endswith('/'):
datadir = datadir[:-1]
filenames = glob.glob(datadir + '/trainlist01_0.tfrecord')
print(filenames)
filename_queue = tf.train.string_input_producer(filenames)
frames, label = _read_and_decode(filename_queue)
frames_batch, label_batch = tf.train.shuffle_batch(
[frames, label],
batch_size=batch_size,
num_threads=20,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples
)
frames_batch = tf.to_float(frames_batch)
label_batch = tf.reshape(label_batch, [batch_size])
return frames_batch, label_batch

日志log管理

这一项个人感觉很重要,深度网络需要不断地调参看结果,之前一直是调用print函数的方式,在可能出错的情况下、或查看某个tensor的形状的情况下,直接打印到终端。

然而随着模型的复杂度提升,有时候程序需要跑很久,调参的时间跨度很大,日志也很多,为了便于之后分析,于是就产生了将日志记录到文件的想法,最好是能够同时输出到终端和日志文件中。基于此想法,调研一番可以借用python
自带的logging模块实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import sys
import time
import logging
def logger_fn(name, file, level=logging.DEBUG):
tf_logger = logging.getLogger(name)
tf_logger.setLevel(level)
ch = logging.StreamHandler(sys.stdout)
fh = logging.FileHandler(file, mode='w')
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
fh.setFormatter(formatter)
tf_logger.addHandler(ch)
tf_logger.addHandler(fh)
return tf_logger
logger = logger_fn('tflog', 'training-{}.log'.format(time.asctime()))

然后在其他文件中,将logger这个对象导入,就可以调用logger.info, logger.debug等函数了。需要注意的是,导入对象在所有文件中只会执行一次,因此不需要将记录日志写成单例类的形式,这种方式简单清晰,使用起来也很方便。

打印训练参数

之前介绍了如何同时将输出信息打印到屏幕和日志文件中,这里列出如何将训练参数也输入到文件中。以我个人习惯于将所有参数用一个Config类来表示。那么训练参数就是类的属性,将这些属性以字符串的形式输出即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Config:
train_dir = 'datasets/TFUCF/train'
test_dir = 'datasets/TFUCF/test' # specify test dir.
logging_dir = 'logs'
n_training = 10000
n_test = 2000
n_classes = 10
batch_size = 32
learning_rate = 0.01
epoch = 30
summary_step = 10
print_step = 1
save_step = 1000
decay_step = 1000
decay_rate = 0.95
cell_size = 512
debug = 0
@classmethod
def logthis(cls):
attrs = sorted([a for a in cls.__dict__ if not a.startswith('_') and a != 'logthis'])
dilim = '-'*40
return '\n'.join([dilim, *['{:>20}|{:<20}'.format(a, cls.__dict__[a]) for a in attrs], dilim])

print(Config.logthis())输出结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
----------------------------------------
batch_size|32
cell_size|512
debug|0
decay_rate|0.95
decay_step|1000
epoch|30
learning_rate|0.01
logging_dir|logs
n_classes|10
n_test|2000
n_training|10000
print_step|1
save_step|1000
summary_step|10
test_dir|datasets/TFUCF/test
train_dir|datasets/TFUCF/train
----------------------------------------