Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent function to Flux.activations to obtain hidden layers output #126

Open
mrazomej opened this issue Feb 5, 2023 · 3 comments
Open

Comments

@mrazomej
Copy link

mrazomej commented Feb 5, 2023

I am trying to train a small autoencoder. The value of the latent space variables is important for my application; therefore, I am looking for an equivalent function to Flux.activations, where one can save the values of the intermediate hidden layers. Is there such a thing in SimpleChains.jl?

@mrazomej
Copy link
Author

mrazomej commented Feb 5, 2023

I haven't figured out how to do this. But I wrote a very hacky function that can transfer parameters from a SimpleChains.Simplechain network to a Flux.Chain one. There are probably much better and cleaner ways to do this, but as for now, the way this works is that one needs to declare a Flux.Chain with the exact same architecture as the trained SimpleChains.SimpleChain network to copy parameters from one to the other. Here's the function in case it is useful to anyone with similar issues:

@doc raw"""
    `simple_to_flux(param, fluxchain)`

Function to transfer the parameters from a `SimpleChains.jl` trained network to
a `Flux.jl` network with the same architecture for downstream manipulation.

NOTE: This function is agnostic to the activation functions in the
`SimpleChains.jl` network from where `param` was extracted. Therefore, for this
transfer to make sense, you must make sure that both networks have the same
architecture!

# Arguments
- `param::Vector{Float32}`: List of parameters obtained from a `SimpleChains.jl`
  network.
- `fluxchain::Flux.Chain`: Multi-layer perceptron defined in the `Flux.jl`
  framework.

# Returns
- `Flux.Chain`: Multi-layer perceptron of the same architecture as `fluxchain`
  but with modified parameters dictated by `param`.
"""
function simple_to_flux(param::Vector{Float32}, fluxchain::Flux.Chain)
    # Extract list of parameters from the Flux autoencoder. NOTE: The collect
    # command transforms the Flux.Params object into a simple list of arrays.
    param_flux = collect(Flux.params(fluxchain))
    # Initialize object where to transfer parameters
    param_transfer = similar(param_flux)

    # Initialize parameter index counter to keep track of the already used
    # parameters
    idx = 1

    # Loop through list of parameters
    for (i, p) in enumerate(param_flux)
        # Initialize object to save transferred parameters
        par = similar(p)
        # Extract parameters using the current index and the length of the
        # parameters
        par = param[idx:(idx+length(par)-1)]
        # Save parameter values with the correct shape
        param_transfer[i] = reshape(par, size(p))
        # Update index for next iteration
        idx += length(par)
    end # for

    # Make parameter transfer a Flux.Params object
    param_transfer = Flux.Params(param_transfer)

    # Initialize list to save Flux.Dense layers that will later be converted
    # into a Flux.Chain
    layers_transfer = Array{Flux.Dense}(
        undef, Int64(length(param_transfer) / 2)
    )

    # Loop through parameters, building the layers one by one
    for (i, p) in enumerate(1:2:length(param_transfer))
        # Generate Flux.Dense layer with weights and biases as the SimpleChains
        # network, and the actuvation function from the Flux network
        layers_transfer[i] = Flux.Dense(
            param_transfer[p], param_transfer[p+1], fluxchain[i].σ
        )
    end # for

    # Return Flux.jl multi-layer perceptron
    return Flux.Chain(layers_transfer...)

end # function

@chriselrod
Copy link
Contributor

I really need to get around to adding docs describing the layer interface.

Then, I'd just define a layer that forwards all inputs, but also push!(place_youre_storing, Array(outputfromprevlayer)).
Then you can insert that layer wherever you'd like to save.

BTW, feel free to make a PR adding the above as a package extension when we using Flux.
However

  1. that requires Julia 1.9
  2. I need to actually get tests passing again on SimpleChains main branch (after upgrading to StrideArraysCore 0.4), but I haven't found the time.

@CodeReclaimers
Copy link
Contributor

CodeReclaimers commented Mar 3, 2023

Here's an attempt at a minimal parameter-free custom layer definition, please let me know if I'm doing anything sketchy here. I can put some comments on the custom layer functions if that might be useful as a placeholder for a full layer interface description.

using SimpleChains

struct MyDoNothingLayer

end

function (mdnl::MyDoNothingLayer)(B::AbstractVecOrMat{T}, p::Ptr, pu::Ptr{UInt8}) where {T}
	B, p, pu
end

Base.show(io::IO, mdnl::MyDoNothingLayer) = print(io, "Do-nothing layer: ", pointer_from_objref(Ref(mdnl)))

SimpleChains.numparam(::MyDoNothingLayer, inputdim) = static(0), inputdim

SimpleChains.parameter_free(::MyDoNothingLayer) = true

SimpleChains.init_params!(::MyDoNothingLayer, p, inputdim, rng) = p, inputdim

SimpleChains.forward_layer_output_size(::Val{T}, a::MyDoNothingLayer, inputdim) where {T} = static(0), inputdim

function SimpleChains.valgrad_layer!(pg::Ptr{T}, mdnl::MyDoNothingLayer, x, p::Ptr{T}, pu::Ptr{UInt8}) where {T}
	pg, x, p, pu
end

function SimpleChains.pullback!(
  __::Ptr{T},
  mdnl::MyDoNothingLayer,
  C̄,
  B,
  p::Ptr{T},
  pu::Ptr{UInt8},
  pu2::Ptr{UInt8}
) where {T}
  C̄, pu2
end



# Below is an example network demo lifted from https://julialang.org/blog/2022/04/simple-chains/
# with a MyDoNothingLayer added.

function f(x)
  N = Base.isqrt(length(x))
  A = reshape(view(x, 1:N*N), (N,N))
  expA = exp(A)
  vec(expA)
end

T = Float32
D = 2 # 2x2 matrices
X = randn(T, D*D, 10_000); # random input matrices
Y = reduce(hcat, map(f, eachcol(X))); # `mapreduce` is not optimized for `hcat`, but `reduce` is

Xtest = randn(T, D*D, 10_000)
Ytest = reduce(hcat, map(f, eachcol(Xtest)))

mlpd = SimpleChain(
  static(4),
  TurboDense(tanh, 32),
  MyDoNothingLayer(),
  TurboDense(tanh, 16),
  TurboDense(identity, 4)
)

p = SimpleChains.init_params(mlpd)
G = SimpleChains.alloc_threaded_grad(mlpd)

mlpdloss = SimpleChains.add_loss(mlpd, SquaredLoss(Y))
mlpdtest = SimpleChains.add_loss(mlpd, SquaredLoss(Ytest))

report = let mtrain = mlpdloss, X=X, Xtest=Xtest, mtest = mlpdtest
  p -> begin
    let train = mlpdloss(X, p), test = mlpdtest(Xtest, p)
      @info "Loss:" train test
    end
  end
end

report(p)
opt = SimpleChains.ADAM()
for _ in 1:3
  SimpleChains.train_unbatched!(G, p, mlpdloss, X, opt, 10_000)
  report(p)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants