本文是我看一篇论文代码整理的总结。下面这篇论文(简称 HAT)是持续学习领域的一个经典工作,我曾经以此论文为样板,完整仔细地看过整个项目的代码,从而理解并开始自己实现持续学习的实验的,现将心得整理于此。
这篇总结将详细介绍代码的各个细节,目的是一站式搞懂一个持续学习项目乃至深度学习项目的写法,也提供了一种阅读他人代码的思路,供刚入门该领域的同学参考。我将按照自外到内的顺序,先介绍整个工程的逻辑,到主函数,再到具体函数或类的细节,剥洋葱式地讲解。
这个项目的代码有些地方写的乱或不规范,注意取其精华,掌握思想,并了解代码不合理的地方在哪。
看懂这篇笔记的先修条件是掌握 Python 和 PyTorch,以及 Linux 系统的基本使用,并了解深度学习和持续学习的基础知识,请参考我的相关笔记:
论文信息
Overcoming Catastrophic Forgetting with Hard Attention to the Task
- 会议:ICML 2018
- 作者:西班牙巴塞罗那的大学
- 内容:持续学习模型 HAT,是将 mask 机制加到持续学习的第一篇论文,提出了一个很简单的、每个神经元引入一个任务 mask 的方法,并给出了训练方法,和一个解决模型容量问题的稀疏正则项,让新旧任务 mask 重合。它属于参数隔离方法,之后很多带 mask 机制的持续学习论文以此篇为基础。
论文代码:https://github.com/joansj/hat
工程逻辑
我们从根目录开始,src
文件夹存放的是真正的代码,我们稍后讨论。根目录下的其他文件都是与代码没有直接关系的:
LICENSE
: 一个文本文件(可以看到没有后缀名),打开可以看到,里面的文字是在声明版权,告诉他人可以怎么用此项目、禁止怎么用(否则追究责任)。一般项目都会在根目录放一个这种名为 LICENSE 的文件,里面声明性的文字不需要自己写,去网上查查各种常见的选项(如 CC 4.0、MIT License 等),选一个合适的复制过来就好。在 Github 创建项目时有个选项可以选 license,之后会自动在根目录里创建 LICENSE 文件文本,很方便。在项目代码中大家一般是这样做,在其他例如文章、博客中也有其他方式声明版权,起到相同的作用,例如,可以看看我这个博客每篇文章结尾有一段话:“本文由作者按照 CC BY 4.0 进行授权,转载请注明”,点击就能跳转到 CC 4.0 指示的版权声明文本。readme.md
: 顾名思义就是“请先读我”,它是作者写的对项目的描述性文本,给用户看的。可以在这里写任何想跟用户说的,例如使用说明等信息。在 Github 项目首页也是默认展示这段文本(在 Github 创建项目时可以选择是否 add a README file)。众所周知在电脑上做笔记用 Markdown 格式很方便,现在的代码项目也是用它写(语法请自行学习)readme,而不是 word 之类的文件,因为技术上可以方便地接到网页上显示(例如我的博客每一篇文章都是 Markdown 格式写的)。本项目由于是一篇论文的代码,所以作者主要在这里写了论文的信息,并简要介绍了程序的安装和运行方法。requirements.txt
: 一个文本文件,列举了代码所需要的环境,放在项目根目录中,告诉别人运行此项目需要装什么第三方库。注意这个不一定是手打的或复制的,而是可以通过 Pip 或 Conda 自动生成的,此时具有一定的语法(当然有时候自动生成的太过详细反而不合适,于是只需要手打一些重要的即可。本项目就做的不太好,把其他无关的环境里的包也包含进来了)。以下是相关命令:pip freeze >requirements.txt
或conda list -e >requirements.txt
: 生成环境列表到requirements.txt
;pip install -r requirements.txt
或conda install --file requirements.txt
或conda create --name XXX --file requirements.txt
: 安装 requirements.txt 的环境到当前环境或新建环境。
.gitignore
: 一个文本文件,表示在上传到 Github 时应该忽略哪些文件(语法请自行学习),注意它用.
开头,在 Linux 或 Mac 系统中代表隐藏文件,下同。这些文件通常是临时的、runtime 的或结果性质的非代码、非必要的文件,不需要上传占用空间让别人看到。在 Github 创建项目时可以选择是否 add .gitignore。本项目忽略上传的文件有(其中一些文件在运行之后才会出现):logs
文件夹: 存放保存的实验结果文件(见“处理实验结果”一节)。它是非代码性质的数据文件,无需上传;dat
文件夹: 存放深度学习的数据集。它是非代码性质的数据文件,并且占用大量空间,不能上传;res
文件夹: 存放另外一些实验结果(见“处理实验结果”一节);old
、temp
文件夹: 作用未知,但看名字应该是临时文件;- 其他无关紧要的文件:
src/.idea
文件夹:(使用 IDE PyCharm 的配置文件);src/__pycache__
文件夹(Python 缓存文件);pyc
类型文件(py 文件编译后的二进制文件);.DS_Store
类型文件(苹果电脑的文件系统配置文件);.png
格式的文件;两个脚本文件(src/work.sh
、src/immalpha.sh
),可能是作者写了试的但最终没用。
下面看代码的 src
文件夹。有四个文件起到程序入口的作用:
run.py
: 是最基本的主函数,负责运行一次深度学习实验;run_multi.py
: 是用run.py
改写的,负责运行多次深度学习实验;work.py
:作者写的另一个可以运行多次深度学习实验的程序;run_compression.sh
: 负责运行模型压缩(compression)实验的程序,见论文 4.4 节。它是 Linux 系统的 shell 脚本命令(即命令行中的单个命令组合成的打包命令)。可以看到,这里它包含了多行固定的运行run.py
的命令行命令,当运行run_compression.sh
时,相当于运行了里面写的这些命令。
其他的是具体的模块代码:
approaches
文件夹:定义各种持续学习算法,每个算法是一个 py 文件;dataloaders
文件夹:定义了数据预处理方法,每个数据集是一个 py 文件。由于持续学习数据集一般是现有数据集构造的,这里也定义了如何构造数据集;networks
文件夹:定义了神经网络结构,每个结构是一个 py 文件;plot_results.py
: 可视化实验结果的工具(见“处理实验结果”一节).本项目是先把结果存下来,再在需要时单独对其可视化。可视化与核心业务分离,这个 py 文件就是单独的可视化程序;utils.py
文件:存放各种工具函数,为了不让主要部分的代码过长,如打印函数,计算某个量等。
主程序
run.py
是整个项目的核心,它完成一次深度学习实验的整个流程。
解析命令行参数
一次深度学习实验需要指定很多东西:数据集、网络结构、学习算法、各种超参数,还有一些细节的配置如随机数的种子、输出格式等。这些信息一般是不出现在代码里的,而是作为运行程序时用户指定的参数,即命令行参数。关于如何使用 Python 命令行参数,我在这篇博文有详细讨论。
run.py
定义命令行参数的部分在 9-20 行,解析命令行参数在 29-97 行。可以看到,它定义了如下 7 个命令行参数:
--seed
: 见“随机数种子”一节;--experiment
: 解析时通过 if 语句手动对应选择选项。根据命令行选项,把dataloaders
文件夹中的模块统一解析到名为dataloader
的变量中,在下面统一调用;--approach
: 解析时通过 if 语句手动对应选择选项。根据命令行选项,把approaches
文件夹中的模块统一解析到名为approach
的变量中,在下文统一调用;--nepochs
: 训练轮数,是比较重要的超参数,需要用户手动指定;--lr
: 学习率,是比较重要的超参数,需要用户手动指定;--parameter
: 为其他超参数的预留位置(因为每个 approach 的超参数都可以有所不同),具体是解析成几个、什么超参数,要看具体 approach 的定义;--output
: 指定输出结果文件名路径,见“处理实验结果”一节。
深度学习流程
接下来的代码对应深度学习流程:
- 读取数据集(99-102行):可以看到,所有
dataloaders
中的模块都只有一个get
函数,在这里统一调用,用于得到数据集(包括训练集、验证集、测试集,作者的处理办法是先打包成一个 data 变量,再在训练或测试时抽离出来,见125、154等行),以及每个任务有几个类、输入维度等信息(用于定义网络); - 网络结构初始化(104-107行):可以看到,所有
networks
中的模块都只有一个Net
类,在这里统一实例化为要训练的网络结构。实例化需要确定每个任务有几个类、输入维度等信息,来自上面数据集get
函数的返回值; - 定义学习算法(109-112行):可以看到,所有
approaches
中的模块都只有一个Appr
类,在这里统一实例化为学习算法。粗略阅读其代码可发现,这种Appr
类:- 不仅定义了持续学习的机制(因此实例化时需要传入持续学习有关的超参数,作者的做法是把
args
整个传进去,例如/approaches/hat.py
中27-31行解析了args.parameter
为 lamb 和 smax 两个超参数,用户在传--parameters
时就知道--parameter
代表这两个超参数); - 还把优化器和损失函数一并包进来定义,因此实例化时需要指定优化器的超参数、训练轮数等,这些都在命令行参数
args
里; - 请注意,
Appr
类还把网络也包进来作为实例属性了,从这里开始程序不再出现网络net
变量;
- 不仅定义了持续学习的机制(因此实例化时需要传入持续学习有关的超参数,作者的做法是把
- 训练(148-149行):统一调用
Appr
类的train
方法,它接受上面抽离出来的训练集和验证集,以及第几个任务这个信息。注意不需要传网络,它在Appr
里面,这个训练函数本质上也是在修改更新它; - 测试(152-159行):统一调用
Appr
类的eval
方法,它接受上面抽离出来的测试集,以及第几个任务这个信息。注意这里外层有个 u 循环,是要测试所有任务的。仍然不需要传网络。
Dataloaders
项目在 dataloaders
文件夹定义了数据集、预处理方法和构造持续学习任务的代码,每个数据集是一个 py 文件。每个文件都只定义了一个 get
函数,我们以持续学习经典的、较为简单的 pmnist.py
(Permuted MNIST)为例来讲解。get
函数它返回如下内容:
- 数据集变量为
data
:是一个嵌套字典,即字典的值还是字典- 第一层(11行)为任务,键为任务 ID;
- 有一个额外的键 ‘ncla’ 存放所有任务的类数之和(80行);
- 第二层(34行)为任务的元信息,包括:
- 任务名字 ‘name’:作者命名为 ‘pmnist-任务ID’(35行);
- 类的数量 ‘ncla’:在 Permuted MNIST 中,每个任务类的数量固定为 10(36行);
- 训练、验证、测试数据 ‘train’, ‘valid’, ‘test’;
- 第三层(39行)在数据集 ‘train’, ‘valid’, ‘test’ 里面:
- 输入 ‘x’:一个大 Tensor,事实上本项目输入模型的数据集并不是用的 PyTorch 的 Dataloader,作者是手动划分 batch 的,例如,见
approaches/sgd.py
的 81-82 行; - 标签 ‘y’:一个大 Tensor;
- 输入 ‘x’:一个大 Tensor,事实上本项目输入模型的数据集并不是用的 PyTorch 的 Dataloader,作者是手动划分 batch 的,例如,见
- 第一层(11行)为任务,键为任务 ID;
- 每个任务有几个类
taskcla
(78行):是一个列表,对 Permuted MNIST,它固定是[10,...,10]
; - 输入维度
size
(13行):直接定义为常量[1,28,28]
。
从 get
函数参数可以看到,作者没有为用户提供什么选择,一个 Permuted MNIST 数据集基本是固定的,用户只能设置:
- 控制随机数种子的
seed
和fixed_order
:见“随机数种子”一节; pc_valid
:验证集数据比例。
下面来看 data
变量第三层的数据集是如何一步步构造的:
- 首先通过
torchvision.datasets
将原始的 MNIST 数据集下载到dat
变量中(27-30行),再一步步解析到data
中; - 用原始数据集
dat
构造 batch=1 的 Dataloader(38行,应该是为了方便写循环),逐张图片作 permute 操作(41-43行),添加到data
中。注意此时data
的数据部分 ‘x’,’y’ 现在是列表; - 将此时的列表数据保存下来(20-21、51-52行),以后可以直接读取(55-67行)。究其原因,是前面逐张图片的处理操作太慢了,哪怕保存读取也更节省时间;
- 将列表转换为可以输入到 nn.Module 的 Tensor(48-52行);
- 注意上面只是分了训练集和测试集,还要从训练集 ‘train’ 中划分验证集 ‘valid’(70-73行)。注意作者在 pmnist.py 中验证集是直接复制了训练集,也就是说模型选择是按照训练集上最好的来选的。我不知道是因为懒还是别的原因,但这样是容易过拟合的,越小的数据集更是如此。
其他的数据集大同小异,我简要介绍之,主要关注区别:
mnist2.py
:是 2 个任务的 Split MNIST 数据集。由于不涉及逐张 Permute 操作,作者也没有在中间设计保存读取;cifar.py
:是 10 个任务的 Split CIFAR 数据集,前 5 个任务用 CIFAR10 数据集,每个任务有 2 个类;后 5 个任务用 CIFAR100 数据集,每个任务有 20 个类。对此数据集作者终于随机划分了训练集给验证集(79-90行),比例pc_valid
在get
函数参数由用户指定;mixture.py
:是很多种数据集的混合,有 8 个任务,每个任务是一种数据集,分别是 CIFAR10、CIFAR100、MNIST、SVHN、FashionMNIST、TrafficSigns、Facescrub、notMNIST(不是按顺序,而是固定地随机打乱)。这里有些数据集是torchvision.datasets
没有的,作者在下面定义了相应的数据集类(相当于自定义 Dataset 类)。
下面浅看一下作者是怎么自定义数据集的(在 mixture.py
):
- FashionMNIST:实际上
torchvision.datasets
是有这个数据集的,可能是作者在使用其 API 时遇到了 bug,然后自己重写了一个(249-257行); - TrafficSigns、Facescrub、notMNIST:都继承自
Dataset
类,写法遵从此笔记讲的自定义方法。__init__()
函数中从本地文件读取整个数据集到data
和labels
变量,然后在__getitem__()
直接索引。与 MNIST 类似,download
和train
参数控制下载和选择训练还是测试集。下载操作需要复杂的网络通讯和纠错机制,也是打包在一个download()
函数。通过判断train=True
的条件语句,选择读取训练还是测试数据集。
Networks
项目在 networks
文件夹的代码定义了网络结构,每个模型是一个 py 文件。每个文件都只定义了一个 nn.Module
名为 Net
的类。会写这些 nn.Module
类是深度学习的基础,请参考此笔记。每个深度学习项目都大同小异,大都用到 MLP、AlexNet、ResNet 等网络,写法也差不多。
我们更需要关注的是网络结构是如何适配持续学习的场景或方法的。在我看来有两点:
- 问题一:如何处理持续学习场景中新任务新来的类(输出头);
- 问题二:对 HAT 这种 model-based 的持续学习方法,涉及修改网络结构,怎么改。 我们以较简单的 MLP 为例来看作者是如何处理的,见
mlp.py
和mlp_hat.py
。
对于问题一,作者是事先把所有任务的输出头都定义好(19-21行),前馈时也会输出所有输出头结果的拼接(30-32行);而不是每来一个新任务动态地增加输出头,因为这是在做一个固定的实验,这样写代码比较方便。什么时候、什么任务用那个输出头,这些都定义在持续学习方法 approach
的训练和测试函数中。另外,即使像 Permuted MNIST 这样所有任务类别相同的数据集,也是每个任务给一个自己的输出头,而不是共用相同的输出头。
对于问题二,不可避免地要对每种网络结构衍生出一个修改版本,例如本项目中 mlp.py
衍生出 mlp_hat.py
,alexnet.py
衍生出 alexnet_hat.py
, alexnet_pathnet.py
, alexnet_progressive.py
, alexnet_lfl.py
等,者都是涉及修改网络结构的 model-based 方法,在使用这些 model-based 方法时,要求使用相应的网络结构。
来看一下 mlp_hat.py
,它定义了 HAT 方法的网络结构,即加了 mask 的 MLP。mlp.py
提供了一个 3 个隐藏层的 MLP,每层神经元个数相同;而 mlp_hat.py
提供了 1、2、3 个隐藏层的 MLP(由 __init__
函数的参数 nlayers
指定,观察下文代码它实际上只能取 1、2、3)。以 3 层 MLP 为例:
__init__()
函数中,比普通 MLP 多了三个 efc1、efc2、efc3(20、23、26行),即在每层神经元上的 task embedding,可见它们实现为nn.Embedding
,这个类用于表示一组长度相同的模型参数(称为 embedding),第一个参数num_embeddings
为 embedding 的个数,第二个参数embedding_dim
为 embedding 长度。这个类一般用于词向量表示(一组词库,每个词用一个 embedding 表示),但在这里作者用于表示各个任务的 task embedding 向量,注意每个 efc1、efc2、efc3 各自都是预定义好了所有任务各层的 embedding,而不是单个任务。- 有了 task embedding,通过论文中的公式(1):乘以尺度参数 \(s\) 再过 gate function 即 Sigmoid,得到 mask。这个过程打包成了一个 mask 函数(65-71行);
forward
函数中,比普通 MLP 多了 mask 的步骤:在每个神经元激活后乘以 mask(53、56、59行)。注意,这个 forward 函数不仅接受输入 \(x\),还包括了任务 \(t\)和计算 mask 用的 \(s\)。也就是说,模型是在这里提供接口给训练和测试时确定第几个任务这个信息的,这个Net
类是预定义好了所有任务,然后通过forward
函数区分任务。
此外还有一些细节。有 alexnet_hat.py
和 alexnet_hat_test.py
两个,它们区别在 43-50 行:是否对 task embedding 作归一初始化,与模型压缩实验有关。另外,有些带 hat 的 Net
类写了 get_view_for
函数,它使用 torch 的拉直操作(view)将一个 mask Tensor 拉直,属于 HAT 算法的一个工具函数,写在了模型类里,在 HAT 的 train 函数调用,见下文。
Approaches
项目在 approaches
文件夹的代码定义了各种持续学习算法的代码,每个算法是一个 py 文件。每个文件都只定义了一个名为 Appr
的类,其中都有训练和测试函数 train
、eval
以及定义的优化器 self.optimizer
、损失函数 self.criterion
。我们先来看不加持续学习防遗忘机制的微调算法 sgd.py
,理清楚基本的训练和测试流程的细节。
- 损失函数定义在
self.criterion
中。对于 sgd.py 这种简单的,直接在__init__
函数规定了是nn.CrossEntropyLoss()
,对于hat.py
等,在类中自定义了criterion
函数; - 优化器定义在
self.optimizer
中,是用__init__
传入的优化器超参数lr
等构造的(作者还多写了一层_get_optimizer
,可能是嫌定义优化器的代码太长); train
函数是训练一个任务,的核心工作在 40 行调用train_epoch
函数,定义在 72 行,它的任务是训练任务 t 的一个 epoch。它的流程与普通的深度学习训练过程没什么区别,唯一要注意的是它是在训练任务 t(这个信息当做函数参数传入),体现在 88 行结果截取输出头 t;其他部分代码都在做一件事——动态调整学习率,这是深度学习的训练技巧,我放在后面的章节专门讨论。eval
函数是测试当前模型(训练了任务 t 后)在一个任务上(反映在测试集上)的准确率,注意run.py
有个外循环测试所有任务。这个写的和普通深度学习的测试没什么差别,最后返回测试 loss 和准确率。
接下来看 HAT 算法 hat.py
,它是在 sgd.py
基础上改的,并且要求传入的 model
必须是 _hat
版本的:
criterion
函数(196行)定义了 HAT 论文的 sparse 正则项(公式(5))和分类损失。这个正则项需要用到旧任务<t
mask 的合并,存放在self.mask_pre
。为了这个正则项,损失函数criterion
不仅接受模型输出outputs
和真实标签targets
,还要masks
。注意 if 语句区分了第一个任务的情况;- 核心的
train_epoch
函数调用了加 mask 的 HAT 版forward
函数,并计算了上述criterion
定义的损失,再反向传播计算梯度。在更新之前:- 首先屏蔽不更新旧任务 mask 掉的参数,实现方式是梯度置 0。在 135 行,梯度乘以的
mask_back
就是旧任务mask_pre
的反转(1-x),它在上一步就由mask_pre
计算了出来(97-102行); - 接着应用论文 2.5 节的梯度补偿机制,将梯度乘以了一个补偿因子;
- 更新后再把训练好的 task embedding 统一收缩(clamp)到一个较小范围,这个在论文 2.5 节的最后一段提到。
- 首先屏蔽不更新旧任务 mask 掉的参数,实现方式是梯度置 0。在 135 行,梯度乘以的
train
函数(训练一个任务)最后包含了 task embedding 到 mask 的转换(87-90行)并计算旧任务mask_pre
(91-95行)的过程。可以看到,作者用同样的 float 类型的数据结构同时存放了 task embedding 和二元 mask:训练前是前者,训练后就用此处的代码转换成后者(因为 task embedding 再也不用,没必要存储了);eval
函数:没有太大区别,多了对正则项的统计。
除了上述算法,为了 baseline 的比较,还实现了其他与 HAT 类似的参数隔离方法如 PathNet、Progressive NN 等,也有其他方法如 EWC、IMM、LwF 等。它们的区别就在训练、测试、损失函数、优化器以及用到的相关变量,它们全都可以在一个 Appr
类写明白。
注意,这些方法只有 HAT 有 _test
版本的,它的意思是正式跑的程序,可以看到里面的代码更完善,它在 __init__()
函数定义了 logs
、logpath
(hat_test.py
29-59行,从命令行参数 --parameter
解析),如果定义了它们,从 run.py
的最后一段可以看到,会把测试的详细结果存作处理(见“处理实验结果”一节)。此外,_test
版本还有两个 criterion
函数,它们是一样的,这是作者整理代码时的疏忽,但可以看出来他是想在 _test
版本处理更多信息的。
其他细节
打印调试信息
作者穿插了各种调试信息在代码中。一般是用 print 语句实现,对于复杂的,为了不想让主函数过长,代码打包在了 utils.py
中 print 开头的函数来调用。
我在这里按顺序整理一下作者穿插的调试信息,也可以帮助梳理总结一下上面的内容:
- 23-27行:打印了用户指定的命令行参数,供用户确认;
- 102 行:打印数据集信息和持续学习的任务信息;
- 107 行:打印模型信息;
- 110-111行:打印损失函数和优化器信息;
- 119、176 行:打印训练进程、时间信息;
- 157、166-174 行:打印每一次测试结果、测试汇总结果;
- 162 行:打印结果保存的信息。
除了上面讲述的整个流程,代码中还有很多细节需要我们注意。它们往往是很重要的事情。
处理实验结果
纵观整个代码,作者把以下实验结果保存到了本地文件中:
- 命令行
--output
参数:它定义的路径存放的是测试准确率上三角矩阵acc
(161-163 行),t行u列表示训练完第t个任务时模型在任务u的准确率,这个是持续学习最主要的指标(见持续学习基础知识笔记)。命令行参数如果为空,作者定义了默认的文件路径(21-22 行),可以看到是用--experiment
、--approach
等元信息命名的,用于区分不同实验; Appr
类的logpath
参数:只有hat_test.py
出现,如果在--parameter
包含这一部分,run.py
在 178 行之后把一些结果信息用 pickle 保存下来(pickle 是一个 Python 内置库,可以完整保存、还原任意 Python 变量),在需要画图的时候被plot_results.py
还原调用。
随机数种子
在 run.py
对种子做了全局设定(31-33行),在代码内部也有一些局部随机数变量的种子设定,如 pmnist.py
18 行的任务顺序。
使用 GPU
这种规模的深度学习实验一定是用 GPU 跑的。在代码里:
- 34-35 行检查能不能用 GPU,不能用则强制退出;
- 106 行把模型放到了 GPU 上;
- 142-145、154-155 行把数据集放到了 GPU 上。
实验细节
代码中有一些细节,是深度学习实验经常要做的:
- 要做数据标准化:先手算了均值方差,再在构造数据集时应用
transforms.Normalize
变换,例如pmnist.py
24-30 行;
以下是代码中使用的一些调参技巧:
- 动态调整学习率:每一轮都在验证集上测试(调用
eval
函数)一下 loss,如果连续 lr_patience 个 epoch 验证集 loss 一直不下降,则把学习率调小一点:除以 lr_factor。 - Dropout 层防止过拟合:例如
mlp.py
第 15 行;
这就是一个科研用深度学习项目的全貌。看完了代码,也能感受到其中的不足之处,例如:
- 调用格式不太统一;
- 用户可以指定的东西太少;
- 一些常用的参数藏得太深(例如
--parameter
的解析规则,尤其是appr.logs
),用户必须非常仔细阅读代码才知道怎么用;
当然,科研用的代码以做出实验结果为目的,自己方便能看懂就行,不是产品,不需要呈现给用户,自己需要什么就写什么,挂在网上的目的只是在别人质疑的给他一个参考。这种性质也决定了作者没有必要写的更健全、完美,我们看下来也就理解其代码逻辑和思想即可,无需追究细节。