##################### YOUR CODE STARTS HERE ###################### Step 1: calculate the #zeros (please use round())num_zeros=round(sparsity*num_elements)# Step 2: calculate the importance of weightimportance=tensor.abs()# Step 3: calculate the pruning thresholdthreshold=torch.kthvalue(importance.view(-1),num_zeros).values# Step 4: get binary mask (1 for nonzeros, 0 for zeros)mask=importance>threshold##################### YOUR CODE ENDS HERE #######################
sparsity_dict={##################### YOUR CODE STARTS HERE ###################### please modify the sparsity value of each layer# please DO NOT modify the key of sparsity_dict'backbone.conv0.weight':0,'backbone.conv1.weight':0.6,'backbone.conv2.weight':0.5,'backbone.conv3.weight':0.5,'backbone.conv4.weight':0.5,'backbone.conv5.weight':0.6,'backbone.conv6.weight':0.6,'backbone.conv7.weight':0.75,'classifier.weight':0##################### YOUR CODE ENDS HERE #######################}
defget_num_channels_to_keep(channels:int,prune_ratio:float)->int:"""A function to calculate the number of layers to PRESERVE after pruning
Note that preserve_rate = 1. - prune_ratio
"""##################### YOUR CODE STARTS HERE #####################returnint(round(channels*(1-prune_ratio)))##################### YOUR CODE ENDS HERE #####################@torch.no_grad()defchannel_prune(model:nn.Module,prune_ratio:Union[List,float])->nn.Module:"""Apply channel pruning to each of the conv layer in the backbone
Note that for prune_ratio, we can either provide a floating-point number,
indicating that we use a uniform pruning rate for all layers, or a list of
numbers to indicate per-layer pruning rate.
"""# sanity check of provided prune_ratioassertisinstance(prune_ratio,(float,list))n_conv=len([mforminmodel.backboneifisinstance(m,nn.Conv2d)])# note that for the ratios, it affects the previous conv output and next# conv input, i.e., conv0 - ratio0 - conv1 - ratio1-...ifisinstance(prune_ratio,list):assertlen(prune_ratio)==n_conv-1else:# convert float to listprune_ratio=[prune_ratio]*(n_conv-1)# we prune the convs in the backbone with a uniform ratiomodel=copy.deepcopy(model)# prevent overwrite# we only apply pruning to the backbone featuresall_convs=[mforminmodel.backboneifisinstance(m,nn.Conv2d)]all_bns=[mforminmodel.backboneifisinstance(m,nn.BatchNorm2d)]# apply pruning. we naively keep the first k channelsassertlen(all_convs)==len(all_bns)fori_ratio,p_ratioinenumerate(prune_ratio):prev_conv=all_convs[i_ratio]prev_bn=all_bns[i_ratio]next_conv=all_convs[i_ratio+1]original_channels=prev_conv.out_channels# same as next_conv.in_channelsn_keep=get_num_channels_to_keep(original_channels,p_ratio)# prune the output of the previous conv and bnprev_conv.weight.set_(prev_conv.weight.detach()[:n_keep])prev_bn.weight.set_(prev_bn.weight.detach()[:n_keep])prev_bn.bias.set_(prev_bn.bias.detach()[:n_keep])prev_bn.running_mean.set_(prev_bn.running_mean.detach()[:n_keep])prev_bn.running_var.set_(prev_bn.running_var.detach()[:n_keep])# prune the input of the next conv (hint: just one line of code)##################### YOUR CODE STARTS HERE #####################next_conv.weight.set_(next_conv.weight.detach()[:,:n_keep])##################### YOUR CODE ENDS HERE #####################returnmodel
记得一提的是框架已经给出的代码,所谓 Channel 是在卷积中才会出现的,剪枝也是对输出通道进行剪枝。例如,当前卷积核中本来有 k 个通道输出,剪枝后变成 l 个输出通道,那么下一层的卷积核的输入通道也要相对应地从 k 变成 l。此外一般 Conv 后都会有一个 Batch Norm,应该这个 Conv 的 weight、bias、running_mean 和 running_var 也要一起进行剪枝。
# function to sort the channels from important to non-importantdefget_input_channel_importance(weight):in_channels=weight.shape[1]# importances = []# # compute the importance for each input channel# for i_c in range(weight.shape[1]):# channel_weight = weight.detach()[:, i_c]# ##################### YOUR CODE STARTS HERE ###################### importance = torch.linalg.norm(channel_weight, ord="fro", dim# ##################### YOUR CODE ENDS HERE ###################### importances.append(importance.view(1))# return torch.cat(importances)returntorch.linalg.vector_norm(weight,ord=2,dim=(0,2,3))@torch.no_grad()defapply_channel_sorting(model):model=copy.deepcopy(model)# do not modify the original model# fetch all the conv and bn layers from the backboneall_convs=[mforminmodel.backboneifisinstance(m,nn.Conv2d)]all_bns=[mforminmodel.backboneifisinstance(m,nn.BatchNorm2d)]# iterate through conv layersfori_convinrange(len(all_convs)-1):# each channel sorting index, we need to apply it to:# - the output dimension of the previous conv# - the previous BN layer# - the input dimension of the next conv (we compute importance here)prev_conv=all_convs[i_conv]prev_bn=all_bns[i_conv]next_conv=all_convs[i_conv+1]# note that we always compute the importance according to input channelsimportance=get_input_channel_importance(next_conv.weight)# sorting from large to smallsort_idx=torch.argsort(importance,descending=True)# apply to previous conv and its following bnprev_conv.weight.copy_(torch.index_select(prev_conv.weight.detach(),0,sort_idx))fortensor_namein['weight','bias','running_mean','running_var']:tensor_to_apply=getattr(prev_bn,tensor_name)tensor_to_apply.copy_(torch.index_select(tensor_to_apply.detach(),0,sort_idx))# apply to the next conv input (hint: one line of code)##################### YOUR CODE STARTS HERE #####################next_conv.weight.copy_(torch.index_select(next_conv.weight.detach(),1,sort_idx))##################### YOUR CODE ENDS HERE #####################returnmodel