Skip to content

Commit

Permalink
Merge pull request #25 from PumasAI/showmethods
Browse files Browse the repository at this point in the history
Show Methods
  • Loading branch information
chriselrod committed Feb 22, 2022
2 parents d11d5d6 + bb112e6 commit 503d8f6
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 134 deletions.
6 changes: 4 additions & 2 deletions src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ init_params!(::Activation, p) = p

output_size(::Val{T}, a::Activation, s) where {T} = align(prod(s)*(2sizeof(T))), s

Base.show(io::IO, a::Activation) = print(io, "Activation layer applying: ", a.f)

function (a::Activation)(x::AbstractArray{T}, p::Ptr, pu::Ptr{UInt8}) where {T}
f = a.f
C = PtrArray(reinterpret(Ptr{T}, pu), size(x))
Expand All @@ -32,8 +34,8 @@ function valgrad_layer!(pg::Ptr{T}, a::Activation, x, p::Ptr{T}, pu::Ptr{UInt8})
end
pg, C, p, pu
end
@inline pullback_param!(pg::Ptr, ::Activation, C̄, B, p::Ptr, pu::Ptr{UInt8}) = nothing
function pullback!(pg::Ptr{T}, a::Activation, C̄, B, p::Ptr{T}, pu::Ptr{UInt8}, pu2::Ptr{UInt8}) where {T}
@inline pullback_param!(__::Ptr, ::Activation, C̄, B, p::Ptr, pu::Ptr{UInt8}) = nothing
function pullback!(__::Ptr{T}, a::Activation, C̄, B, p::Ptr{T}, pu::Ptr{UInt8}, pu2::Ptr{UInt8}) where {T}
∂C = PtrArray(reinterpret(Ptr{T}, pu), size(C̄))
@turbo for i eachindex(∂C)
C̄[i] *= ∂C[i]
Expand Down
10 changes: 10 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ end
TurboDense{B}(f::F, t::Tuple{I1,I2}) where {F,I1,I2,B} = TurboDense{B,Tuple{I1,I2},F}(f, t)
TurboDense(f::F, t::Tuple{I1,I2}) where {F,I1,I2} = TurboDense{true,Tuple{I1,I2},F}(f, t)

function Base.show(io::IO, td::TurboDense{B}) where {B}
w = B ? "with" : "without"
print(io, "TurboDense $(td.dims) $w bias.")
if td.f !== identity
println(io)
show(io, Activation(td.f))
end
end


input_dims(d::TurboDense) = getfield(d.dims, 1)
function numparam(d::TurboDense{false})
id, od = getfield(d,:dims)
Expand Down
6 changes: 6 additions & 0 deletions src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Dropout(x::T, rng = local_rng()) where {T <: Union{Float32,Float64}} = Dropout(B
getrng(d::Dropout{Nothing}) = local_rng()
getrng(d::Dropout{<:VectorizedRNG.AbstractRNG}) = getfield(d, :rng)

function Base.show(io::IO, d::Dropout)
print(io, "Dropout(p=$(Float64(d.rng)/0xffffffff))")
end


gradval(::Val{T}, d::Dropout) where {T} = T(0xffffffff) / (T(0xffffffff) - d.p)
numparam(::Dropout) = 0
parameter_free(::Dropout) = true
Expand All @@ -31,6 +36,7 @@ function (d::Dropout)(B::AbstractVecOrMat{T}, p::Ptr, pu::Ptr{UInt8}) where {T}
B, p, pu # inference
end


getpcmp(::StaticInt{W}, ::StaticInt{W}, x) where {W} = x
getpcmp(::StaticInt{W}, ::StaticInt{WU}, x) where {W,WU} = getpcmp(StaticInt(W), StaticInt(WU), x, Static.gt(StaticInt(W), StaticInt(WU)))
function getpcmp(::StaticInt{W}, ::StaticInt{WU}, x::UInt32, ::True) where {W,WU}
Expand Down
6 changes: 5 additions & 1 deletion src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ init_params!(::AbstractLoss, p) = p

squared_loss(chn::SimpleChain, y) = add_loss(chn, SquaredLoss(y))

Base.show(io::IO, ::SquaredLoss) = print(io, "SquaredLoss")

function chain_valgrad!(pg, arg::AbstractArray{T}, layers::Tuple{SquaredLoss}, p::Ptr, pu::Ptr{UInt8}) where {T}
y = getfield(getfield(layers, 1), :y)
g = PtrArray(stridedpointer(Base.unsafe_convert(Ptr{T}, pu), bytestrideindex(arg)), size(arg), StrideArraysCore.val_dense_dims(arg))
Expand Down Expand Up @@ -62,8 +64,10 @@ target(sl::AbsoluteLoss) = getfield(sl, :y)

absolute_loss(chn::SimpleChain, y) = add_loss(chn, AbsoluteLoss(y))

Base.show(io::IO, ::AbsoluteLoss) = print(io, "AbsoluteLoss")


function chain_valgrad!(pg, arg::AbstractArray{T}, layers::Tuple{AbsoluteLoss}, p::Ptr, pu::Ptr{UInt8}) where {T}
function chain_valgrad!(__, arg::AbstractArray{T}, layers::Tuple{AbsoluteLoss}, _::Ptr, pu::Ptr{UInt8}) where {T}
y = getfield(getfield(layers, 1), :y)
g = PtrArray(stridedpointer(Base.unsafe_convert(Ptr{T}, pu), bytestrideindex(arg)), size(arg), StrideArraysCore.val_dense_dims(arg))
s = zero(eltype(g))
Expand Down
23 changes: 23 additions & 0 deletions src/penalty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ end
# Base.FastMath.add_fast(unsafe_valgrad!(g, getchain(Λ), arg, params), apply_penalty!(g, Λ, params))
# end

_penalty_applied_to_sc(_::IO, ::Nothing) = nothing
function _penalty_applied_to_sc(io::IO, sc::SimpleChain)
println(io, " applied to:")
show(io, sc)
end
function Base.show(io::IO, p::AbstractPenalty)
print(io, string(Base.typename(typeof(p)))[begin+9:end-1])
λ = getλ(p)
λ === nothing || print(io, " (λ=)")
_penalty_applied_to_sc(io, getchain(p))
end


UnPack.unpack(c::AbstractPenalty{<:SimpleChain}, ::Val{:layers}) = getfield(getchain(c), :layers)
UnPack.unpack(c::AbstractPenalty{<:SimpleChain}, ::Val{:memory}) = getfield(getchain(c), :memory)

Expand All @@ -43,6 +56,8 @@ apply_penalty!(_, ::NoPenalty, __) = Static.Zero()
(::NoPenalty)(chn::SimpleChain) = NoPenalty(chn)
getpenalty(sc::SimpleChain) = NoPenalty(sc)
getpenalty::AbstractPenalty) = Λ
getλ(::NoPenalty) = nothing


struct L1Penalty{NN,T} <: AbstractPenalty{NN}
chn::NN
Expand All @@ -52,6 +67,7 @@ getchain(p::L1Penalty) = getfield(p,:chn)
L1Penalty::Number) = L1Penalty(nothing, λ)
L1Penalty(p::AbstractPenalty, λ) = L1Penalty(getchain(p), λ)
(p::L1Penalty)(chn::SimpleChain) = L1Penalty(chn, p.λ)
getλ(p::L1Penalty) = getfield(p, )

@inline function apply_penalty::L1Penalty{NN,T2}, p::AbstractVector{T3}) where {T2,T3,NN}
l = zero(T3)
Expand Down Expand Up @@ -81,6 +97,7 @@ end
getchain(p::L2Penalty) = getfield(p,:chn)
L2Penalty(λ) = L2Penalty(nothing, λ)
L2Penalty(p::AbstractPenalty, λ) = L2Penalty(getchain(p), λ)
getλ(p::L2Penalty) = getfield(p, )
(p::L2Penalty)(chn::SimpleChain) = L2Penalty(chn, p.λ)

@inline function apply_penalty::L2Penalty{NN,T2}, p::AbstractVector{T3}) where {T2,T3,NN}
Expand Down Expand Up @@ -120,6 +137,12 @@ FrontLastPenalty(λ₁, λ₂) = FrontLastPenalty(nothing, λ₁, λ₂)
FrontLastPenalty(p::AbstractPenalty, λ₁, λ₂) = FrontLastPenalty(getchain(p), λ₁, λ₂)
(p::FrontLastPenalty)(chn::SimpleChain) = FrontLastPenalty(chn, p.front, p.last)

function Base.show(io::IO, p::FrontLastPenalty)
print(io, "Penalty on all but last layer: "); show(io, p.front)
print(io, "\nPenalty on last layer: "); show(io, p.last)
_penalty_applied_to_sc(io, getchain(p))
end


function split_front_last(c::SimpleChain)
l = c.layers
Expand Down
12 changes: 12 additions & 0 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ SimpleChain(l::Vararg) = SimpleChain(l, UInt8[])
SimpleChain(l::Tuple) = SimpleChain(l, UInt8[])
Base.similar(c::SimpleChain) = SimpleChain(c.layers, similar(c.memory))

_show(::IO, ::Tuple{}) = nothing
function _show(io::IO, t::Tuple{T,Vararg}) where {T}
println(io)
show(io, first(t))
_show(io, Base.tail(t))
end
function Base.show(io::IO, sc::SimpleChain)
print(io, "SimpleChain with the following layers:")
_show(io, sc.layers)
end


"""
Base.front(c::SimpleChain)
Expand Down
Loading

0 comments on commit 503d8f6

Please sign in to comment.