Skip to content

Commit

Permalink
Reactivate commented out tests (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Mar 27, 2021
1 parent 3e28705 commit 488abef
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 37 deletions.
3 changes: 2 additions & 1 deletion src/util/zygote_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ function Zygote._pullback(
ctx::AContext, ::Type{Symmetric}, X::StridedMatrix{<:Real}, uplo=:U,
)
function Symmetric_pullback(Δ)
return nothing, _symmetric_back(Δ, uplo), nothing
ΔX = Δ === nothing ? nothing : _symmetric_back(Δ, uplo)
return nothing, ΔX, nothing
end
return Symmetric(X, uplo), Symmetric_pullback
end
Expand Down
1 change: 0 additions & 1 deletion test/models/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ println("lgssm:")
max_primal_allocs=10,
max_forward_allocs=35,
max_backward_allocs=50,
# check_allocs=false,
check_allocs=storage.val isa SArrayStorage,
)
end
Expand Down
69 changes: 34 additions & 35 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,21 +418,20 @@ function test_interface(
x_val = rand(rng, x)
y = conditional_rand(rng, conditional, x_val)

# @testset "rand" begin
# @test length(y) == dim_out(conditional)
# args = (conditional, x_val)
# @code_warntype conditional_rand(y, args...)
# check_infers && @inferred conditional_rand(rng, args...)
# if check_adjoints
# adjoint_test(
# (f, x) -> conditional_rand(MersenneTwister(123456), f, x), args;
# check_infers=check_infers, kwargs...,
# )
# end
# if check_allocs
# check_adjoint_allocations(conditional_rand, (rng, args...); kwargs...)
# end
# end
@testset "rand" begin
@test length(y) == dim_out(conditional)
args = (conditional, x_val)
check_infers && @inferred conditional_rand(rng, args...)
if check_adjoints
adjoint_test(
(f, x) -> conditional_rand(MersenneTwister(123456), f, x), args;
check_infers=check_infers, kwargs...,
)
end
if check_allocs
check_adjoint_allocations(conditional_rand, (rng, args...); kwargs...)
end
end

@testset "predict" begin
@test predict(x, conditional) isa Gaussian
Expand All @@ -450,26 +449,26 @@ function test_interface(
@test cov(pred_marg) isa Diagonal
end

# @testset "posterior_and_lml" begin
# args = (x, conditional, y)
# @test posterior_and_lml(args...) isa Tuple{Gaussian, Real}
# check_infers && @inferred posterior_and_lml(args...)
# if check_adjoints
# (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...))
# ∂args = map(rand_tangent, args)
# adjoint_test(posterior_and_lml, (Δx, Δlml), args, ∂args)
# adjoint_test(posterior_and_lml, (Δx, nothing), args, ∂args)
# adjoint_test(posterior_and_lml, (nothing, Δlml), args, ∂args)
# adjoint_test(posterior_and_lml, (nothing, nothing), args, ∂args)
# end
# if check_allocs
# (Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...))
# check_adjoint_allocations(posterior_and_lml, (Δx, Δlml), args; kwargs...)
# check_adjoint_allocations(posterior_and_lml, (nothing, Δlml), args; kwargs...)
# check_adjoint_allocations(posterior_and_lml, (Δx, nothing), args; kwargs...)
# check_adjoint_allocations(posterior_and_lml, (nothing, nothing), args; kwargs...)
# end
# end
@testset "posterior_and_lml" begin
args = (x, conditional, y)
@test posterior_and_lml(args...) isa Tuple{Gaussian, Real}
check_infers && @inferred posterior_and_lml(args...)
if check_adjoints
(Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...))
∂args = map(rand_tangent, args)
adjoint_test(posterior_and_lml, (Δx, Δlml), args, ∂args)
adjoint_test(posterior_and_lml, (Δx, nothing), args, ∂args)
adjoint_test(posterior_and_lml, (nothing, Δlml), args, ∂args)
adjoint_test(posterior_and_lml, (nothing, nothing), args, ∂args)
end
if check_allocs
(Δx, Δlml) = rand_zygote_tangent(posterior_and_lml(args...))
check_adjoint_allocations(posterior_and_lml, (Δx, Δlml), args; kwargs...)
check_adjoint_allocations(posterior_and_lml, (nothing, Δlml), args; kwargs...)
check_adjoint_allocations(posterior_and_lml, (Δx, nothing), args; kwargs...)
check_adjoint_allocations(posterior_and_lml, (nothing, nothing), args; kwargs...)
end
end
end

"""
Expand Down

2 comments on commit 488abef

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32949

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.0 -m "<description of version>" 488abef0e2e1e46955125b2be247c05a215672c9
git push origin v0.5.0

Please sign in to comment.