知识蒸馏之交叉熵篇——代码实战

张开发
2026/4/30 6:26:21 15 分钟阅读

分享文章

知识蒸馏之交叉熵篇——代码实战
知识蒸馏之交叉熵篇——代码实战。下述代码总体上表示为把模型输出student_logits和真实答案labels做比较计算一个“分类错误程度”的损失值命名为ce_loss。ce_lossF.cross_entropy(student_logits,labels)那么问题来了------1. 为什么用交叉熵因为这是分类任务里最常用的损失函数。比如模型要判断一张图是猫、狗、车。模型不会直接说“猫”而是输出每个类别的分数student_logits[2.1,0.3,-1.2]这些分数表示模型对每个类别的倾向。交叉熵适合衡量模型预测的类别概率与真实类别之间从概率上看差得有多远。如果模型对正确类别非常自信交叉熵小。如果模型对错误类别很自信交叉熵大。2. 交叉熵有什么作用它的作用是告诉模型你错得有多离谱。举例如果真实标签是“猫”。现在模型 A 预测这个标签的概率分布结果为猫: 0.90, 狗: 0.08, 车: 0.02交叉熵损失很小因为正确类别概率高。模型 B 预测的结果为猫: 0.20, 狗: 0.70, 车: 0.10交叉熵损失较大因为模型更相信“狗”。训练时神经网络会通过反向传播让这个损失变小。也就是让模型越来越倾向于给正确类别更高分。3. 什么是ce_loss有什么用处ce_loss是一个变量名通常表示cross entropy loss也就是交叉熵损失。它一般是一个标量比如tensor(0.7321)它的用途主要有两个ce_loss.backward()optimizer.step()ce_loss.backward()会计算梯度告诉每个参数应该往哪个方向调整。optimizer.step()根据梯度更新模型参数。所以ce_loss是训练模型时的核心指标之一模型通过最小化它来学习。4. 这个F是哪里定义的里面大概都有些什么这里的F通常来自 PyTorchimporttorch.nn.functionalasFF不是一个函数而是一个模块完整名字是torch.nn.functional里面有很多常用的神经网络函数比如F.relu()F.softmax()F.cross_entropy()F.mse_loss()F.dropout()F.max_pool2d()F.one_hot()这些函数通常是“无状态”的也就是只负责计算不自己保存可训练参数。比如F.relu(x)只是把小于 0 的数变成 0。而类似nn.Linear(...)这种层会保存权重参数。5.student_logits、labels分别代表什么为什么定义这两个参数student_logits是学生模型的原始输出分数。名字里有两个部分student表示学生模型logits表示还没有经过 softmax 的原始分类分数例如一个 batch 有 2 条样本每条样本分 3 类student_logitstorch.tensor([[2.1,0.3,-1.2],[0.1,1.5,0.4]])形状通常是[batch_size,num_classes]labels是真实类别标签labelstorch.tensor([0,1])意思是第 1 个样本真实类别是第 0 类第 2 个样本真实类别是第 1 类定义这两个参数是为了让损失函数知道模型预测了什么真实答案是什么有了这两个东西才能计算模型错得多不多。6. 这一整行代码是用来干什么的这一整行代码的作用是ce_lossF.cross_entropy(student_logits,labels)把学生模型的输出student_logits和真实标签labels进行比较计算分类损失并保存到ce_loss变量里。可以理解成ce_loss 模型预测结果 和 标准答案 之间的差距在知识蒸馏代码里它通常表示学生模型直接向真实标签学习的损失。比如总损失可能是lossalpha*ce_lossbeta*distill_loss其中ce_loss学生模型向真实标签学习distill_loss学生模型向教师模型学习附上实现mini知识蒸馏的代码importargparseimportrandomfrompathlibimportPathimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromsklearn.datasetsimportload_digitsfromtorch.utils.dataimportDataLoader,TensorDataset,random_splitfromtorchvisionimportdatasets,transformsclassTeacherCNN(nn.Module):def__init__(self):super().__init__()self.featuresnn.Sequential(nn.Conv2d(1,32,kernel_size3,padding1),nn.ReLU(),nn.Conv2d(32,64,kernel_size3,padding1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64,128,kernel_size3,padding1),nn.ReLU(),nn.MaxPool2d(2),)self.classifiernn.Sequential(nn.Flatten(),nn.Linear(128*7*7,256),nn.ReLU(),nn.Dropout(0.2),nn.Linear(256,10),)defforward(self,x):returnself.classifier(self.features(x))classStudentCNN(nn.Module):def__init__(self):super().__init__()self.featuresnn.Sequential(nn.Conv2d(1,16,kernel_size3,padding1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16,32,kernel_size3,padding1),nn.ReLU(),nn.MaxPool2d(2),)self.classifiernn.Sequential(nn.Flatten(),nn.Linear(32*7*7,64),nn.ReLU(),nn.Linear(64,10),)defforward(self,x):returnself.classifier(self.features(x))defset_seed(seed):random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)defcount_params(model):returnsum(p.numel()forpinmodel.parameters()ifp.requires_grad)defbuild_loaders(data_dir,batch_size,dataset_name,seed):ifdataset_namedigits:digitsload_digits()imagestorch.tensor(digits.images,dtypetorch.float32).unsqueeze(1)/16.0imagesF.interpolate(images,size(28,28),modebilinear,align_cornersFalse)images(images-0.5)/0.5labelstorch.tensor(digits.target,dtypetorch.long)datasetTensorDataset(images,labels)train_sizeint(0.8*len(dataset))test_sizelen(dataset)-train_size generatortorch.Generator().manual_seed(seed)train_set,test_setrandom_split(dataset,[train_size,test_size],generatorgenerator)train_loaderDataLoader(train_set,batch_sizebatch_size,shuffleTrue,num_workers0)test_loaderDataLoader(test_set,batch_sizebatch_size,shuffleFalse,num_workers0)returntrain_loader,test_loader transformtransforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,)),])train_setdatasets.MNIST(data_dir,trainTrue,downloadTrue,transformtransform)test_setdatasets.MNIST(data_dir,trainFalse,downloadTrue,transformtransform)train_loaderDataLoader(train_set,batch_sizebatch_size,shuffleTrue,num_workers0)test_loaderDataLoader(test_set,batch_sizebatch_size,shuffleFalse,num_workers0)returntrain_loader,test_loaderdefevaluate(model,loader,device):model.eval()correct0total0loss_total0.0withtorch.no_grad():forimages,labelsinloader:imagesimages.to(device)labelslabels.to(device)logitsmodel(images)loss_totalF.cross_entropy(logits,labels).item()*images.size(0)predslogits.argmax(dim1)correct(predslabels).sum().item()totallabels.size(0)returnloss_total/total,correct/totaldeftrain_supervised(model,train_loader,test_loader,device,epochs,lr,name):optimizertorch.optim.Adam(model.parameters(),lrlr)model.to(device)forepochinrange(1,epochs1):model.train()running_loss0.0forimages,labelsintrain_loader:imagesimages.to(device)labelslabels.to(device)optimizer.zero_grad()logitsmodel(images)lossF.cross_entropy(logits,labels)loss.backward()optimizer.step()running_lossloss.item()*images.size(0)test_loss,test_accevaluate(model,test_loader,device)train_lossrunning_loss/len(train_loader.dataset)print(f{name}epoch{epoch}: train_loss{train_loss:.4f}test_loss{test_loss:.4f}test_acc{test_acc:.4f})defdistillation_loss(student_logits,teacher_logits,labels,temperature,alpha):ce_lossF.cross_entropy(student_logits,labels)soft_student_log_probsF.log_softmax(student_logits/temperature,dim1)soft_teacher_probsF.softmax(teacher_logits/temperature,dim1)kd_lossF.kl_div(soft_student_log_probs,soft_teacher_probs,reductionbatchmean)returnalpha*ce_loss(1-alpha)*(temperature**2)*kd_lossdeftrain_distilled(student,teacher,train_loader,test_loader,device,epochs,lr,temperature,alpha):optimizertorch.optim.Adam(student.parameters(),lrlr)teacher.to(device)student.to(device)teacher.eval()forepochinrange(1,epochs1):student.train()running_loss0.0forimages,labelsintrain_loader:imagesimages.to(device)labelslabels.to(device)optimizer.zero_grad()student_logitsstudent(images)withtorch.no_grad():teacher_logitsteacher(images)lossdistillation_loss(student_logits,teacher_logits,labels,temperature,alpha)loss.backward()optimizer.step()running_lossloss.item()*images.size(0)test_loss,test_accevaluate(student,test_loader,device)train_lossrunning_loss/len(train_loader.dataset)print(student_kd fepoch{epoch}: train_loss{train_loss:.4f}test_loss{test_loss:.4f}ftest_acc{test_acc:.4f}temperature{temperature}alpha{alpha})defmain():parserargparse.ArgumentParser()parser.add_argument(--data-dir,typePath,defaultPath(data))parser.add_argument(--dataset,choices[digits,mnist],defaultdigits)parser.add_argument(--batch-size,typeint,default128)parser.add_argument(--epochs-teacher,typeint,default3)parser.add_argument(--epochs-student,typeint,default3)parser.add_argument(--lr,typefloat,default1e-3)parser.add_argument(--temperature,typefloat,default4.0)parser.add_argument(--alpha,typefloat,default0.5,helpWeight for hard-label cross entropy.)parser.add_argument(--seed,typeint,default42)argsparser.parse_args()set_seed(args.seed)devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)print(fdevice{device})train_loader,test_loaderbuild_loaders(args.data_dir,args.batch_size,args.dataset,args.seed)teacherTeacherCNN()student_baselineStudentCNN()student_kdStudentCNN()print(fteacher params{count_params(teacher):,})print(fstudent params{count_params(student_baseline):,})print(\n Train teacher )train_supervised(teacher,train_loader,test_loader,device,args.epochs_teacher,args.lr,teacher)print(\n Train student baseline )train_supervised(student_baseline,train_loader,test_loader,device,args.epochs_student,args.lr,student_baseline)print(\n Train student with knowledge distillation )train_distilled(student_kd,teacher,train_loader,test_loader,device,args.epochs_student,args.lr,args.temperature,args.alpha,)teacher_loss,teacher_accevaluate(teacher,test_loader,device)baseline_loss,baseline_accevaluate(student_baseline,test_loader,device)kd_loss,kd_accevaluate(student_kd,test_loader,device)print(\n Final result )print(fteacher: loss{teacher_loss:.4f}acc{teacher_acc:.4f}params{count_params(teacher):,})print(fstudent_baseline: loss{baseline_loss:.4f}acc{baseline_acc:.4f}params{count_params(student_baseline):,})print(fstudent_kd: loss{kd_loss:.4f}acc{kd_acc:.4f}params{count_params(student_kd):,})if__name____main__:main()直接运行命令为python train_mnist_kd.py --epochs-teacher 3 --epochs-student 3下载MNIST数据集后的运行命令为python train_mnist_kd.py --dataset mnist --epochs-teacher 3 --epochs-student 3

更多文章