pytorch官方的剪枝工具分为非结构化和结构化剪枝两种,非结构化剪枝会随机地把一些权重参数变为0,结构化剪枝则将某个维度某些通道随机变成0,但这套工具不会真正输出剪枝后的模型,只是将模型变稀疏了,只有用某些特殊前向库,才能加快模型运行速度。最近找到一个库,能根据一定策略找到权重中作用较小的部分,用index表示,并保留对应的模型。
首先构造输入,前向运行一次模型,得到模型对应的计算图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
对于计算图中的各种带权重的层,根据指定策略(目前支持ln,l1,l2等)比较权重中各个数值,找到第k小的数对应的index,记为将要删除的部分,
strategy = tp.strategy.L1Strategy()
# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9, ...]
得到index后,对依赖该层的其他层,递归使用对应的prune函数进行剪枝,得到每一层的剪枝计划。
pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
在剪枝计划执行时,根据index改变model中对应层的定义,使得model中的channel数变少。
# plune plane exec source code
def exec(self, dry_run=False):
num_pruned = 0
for dep, idxs in self._plans: # idxs were computed by specified strategy
_, n = dep(idxs, dry_run=dry_run)
num_pruned += n
return num_pruned
以卷积层为例,执行剪枝计划时,根据strategy提供的idx,对weight和bias进行修剪:
class ConvPruning(BasePruningFunction):
@staticmethod
def prune_params(layer: nn.Module, idxs: Sequence[int]) -> nn.Module:
keep_idxs = list(set(range(layer.out_channels)) - set(idxs))
layer.out_channels = layer.out_channels-len(idxs)
if not layer.transposed:
layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
else:
layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs])
if layer.bias is not None:
layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
return layer
@staticmethod
def calc_nparams_to_prune(layer: nn.Module, idxs: Sequence[int]) -> int:
nparams_to_prune = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0)
return nparams_to_prune
如此,剪枝后model.forward时,运行的卷积层就是剪枝过的版本啦。
- CIFAR10上ResNet18剪枝
- 注意事项:
- 剪枝是对遍历网络的每一个单独的层进行剪枝
- 剪枝时虽然有一定的策略,但不能保证每个剪掉这些层之后损失就是最小的
- 每层剪枝的比例可以不同,可以考虑人工将网络划分为几个部分,每个部分可以设置不同的剪枝比例,但具体应该设置多少,可以从剪枝后模型的执行结果进行评估,可以考虑写一个遍历算法,或者写个启发式搜索来找到最佳比例;
- 剪枝后应当在新的模型上继续fine tune一定epoch,以得到最适合此网络结构的权重