Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Pruning methods/library #13

Closed
darsnack opened this issue Oct 26, 2020 · 15 comments
Closed

Pruning methods/library #13

darsnack opened this issue Oct 26, 2020 · 15 comments
Assignees
Labels
enhancement New feature or request interface Interface design or implementation issue needs-discussion Community input wanted parity:pytorch Needed for feature parity with PyTorch

Comments

@darsnack
Copy link
Member

I'm going to start looking at creating a pruning package with Flux. Will update this soon with more details.

@darsnack darsnack self-assigned this Oct 26, 2020
@darsnack darsnack added enhancement New feature or request interface Interface design or implementation issue needs-discussion Community input wanted parity:pytorch Needed for feature parity with PyTorch labels Oct 26, 2020
@DrChainsaw
Copy link

For structured pruning, this is pretty powerful as it handles size alignment between layers in an optimal fashion (through JuMP and Cbc): https://github.com/DrChainsaw/NaiveNASflux.jl

There are only two built in methods as of now, but the general API is to just provide a function which returns the value per neuron when given a vertex (i.e layer) of the model. See the first example.

@darsnack
Copy link
Member Author

Looks like this is a substantial framework for activation based pruning where the mask is not known a priori. Is my understanding correct?

@DrChainsaw
Copy link

I would say the framework itself is something like a more general version of network morphisms and pruning metrics themselves are not really core functionality.

When it comes to selecting which neurons to keep in case you decrese the size of at least one layer (e.g. the pruning use case), you can supply any metric as a vector and the core functionalty will try to maximize it for the whole network given the constraints that things need to fit together. I guess this can be thought of as creating the mask.

ActivationContribution happened to be a suitable wrapper for an implementation of one metric, but if you for example choose to not wrap layers in it, the default metric will just use the absolute magnitude of the weights instead.

@darsnack
Copy link
Member Author

Thanks for the clarification. I'll spend some time playing around with the package to learn more about it.

@DrChainsaw
Copy link

Awesome! Please post issues if you find anything wonky or difficult to comprehend!

@DrChainsaw
Copy link

Link to the pruning example from the ONNX issue: #10 (comment)

@DrChainsaw
Copy link

DrChainsaw commented Nov 1, 2020

Here is a take on a very simple “Package” for structured pruning using NaiveNASflux.

I have made some updates to NaiveNASlib so it is possible to just "mark" neurons to be pruned by giving them a negative value. Note that a negative value by itself is not a guarantee for pruning as things like residual connections could make the net value of keeping a negative valued neuron positive due to how elementwise addition ties neurons from multiple layers together.

It has two main functions

  1. prune(model, fraction) which prunes approximately fraction of the neurons by removing the corresponding parameters
  2. pruneuntil(model, accept; depth=6) which simply uses prune to try to find the most amount of pruning where accept(pruned_model) returns true in depth number of attempts.
Module with some extra code for experiments

(PruningExample) pkg> add Statistics, MLDatasets NaiveNASflux, https://github.com/DrChainsaw/ONNXmutable.jl

module PruningExample

using ONNXmutable, NaiveNASflux, Statistics
import MLDatasets

export prune, pruneuntil

function pruning_metric(v, offs)
    val = neuron_value(v) # neuron_value defaults to magnitude of parameters along activation dimension
    ismissing(val) && return zeros(nout_org(v)) # Layers with no parameters return missing by default
    return val .- offs
end

function prune(g::CompGraph, fraction)
    @assert 0 < fraction < 1 "Fraction of neurons to prune must be between 0 and 1"
    gnew = copy(g)
    # We don't want to change the number of outputs from the model, so exclude all layers for which a change in number of neurons leads to a change in model output size
    pvs = prunable(gnew)

    # Find the fraction neurons with smallest value 
    allvals = mapreduce(neuron_value, vcat, pvs) |> skipmissing |> collect
    cutoff = partialsort(allvals, round(Int, fraction*length(allvals)))

    # Prune the model
    Δoutputs(OutSelectRelaxed() |> ApplyAfter, gnew, v -> v in pvs ? pruning_metric(v, cutoff) : fill(100, nout(v)))
    return gnew
end

prunable(g::CompGraph) = mapreduce(prunable, vcat, g.outputs)
function prunable(v, ok = false)
    vs = mapreduce(vcat, inputs(v); init=AbstractVertex[]) do vi
        prunable(vi, ok || isabsorb(v))
    end
    ok ? unique(vcat(v, vs)) : unique(vs)
end

isabsorb(v) = isabsorb(trait(v))
isabsorb(t::DecoratingTrait) = isabsorb(base(t))
isabsorb(t::SizeAbsorb) = true
isabsorb(t::MutationSizeTrait) = false 

function pruneuntil(g::CompGraph, accept; depth = 6)
    # Binary search how much we can prune and still meet the acceptance criterion
    step = 1
    fraction = 0.5
    gaccepted = g
    faccepted = 0.0
    while step < 2^depth
        @info "Prune $fraction of parameters"
        g′ = prune(g, fraction)
        step *= 2
        if accept(g′)
            faccepted = fraction
            gaccepted = g′
            fraction += fraction / step
        else
            fraction -= fraction / step
        end
    end
    return gaccepted
end


# Auxiallary stuff to run the experiment
export resnet, faccept, nparams
const resnetfile= Ref{String}("")
function resnet() 
   if !isfile(resnetfile[])
        # I couldn't find any SOTA ONNX models for CIFAR10 online. 
        # This is my not very successful attempt at replicating these experiments: https://github.com/davidcpage/cifar10-fast/blob/master/experiments.ipynb
        # Test accuracy is around 92% iirc
        resnetfile[] = download("https://github.com/DrChainsaw/NaiveGAExperiments/raw/master/lamarckism/pretrained/resnet.onnx")
   end
   return CompGraph(resnetfile[])
end

function cifar10accuracy(model, batchsize=16; nbatches=cld(10000, batchsize))
    x,y = MLDatasets.CIFAR10.testdata()
    itr = Flux.Data.DataLoader((x, Flux.onehotbatch(y, sort(unique(y)))); batchsize);
    xm = mean(x) |> Float32
    xs = std(x; mean=xm) |> Float32
    mean(Iterators.take(itr, nbatches)) do (xb, yb)
        xb_std = @. (Float32(xb) - xm) / xs
        sum(Flux.onecold(model(xb_std)) .== Flux.onecold(yb))
    end / batchsize
end

function faccept(model)
    # I don't have a GPU on this computer so in this example I'll just use a small subset of the test set
    acc = cifar10accuracy(model, 32; nbatches=10)
    @info "\taccuracy: $acc"
    return acc > 0.9
end

nparams(m) = mapreduce(prod  size, +, params(m))

end

Example which prunes an imported (quite poor tbh) CIFAR10 model as much as possible while staying above 90% accuracy on the test set:

julia> using PruningExample

julia> f = resnet();

julia> nparams(f)
6575370

julia> f′ = pruneuntil(f, faccept);
[ Info: Prune 0.5 of parameters
[ Info:         accuracy: 0.478125
[ Info: Prune 0.25 of parameters
[ Info:         accuracy: 0.9125
[ Info: Prune 0.3125 of parameters
[ Info:         accuracy: 0.890625
[ Info: Prune 0.2734375 of parameters
[ Info:         accuracy: 0.909375
[ Info: Prune 0.29052734375 of parameters
[ Info:         accuracy: 0.909375
[ Info: Prune 0.2996063232421875 of parameters
[ Info:         accuracy: 0.89375

julia> nparams(f′) # Close enough to 29% fewer parameters 
4926056

Note how this is one slippery slope on the dark path to NAS as one immediately starts thinking things like "what if I retrain the model just a little after pruning" and then "maybe I should try to increase the size and see if things get better, or why not add some more layers and remove the ones which perform poorly and...". Disclaimer: I toy with NAS in my spare time as a hobby and I'm not wasting anyones money (except my own electricty bill) on it.

@darsnack
Copy link
Member Author

darsnack commented Jul 7, 2021

Currently being worked on in MaskedArrays.jl and FluxPrune.jl.

@DrChainsaw
Copy link

Looks very cool!

Just to expose my ignorance on the subject: What is the end goal after masking the parameters?

For example, it is not clear to me if there are any benefits of having some fraction of parameters masked. Or should one convert them to sparse arrays if the amount of masking is above some theshold when it is beneficial? Does that give benefits on a GPU? Or does one just learn that a smaller model works and then builds and retrains the smaller model? Does not seem to be doable with unstructured pruning, or?

@darsnack
Copy link
Member Author

Yeah that part is currently missing, but once you have finished pruning, you have a "freeze" step that turns all the masked arrays into a compressed form. For unstructured that could be a sparse array. For structured that could be "re-building" the model after dropping channels.

Probably, I will make the "freeze" step call MaskedArrays.freeze which just turns each masked array into the corresponding Array with zeros. I might include some common helpers, but I will leave the decision of going to sparse formats, etc. to the user. The reason is that AFAIK there is no standard next step. How to most effectively take advantage of zeros is highly hardware dependent. So, it is up to the person pruning a model to decide how best to take advantage of the pruned result.

@DrChainsaw
Copy link

Makes sense to me, guess I wasn't too far off base then.

For structured that could be "re-building" the model after dropping channels.

Just a word of warning to keep you from going insane: This is harder than it seems as the change propagates to the next layer and if that is things like concatenations and/or elementwise operations things get out of hand quickly. This is basically what I tried to do with NaiveNASlib after having some success with a simple version in another project and it ended with me throwing in the towel and reaching for JuMP when trying to make use of it in a more or less unrestricted NAS setting.

@darsnack
Copy link
Member Author

Yeah the structured pruning literature special cases the handling of skip connections for this reason. They don't provide a general algorithm for all kinds of operations, and I don't intend to come up with one as a first pass. For now, I am just going to implement what's already out there, and hopefully throw a useful error when a user does some pruning that would result in mismatched dimensions.

@DrChainsaw
Copy link

I haven't seen the MIP approach in litterature and I must say I am pretty happy with the result.

I guess the drawback is that there is still a large set of opertations which can't use any of the formulations which NaiveNASlib ships and then one must write the constraints for each such operation (e.g. depthwise/grouped convolutions). I don't think this is much easier to handle with any other approach either (except maybe the linearization parts).

@darsnack
Copy link
Member Author

I haven't seen the MIP approach in litterature and I must say I am pretty happy with the result.

Yes, I always thought this was a really cool show case of the power of Julia!

@darsnack
Copy link
Member Author

darsnack commented Aug 9, 2022

The initial sketch of the library is FluxPrune.jl. I am going to close this issue in favor of opening specific issues over on that repo. I also updated FluxML/Flux.jl#1431 to reflect this.

@darsnack darsnack closed this as completed Aug 9, 2022
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request interface Interface design or implementation issue needs-discussion Community input wanted parity:pytorch Needed for feature parity with PyTorch
Projects
None yet
Development

No branches or pull requests

2 participants