Skip to content

Commit

Permalink
Merge pull request #31 from PumasAI/checkinputs
Browse files Browse the repository at this point in the history
Check if chain inputs match
  • Loading branch information
chriselrod committed Feb 23, 2022
2 parents c08aea2 + ce213a4 commit f9398f2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ function _input_dims(t::Tuple{L,Vararg}) where {L}
end
chain_input_dims(c::SimpleChain) = _input_dims(c.layers)


_verify_chain(::Tuple{}, _) = nothing
function _verify_chain(layers::Tuple{L,Vararg}, inputdim = _input_dims(layers)) where {L}
l = first(layers)
Expand All @@ -24,6 +25,7 @@ function _verify_chain(layers::Tuple{L,Vararg}, inputdim = _input_dims(layers))
_verify_chain(Base.tail(layers), d)
end


SimpleChain(l::Vararg) = (_verify_chain(l); SimpleChain(l, UInt8[]))
SimpleChain(l::Tuple) = (_verify_chain(l); SimpleChain(l, UInt8[]))
Base.similar(c::SimpleChain) = SimpleChain(c.layers, similar(c.memory))
Expand Down Expand Up @@ -70,7 +72,21 @@ parameter_free(x) = numparam(x) == 0
d2 > length(memory) && resize!(memory, d2)
d
end

matches(x::Integer, y::Integer) = x == y
matches(x::Tuple{Integer,Vararg}, y::Integer) = first(x) == y
matches(x::Integer, y::Tuple{Integer,Vararg}) = x == first(y)
matches(::Tuple{}, ::Tuple) = true
function matches(x::Tuple{X,Vararg}, y::Tuple{Y,Vararg}) where {X,Y}
matches(first(x), first(y)) && matches(Base.tail(x), Base.tail(y))
end
function verify_arg(c, arg)
if !matches(chain_input_dims(c), size(arg))
throw(ArgumentError("Input argument: !matches(chain_input_dims(c), size(arg))"))
end
end
function (c::SimpleChain)(arg, params)
verify_arg(c, arg)
@unpack layers, memory = c
resize_memory!(layers, memory, arg)
unsafe_chain(layers, params, memory, arg)
Expand All @@ -96,6 +112,8 @@ end
init_params!(::Tuple{}, p::Ptr) = nothing
init_params::SimpleChain, ::Type{T} = Float32) where {T} = init_params!(Λ, Vector{T}(undef, numparam(Λ)))



"""
Allowed destruction:
Expand All @@ -109,6 +127,7 @@ It is also allowed to destroy the previous layer's return `B` to produce `B̄` (
Thus, the pullback is not allowed to depend on `C`, as it may have been destroyed in producing `C̄`.
"""
function valgrad!(g, c::SimpleChain, arg, params)
verify_arg(c, arg)
@unpack layers, memory = c
resize_memory!(layers, memory, arg)
unsafe_valgrad!(g, layers, params, memory, arg)
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ SquaredLoss"""

p = SimpleChains.init_params(sc, T);
g = similar(p);
@test_throws ArgumentError sc(rand(T,23,2), p)
@test_throws ArgumentError sc(rand(T,23), p)
@test_throws MethodError sc(Array{T,0}(undef), p)
@test_throws ArgumentError valgrad!(g, sc, rand(T,23,2), p)
@test_throws ArgumentError valgrad!(g, sc, rand(T,23), p)
valgrad!(g, scflp, x, p)
if VERSION < v"1.8-DEV" # FIXME: remove check when Zygote stops segfaulting on 1.8-DEV
@test g == Zygote.gradient(p -> FrontLastPenalty(sc, L2Penalty(2.3), L1Penalty(0.45))(x, p), p)[1]
Expand Down

0 comments on commit f9398f2

Please sign in to comment.