首页 代码学习笔记:从论文 HAT 学会写持续学习代码
文章
取消

代码学习笔记:从论文 HAT 学会写持续学习代码

本文是我看一篇论文代码整理的总结。下面这篇论文(简称 HAT)是持续学习领域的一个经典工作,我曾经以此论文为样板,完整仔细地看过整个项目的代码,从而理解并开始自己实现持续学习的实验的,现将心得整理于此。

这篇总结将详细介绍代码的各个细节,目的是一站式搞懂一个持续学习项目乃至深度学习项目的写法,也提供了一种阅读他人代码的思路,供刚入门该领域的同学参考。我将按照自外到内的顺序,先介绍整个工程的逻辑,到主函数,再到具体函数或类的细节,剥洋葱式地讲解。

这个项目的代码有些地方写的乱或不规范,注意取其精华,掌握思想,并了解代码不合理的地方在哪。

看懂这篇笔记的先修条件是掌握 Python 和 PyTorch,以及 Linux 系统的基本使用,并了解深度学习和持续学习的基础知识,请参考我的相关笔记:

论文信息

Overcoming Catastrophic Forgetting with Hard Attention to the Task

  • 会议:ICML 2018
  • 作者:西班牙巴塞罗那的大学
  • 内容:持续学习模型 HAT,是将 mask 机制加到持续学习的第一篇论文,提出了一个很简单的、每个神经元引入一个任务 mask 的方法,并给出了训练方法,和一个解决模型容量问题的稀疏正则项,让新旧任务 mask 重合。它属于参数隔离方法,之后很多带 mask 机制的持续学习论文以此篇为基础。

论文代码:https://github.com/joansj/hat


工程逻辑

我们从根目录开始,src 文件夹存放的是真正的代码,我们稍后讨论。根目录下的其他文件都是与代码没有直接关系的:

  • LICENSE: 一个文本文件(可以看到没有后缀名),打开可以看到,里面的文字是在声明版权,告诉他人可以怎么用此项目、禁止怎么用(否则追究责任)。一般项目都会在根目录放一个这种名为 LICENSE 的文件,里面声明性的文字不需要自己写,去网上查查各种常见的选项(如 CC 4.0、MIT License 等),选一个合适的复制过来就好。在 Github 创建项目时有个选项可以选 license,之后会自动在根目录里创建 LICENSE 文件文本,很方便。在项目代码中大家一般是这样做,在其他例如文章、博客中也有其他方式声明版权,起到相同的作用,例如,可以看看我这个博客每篇文章结尾有一段话:“本文由作者按照 CC BY 4.0 进行授权,转载请注明”,点击就能跳转到 CC 4.0 指示的版权声明文本。
  • readme.md: 顾名思义就是“请先读我”,它是作者写的对项目的描述性文本,给用户看的。可以在这里写任何想跟用户说的,例如使用说明等信息。在 Github 项目首页也是默认展示这段文本(在 Github 创建项目时可以选择是否 add a README file)。众所周知在电脑上做笔记用 Markdown 格式很方便,现在的代码项目也是用它写(语法请自行学习)readme,而不是 word 之类的文件,因为技术上可以方便地接到网页上显示(例如我的博客每一篇文章都是 Markdown 格式写的)。本项目由于是一篇论文的代码,所以作者主要在这里写了论文的信息,并简要介绍了程序的安装和运行方法。
  • requirements.txt: 一个文本文件,列举了代码所需要的环境,放在项目根目录中,告诉别人运行此项目需要装什么第三方库。注意这个不一定是手打的或复制的,而是可以通过 Pip 或 Conda 自动生成的,此时具有一定的语法(当然有时候自动生成的太过详细反而不合适,于是只需要手打一些重要的即可。本项目就做的不太好,把其他无关的环境里的包也包含进来了)。以下是相关命令:
    • pip freeze >requirements.txtconda list -e >requirements.txt: 生成环境列表到 requirements.txt
    • pip install -r requirements.txtconda install --file requirements.txtconda create --name XXX --file requirements.txt: 安装 requirements.txt 的环境到当前环境或新建环境。
  • .gitignore: 一个文本文件,表示在上传到 Github 时应该忽略哪些文件(语法请自行学习),注意它用 . 开头,在 Linux 或 Mac 系统中代表隐藏文件,下同。这些文件通常是临时的、runtime 的或结果性质的非代码、非必要的文件,不需要上传占用空间让别人看到。在 Github 创建项目时可以选择是否 add .gitignore。本项目忽略上传的文件有(其中一些文件在运行之后才会出现):
    • logs 文件夹: 存放保存的实验结果文件(见“处理实验结果”一节)。它是非代码性质的数据文件,无需上传;
    • dat 文件夹: 存放深度学习的数据集。它是非代码性质的数据文件,并且占用大量空间,不能上传;
    • res 文件夹: 存放另外一些实验结果(见“处理实验结果”一节);
    • oldtemp 文件夹: 作用未知,但看名字应该是临时文件;
    • 其他无关紧要的文件:src/.idea 文件夹:(使用 IDE PyCharm 的配置文件);src/__pycache__ 文件夹(Python 缓存文件);pyc 类型文件(py 文件编译后的二进制文件);.DS_Store 类型文件(苹果电脑的文件系统配置文件);.png 格式的文件;两个脚本文件(src/work.shsrc/immalpha.sh),可能是作者写了试的但最终没用。

下面看代码的 src 文件夹。有四个文件起到程序入口的作用:

  • run.py: 是最基本的主函数,负责运行一次深度学习实验;
  • run_multi.py: 是用 run.py 改写的,负责运行多次深度学习实验;
  • work.py:作者写的另一个可以运行多次深度学习实验的程序;
  • run_compression.sh: 负责运行模型压缩(compression)实验的程序,见论文 4.4 节。它是 Linux 系统的 shell 脚本命令(即命令行中的单个命令组合成的打包命令)。可以看到,这里它包含了多行固定的运行 run.py 的命令行命令,当运行 run_compression.sh 时,相当于运行了里面写的这些命令。

其他的是具体的模块代码:

  • approaches 文件夹:定义各种持续学习算法,每个算法是一个 py 文件;
  • dataloaders 文件夹:定义了数据预处理方法,每个数据集是一个 py 文件。由于持续学习数据集一般是现有数据集构造的,这里也定义了如何构造数据集;
  • networks 文件夹:定义了神经网络结构,每个结构是一个 py 文件;
  • plot_results.py: 可视化实验结果的工具(见“处理实验结果”一节).本项目是先把结果存下来,再在需要时单独对其可视化。可视化与核心业务分离,这个 py 文件就是单独的可视化程序;
  • utils.py 文件:存放各种工具函数,为了不让主要部分的代码过长,如打印函数,计算某个量等。

主程序

run.py 是整个项目的核心,它完成一次深度学习实验的整个流程。

解析命令行参数

一次深度学习实验需要指定很多东西:数据集、网络结构、学习算法、各种超参数,还有一些细节的配置如随机数的种子、输出格式等。这些信息一般是不出现在代码里的,而是作为运行程序时用户指定的参数,即命令行参数。关于如何使用 Python 命令行参数,我在这篇博文有详细讨论。

run.py 定义命令行参数的部分在 9-20 行,解析命令行参数在 29-97 行。可以看到,它定义了如下 7 个命令行参数:

  • --seed: 见“随机数种子”一节;
  • --experiment: 解析时通过 if 语句手动对应选择选项。根据命令行选项,把 dataloaders 文件夹中的模块统一解析到名为 dataloader 的变量中,在下面统一调用;
  • --approach: 解析时通过 if 语句手动对应选择选项。根据命令行选项,把 approaches 文件夹中的模块统一解析到名为 approach 的变量中,在下文统一调用;
  • --nepochs: 训练轮数,是比较重要的超参数,需要用户手动指定;
  • --lr: 学习率,是比较重要的超参数,需要用户手动指定;
  • --parameter: 为其他超参数的预留位置(因为每个 approach 的超参数都可以有所不同),具体是解析成几个、什么超参数,要看具体 approach 的定义;
  • --output: 指定输出结果文件名路径,见“处理实验结果”一节。

深度学习流程

接下来的代码对应深度学习流程:

  • 读取数据集(99-102行):可以看到,所有 dataloaders 中的模块都只有一个 get 函数,在这里统一调用,用于得到数据集(包括训练集、验证集、测试集,作者的处理办法是先打包成一个 data 变量,再在训练或测试时抽离出来,见125、154等行),以及每个任务有几个类、输入维度等信息(用于定义网络);
  • 网络结构初始化(104-107行):可以看到,所有 networks 中的模块都只有一个 Net 类,在这里统一实例化为要训练的网络结构。实例化需要确定每个任务有几个类、输入维度等信息,来自上面数据集 get 函数的返回值;
  • 定义学习算法(109-112行):可以看到,所有 approaches 中的模块都只有一个 Appr 类,在这里统一实例化为学习算法。粗略阅读其代码可发现,这种 Appr 类:
    • 不仅定义了持续学习的机制(因此实例化时需要传入持续学习有关的超参数,作者的做法是把 args 整个传进去,例如 /approaches/hat.py 中27-31行解析了 args.parameter 为 lamb 和 smax 两个超参数,用户在传 --parameters 时就知道 --parameter 代表这两个超参数);
    • 还把优化器和损失函数一并包进来定义,因此实例化时需要指定优化器的超参数、训练轮数等,这些都在命令行参数 args 里;
    • 请注意,Appr 类还把网络也包进来作为实例属性了,从这里开始程序不再出现网络 net 变量;
  • 训练(148-149行):统一调用 Appr 类的 train 方法,它接受上面抽离出来的训练集和验证集,以及第几个任务这个信息。注意不需要传网络,它在 Appr 里面,这个训练函数本质上也是在修改更新它;
  • 测试(152-159行):统一调用 Appr 类的 eval 方法,它接受上面抽离出来的测试集,以及第几个任务这个信息。注意这里外层有个 u 循环,是要测试所有任务的。仍然不需要传网络。

Dataloaders

项目在 dataloaders 文件夹定义了数据集、预处理方法和构造持续学习任务的代码,每个数据集是一个 py 文件。每个文件都只定义了一个 get 函数,我们以持续学习经典的、较为简单的 pmnist.py(Permuted MNIST)为例来讲解。get 函数它返回如下内容:

  • 数据集变量为 data:是一个嵌套字典,即字典的值还是字典
    • 第一层(11行)为任务,键为任务 ID;
      • 有一个额外的键 ‘ncla’ 存放所有任务的类数之和(80行);
    • 第二层(34行)为任务的元信息,包括:
      • 任务名字 ‘name’:作者命名为 ‘pmnist-任务ID’(35行);
      • 类的数量 ‘ncla’:在 Permuted MNIST 中,每个任务类的数量固定为 10(36行);
      • 训练、验证、测试数据 ‘train’, ‘valid’, ‘test’;
    • 第三层(39行)在数据集 ‘train’, ‘valid’, ‘test’ 里面:
      • 输入 ‘x’:一个大 Tensor,事实上本项目输入模型的数据集并不是用的 PyTorch 的 Dataloader,作者是手动划分 batch 的,例如,见 approaches/sgd.py 的 81-82 行;
      • 标签 ‘y’:一个大 Tensor;
  • 每个任务有几个类 taskcla(78行):是一个列表,对 Permuted MNIST,它固定是 [10,...,10]
  • 输入维度 size(13行):直接定义为常量 [1,28,28]

get 函数参数可以看到,作者没有为用户提供什么选择,一个 Permuted MNIST 数据集基本是固定的,用户只能设置:

  • 控制随机数种子的 seedfixed_order:见“随机数种子”一节;
  • pc_valid:验证集数据比例。

下面来看 data 变量第三层的数据集是如何一步步构造的:

  • 首先通过 torchvision.datasets 将原始的 MNIST 数据集下载到 dat 变量中(27-30行),再一步步解析到 data 中;
  • 用原始数据集 dat 构造 batch=1 的 Dataloader(38行,应该是为了方便写循环),逐张图片作 permute 操作(41-43行),添加到 data 中。注意此时 data 的数据部分 ‘x’,’y’ 现在是列表;
  • 将此时的列表数据保存下来(20-21、51-52行),以后可以直接读取(55-67行)。究其原因,是前面逐张图片的处理操作太慢了,哪怕保存读取也更节省时间;
  • 将列表转换为可以输入到 nn.Module 的 Tensor(48-52行);
  • 注意上面只是分了训练集和测试集,还要从训练集 ‘train’ 中划分验证集 ‘valid’(70-73行)。注意作者在 pmnist.py 中验证集是直接复制了训练集,也就是说模型选择是按照训练集上最好的来选的。我不知道是因为懒还是别的原因,但这样是容易过拟合的,越小的数据集更是如此。

其他的数据集大同小异,我简要介绍之,主要关注区别:

  • mnist2.py:是 2 个任务的 Split MNIST 数据集。由于不涉及逐张 Permute 操作,作者也没有在中间设计保存读取;
  • cifar.py:是 10 个任务的 Split CIFAR 数据集,前 5 个任务用 CIFAR10 数据集,每个任务有 2 个类;后 5 个任务用 CIFAR100 数据集,每个任务有 20 个类。对此数据集作者终于随机划分了训练集给验证集(79-90行),比例 pc_validget 函数参数由用户指定;
  • mixture.py:是很多种数据集的混合,有 8 个任务,每个任务是一种数据集,分别是 CIFAR10、CIFAR100、MNIST、SVHN、FashionMNIST、TrafficSigns、Facescrub、notMNIST(不是按顺序,而是固定地随机打乱)。这里有些数据集是 torchvision.datasets 没有的,作者在下面定义了相应的数据集类(相当于自定义 Dataset 类)。

下面浅看一下作者是怎么自定义数据集的(在 mixture.py):

  • FashionMNIST:实际上 torchvision.datasets 是有这个数据集的,可能是作者在使用其 API 时遇到了 bug,然后自己重写了一个(249-257行);
  • TrafficSigns、Facescrub、notMNIST:都继承自 Dataset 类,写法遵从此笔记讲的自定义方法。__init__() 函数中从本地文件读取整个数据集到 datalabels 变量,然后在 __getitem__() 直接索引。与 MNIST 类似,downloadtrain 参数控制下载和选择训练还是测试集。下载操作需要复杂的网络通讯和纠错机制,也是打包在一个 download() 函数。通过判断 train=True 的条件语句,选择读取训练还是测试数据集。

Networks

项目在 networks 文件夹的代码定义了网络结构,每个模型是一个 py 文件。每个文件都只定义了一个 nn.Module 名为 Net 的类。会写这些 nn.Module 类是深度学习的基础,请参考此笔记。每个深度学习项目都大同小异,大都用到 MLP、AlexNet、ResNet 等网络,写法也差不多。

我们更需要关注的是网络结构是如何适配持续学习的场景或方法的。在我看来有两点:

  • 问题一:如何处理持续学习场景中新任务新来的类(输出头);
  • 问题二:对 HAT 这种 model-based 的持续学习方法,涉及修改网络结构,怎么改。 我们以较简单的 MLP 为例来看作者是如何处理的,见 mlp.pymlp_hat.py

对于问题一,作者是事先把所有任务的输出头都定义好(19-21行),前馈时也会输出所有输出头结果的拼接(30-32行);而不是每来一个新任务动态地增加输出头,因为这是在做一个固定的实验,这样写代码比较方便。什么时候、什么任务用那个输出头,这些都定义在持续学习方法 approach 的训练和测试函数中。另外,即使像 Permuted MNIST 这样所有任务类别相同的数据集,也是每个任务给一个自己的输出头,而不是共用相同的输出头。

对于问题二,不可避免地要对每种网络结构衍生出一个修改版本,例如本项目中 mlp.py 衍生出 mlp_hat.pyalexnet.py 衍生出 alexnet_hat.py, alexnet_pathnet.py, alexnet_progressive.py, alexnet_lfl.py 等,者都是涉及修改网络结构的 model-based 方法,在使用这些 model-based 方法时,要求使用相应的网络结构。

来看一下 mlp_hat.py,它定义了 HAT 方法的网络结构,即加了 mask 的 MLP。mlp.py 提供了一个 3 个隐藏层的 MLP,每层神经元个数相同;而 mlp_hat.py 提供了 1、2、3 个隐藏层的 MLP(由 __init__ 函数的参数 nlayers 指定,观察下文代码它实际上只能取 1、2、3)。以 3 层 MLP 为例:

  • __init__() 函数中,比普通 MLP 多了三个 efc1、efc2、efc3(20、23、26行),即在每层神经元上的 task embedding,可见它们实现为 nn.Embedding,这个类用于表示一组长度相同的模型参数(称为 embedding),第一个参数 num_embeddings 为 embedding 的个数,第二个参数 embedding_dim 为 embedding 长度。这个类一般用于词向量表示(一组词库,每个词用一个 embedding 表示),但在这里作者用于表示各个任务的 task embedding 向量,注意每个 efc1、efc2、efc3 各自都是预定义好了所有任务各层的 embedding,而不是单个任务。
  • 有了 task embedding,通过论文中的公式(1):乘以尺度参数 \(s\) 再过 gate function 即 Sigmoid,得到 mask。这个过程打包成了一个 mask 函数(65-71行);
  • forward 函数中,比普通 MLP 多了 mask 的步骤:在每个神经元激活后乘以 mask(53、56、59行)。注意,这个 forward 函数不仅接受输入 \(x\),还包括了任务 \(t\)和计算 mask 用的 \(s\)。也就是说,模型是在这里提供接口给训练和测试时确定第几个任务这个信息的,这个 Net 类是预定义好了所有任务,然后通过 forward 函数区分任务。

此外还有一些细节。有 alexnet_hat.pyalexnet_hat_test.py 两个,它们区别在 43-50 行:是否对 task embedding 作归一初始化,与模型压缩实验有关。另外,有些带 hat 的 Net 类写了 get_view_for 函数,它使用 torch 的拉直操作(view)将一个 mask Tensor 拉直,属于 HAT 算法的一个工具函数,写在了模型类里,在 HAT 的 train 函数调用,见下文。

Approaches

项目在 approaches 文件夹的代码定义了各种持续学习算法的代码,每个算法是一个 py 文件。每个文件都只定义了一个名为 Appr 的类,其中都有训练和测试函数 traineval以及定义的优化器 self.optimizer、损失函数 self.criterion。我们先来看不加持续学习防遗忘机制的微调算法 sgd.py,理清楚基本的训练和测试流程的细节。

  • 损失函数定义在 self.criterion 中。对于 sgd.py 这种简单的,直接在 __init__ 函数规定了是 nn.CrossEntropyLoss(),对于 hat.py 等,在类中自定义了 criterion 函数;
  • 优化器定义在 self.optimizer 中,是用 __init__ 传入的优化器超参数 lr 等构造的(作者还多写了一层 _get_optimizer,可能是嫌定义优化器的代码太长);
  • train 函数是训练一个任务,的核心工作在 40 行调用 train_epoch 函数,定义在 72 行,它的任务是训练任务 t 的一个 epoch。它的流程与普通的深度学习训练过程没什么区别,唯一要注意的是它是在训练任务 t(这个信息当做函数参数传入),体现在 88 行结果截取输出头 t;其他部分代码都在做一件事——动态调整学习率,这是深度学习的训练技巧,我放在后面的章节专门讨论。
  • eval 函数是测试当前模型(训练了任务 t 后)在一个任务上(反映在测试集上)的准确率,注意 run.py 有个外循环测试所有任务。这个写的和普通深度学习的测试没什么差别,最后返回测试 loss 和准确率。

接下来看 HAT 算法 hat.py,它是在 sgd.py 基础上改的,并且要求传入的 model 必须是 _hat 版本的:

  • criterion 函数(196行)定义了 HAT 论文的 sparse 正则项(公式(5))和分类损失。这个正则项需要用到旧任务 <t mask 的合并,存放在 self.mask_pre。为了这个正则项,损失函数 criterion 不仅接受模型输出 outputs 和真实标签 targets,还要 masks。注意 if 语句区分了第一个任务的情况;
  • 核心的 train_epoch 函数调用了加 mask 的 HAT 版 forward 函数,并计算了上述 criterion 定义的损失,再反向传播计算梯度。在更新之前:
    • 首先屏蔽不更新旧任务 mask 掉的参数,实现方式是梯度置 0。在 135 行,梯度乘以的 mask_back 就是旧任务 mask_pre 的反转(1-x),它在上一步就由 mask_pre 计算了出来(97-102行);
    • 接着应用论文 2.5 节的梯度补偿机制,将梯度乘以了一个补偿因子;
    • 更新后再把训练好的 task embedding 统一收缩(clamp)到一个较小范围,这个在论文 2.5 节的最后一段提到。
  • train 函数(训练一个任务)最后包含了 task embedding 到 mask 的转换(87-90行)并计算旧任务 mask_pre(91-95行)的过程。可以看到,作者用同样的 float 类型的数据结构同时存放了 task embedding 和二元 mask:训练前是前者,训练后就用此处的代码转换成后者(因为 task embedding 再也不用,没必要存储了);
  • eval 函数:没有太大区别,多了对正则项的统计。

除了上述算法,为了 baseline 的比较,还实现了其他与 HAT 类似的参数隔离方法如 PathNet、Progressive NN 等,也有其他方法如 EWC、IMM、LwF 等。它们的区别就在训练、测试、损失函数、优化器以及用到的相关变量,它们全都可以在一个 Appr 类写明白。

注意,这些方法只有 HAT 有 _test 版本的,它的意思是正式跑的程序,可以看到里面的代码更完善,它在 __init__() 函数定义了 logslogpathhat_test.py 29-59行,从命令行参数 --parameter 解析),如果定义了它们,从 run.py 的最后一段可以看到,会把测试的详细结果存作处理(见“处理实验结果”一节)。此外,_test 版本还有两个 criterion 函数,它们是一样的,这是作者整理代码时的疏忽,但可以看出来他是想在 _test 版本处理更多信息的。

其他细节

打印调试信息

作者穿插了各种调试信息在代码中。一般是用 print 语句实现,对于复杂的,为了不想让主函数过长,代码打包在了 utils.py 中 print 开头的函数来调用。

我在这里按顺序整理一下作者穿插的调试信息,也可以帮助梳理总结一下上面的内容:

  • 23-27行:打印了用户指定的命令行参数,供用户确认;
  • 102 行:打印数据集信息和持续学习的任务信息;
  • 107 行:打印模型信息;
  • 110-111行:打印损失函数和优化器信息;
  • 119、176 行:打印训练进程、时间信息;
  • 157、166-174 行:打印每一次测试结果、测试汇总结果;
  • 162 行:打印结果保存的信息。

除了上面讲述的整个流程,代码中还有很多细节需要我们注意。它们往往是很重要的事情。

处理实验结果

纵观整个代码,作者把以下实验结果保存到了本地文件中:

  • 命令行 --output 参数:它定义的路径存放的是测试准确率上三角矩阵 acc(161-163 行),t行u列表示训练完第t个任务时模型在任务u的准确率,这个是持续学习最主要的指标(见持续学习基础知识笔记)。命令行参数如果为空,作者定义了默认的文件路径(21-22 行),可以看到是用--experiment--approach等元信息命名的,用于区分不同实验;
  • Appr 类的 logpath 参数:只有 hat_test.py 出现,如果在 --parameter 包含这一部分,run.py 在 178 行之后把一些结果信息用 pickle 保存下来(pickle 是一个 Python 内置库,可以完整保存、还原任意 Python 变量),在需要画图的时候被 plot_results.py 还原调用。

随机数种子

run.py 对种子做了全局设定(31-33行),在代码内部也有一些局部随机数变量的种子设定,如 pmnist.py 18 行的任务顺序。

使用 GPU

这种规模的深度学习实验一定是用 GPU 跑的。在代码里:

  • 34-35 行检查能不能用 GPU,不能用则强制退出;
  • 106 行把模型放到了 GPU 上;
  • 142-145、154-155 行把数据集放到了 GPU 上。

实验细节

代码中有一些细节,是深度学习实验经常要做的:

  • 要做数据标准化:先手算了均值方差,再在构造数据集时应用 transforms.Normalize 变换,例如 pmnist.py 24-30 行;

以下是代码中使用的一些调参技巧:

  • 动态调整学习率:每一轮都在验证集上测试(调用 eval 函数)一下 loss,如果连续 lr_patience 个 epoch 验证集 loss 一直不下降,则把学习率调小一点:除以 lr_factor。
  • Dropout 层防止过拟合:例如 mlp.py 第 15 行;


这就是一个科研用深度学习项目的全貌。看完了代码,也能感受到其中的不足之处,例如:

  • 调用格式不太统一;
  • 用户可以指定的东西太少;
  • 一些常用的参数藏得太深(例如 --parameter 的解析规则,尤其是 appr.logs),用户必须非常仔细阅读代码才知道怎么用;

当然,科研用的代码以做出实验结果为目的,自己方便能看懂就行,不是产品,不需要呈现给用户,自己需要什么就写什么,挂在网上的目的只是在别人质疑的给他一个参考。这种性质也决定了作者没有必要写的更健全、完美,我们看下来也就理解其代码逻辑和思想即可,无需追究细节。

本文由作者按照 CC BY 4.0 进行授权,转载请注明

论文笔记:Queried Unlabeled Data Improves and Robustifies Class-Incremental Learning

快慢网络式持续学习