Skip to content

Latest commit

 

History

History
69 lines (55 loc) · 3.42 KB

torch_prune.md

File metadata and controls

69 lines (55 loc) · 3.42 KB

Pytorch剪枝

pytorch官方的剪枝工具分为非结构化和结构化剪枝两种,非结构化剪枝会随机地把一些权重参数变为0,结构化剪枝则将某个维度某些通道随机变成0,但这套工具不会真正输出剪枝后的模型,只是将模型变稀疏了,只有用某些特殊前向库,才能加快模型运行速度。最近找到一个库,能根据一定策略找到权重中作用较小的部分,用index表示,并保留对应的模型。

原理

首先构造输入,前向运行一次模型,得到模型对应的计算图

DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

对于计算图中的各种带权重的层,根据指定策略(目前支持ln,l1,l2等)比较权重中各个数值,找到第k小的数对应的index,记为将要删除的部分,

strategy = tp.strategy.L1Strategy() 
# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9, ...]

得到index后,对依赖该层的其他层,递归使用对应的prune函数进行剪枝,得到每一层的剪枝计划。

pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )

在剪枝计划执行时,根据index改变model中对应层的定义,使得model中的channel数变少。

# plune plane exec source code
def exec(self, dry_run=False):
    num_pruned = 0
    for dep, idxs in self._plans: # idxs were computed by specified strategy
        _, n = dep(idxs, dry_run=dry_run)
        num_pruned += n
    return num_pruned

以卷积层为例,执行剪枝计划时,根据strategy提供的idx,对weight和bias进行修剪:

class ConvPruning(BasePruningFunction):
    @staticmethod
    def prune_params(layer: nn.Module, idxs: Sequence[int]) -> nn.Module: 
        keep_idxs = list(set(range(layer.out_channels)) - set(idxs))
        layer.out_channels = layer.out_channels-len(idxs)
        if not layer.transposed:
            layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
        else:
            layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs])
        if layer.bias is not None:
            layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
        return layer
    
    @staticmethod
    def calc_nparams_to_prune(layer: nn.Module, idxs: Sequence[int]) -> int: 
        nparams_to_prune = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0)
        return nparams_to_prune

如此,剪枝后model.forward时,运行的卷积层就是剪枝过的版本啦。

样例

  • CIFAR10上ResNet18剪枝
  • 注意事项:
    • 剪枝是对遍历网络的每一个单独的层进行剪枝
    • 剪枝时虽然有一定的策略,但不能保证每个剪掉这些层之后损失就是最小的
    • 每层剪枝的比例可以不同,可以考虑人工将网络划分为几个部分,每个部分可以设置不同的剪枝比例,但具体应该设置多少,可以从剪枝后模型的执行结果进行评估,可以考虑写一个遍历算法,或者写个启发式搜索来找到最佳比例;
    • 剪枝后应当在新的模型上继续fine tune一定epoch,以得到最适合此网络结构的权重