算法1 类中心知识蒸馏算法 | 第一阶段:微调教师网络 | 输入:教师网络模型 ,预训练参数,训练样本 | 计算标准交叉熵损失函数 | 反向传播更新 的参数,直到损失函数收敛 | 输出:教师网络模型 | 第二阶段:通过类中心蒸馏训练学生网络 | 输入:教师网络模型 ,学生网络模型 ,训练样本 | 初始化:学生网络参数 和超参数 | 按标签 整理训练集 每批次从随机类中 抽取随机的样本 | 根据式(7)计算总体损失函数 | 反向传播更新学生网络的参数 ,直到损失函数收敛 | 输出:学生网络模型 ,参数 | 第三阶段:测试学生网络 | 输入:学生网络模型 ,测试样本 | 输出:预测结果 | 按标签 整理训练集 每批次从随机类中 抽取随机的样本 |
|