-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix gradient issues with kernelmatrix_diag and use ChainRulesCore (#208)
Co-authored-by: David Widmann <devmotion@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
ae78b73
commit aa2099e
Showing
26 changed files
with
366 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
## Forward Rules | ||
|
||
# Note that this is type piracy as the derivative should be NaN for x == y. | ||
function ChainRulesCore.frule( | ||
(_, Δx, Δy), d::Distances.Euclidean, x::AbstractVector, y::AbstractVector | ||
) | ||
Δ = x - y | ||
D = sqrt(sum(abs2, Δ)) | ||
if !iszero(D) | ||
Δ ./= D | ||
end | ||
return D, dot(Δ, Δx) - dot(Δ, Δy) | ||
end | ||
|
||
## Reverse Rules Delta | ||
|
||
function ChainRulesCore.rrule(dist::Delta, x::AbstractVector, y::AbstractVector) | ||
d = dist(x, y) | ||
function evaluate_pullback(::Any) | ||
return NO_FIELDS, Zero(), Zero() | ||
end | ||
return d, evaluate_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2 | ||
) | ||
P = Distances.pairwise(d, X, Y; dims=dims) | ||
function pairwise_pullback(::AbstractMatrix) | ||
return NO_FIELDS, NO_FIELDS, Zero(), Zero() | ||
end | ||
return P, pairwise_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2 | ||
) | ||
P = Distances.pairwise(d, X; dims=dims) | ||
function pairwise_pullback(::AbstractMatrix) | ||
return NO_FIELDS, NO_FIELDS, Zero() | ||
end | ||
return P, pairwise_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix | ||
) | ||
C = Distances.colwise(d, X, Y) | ||
function colwise_pullback(::AbstractVector) | ||
return NO_FIELDS, NO_FIELDS, Zero(), Zero() | ||
end | ||
return C, colwise_pullback | ||
end | ||
|
||
## Reverse Rules DotProduct | ||
|
||
function ChainRulesCore.rrule(dist::DotProduct, x::AbstractVector, y::AbstractVector) | ||
d = dist(x, y) | ||
function evaluate_pullback(Δ::Any) | ||
return NO_FIELDS, Δ .* y, Δ .* x | ||
end | ||
return d, evaluate_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), | ||
d::DotProduct, | ||
X::AbstractMatrix, | ||
Y::AbstractMatrix; | ||
dims=2, | ||
) | ||
P = Distances.pairwise(d, X, Y; dims=dims) | ||
function pairwise_pullback_cols(Δ::AbstractMatrix) | ||
if dims == 1 | ||
return NO_FIELDS, NO_FIELDS, Δ * Y, Δ' * X | ||
else | ||
return NO_FIELDS, NO_FIELDS, Y * Δ', X * Δ | ||
end | ||
end | ||
return P, pairwise_pullback_cols | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2 | ||
) | ||
P = Distances.pairwise(d, X; dims=dims) | ||
function pairwise_pullback_cols(Δ::AbstractMatrix) | ||
if dims == 1 | ||
return NO_FIELDS, NO_FIELDS, 2 * Δ * X | ||
else | ||
return NO_FIELDS, NO_FIELDS, 2 * X * Δ | ||
end | ||
end | ||
return P, pairwise_pullback_cols | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix | ||
) | ||
C = Distances.colwise(d, X, Y) | ||
function colwise_pullback(Δ::AbstractVector) | ||
return NO_FIELDS, NO_FIELDS, Δ' .* Y, Δ' .* X | ||
end | ||
return C, colwise_pullback | ||
end | ||
|
||
## Reverse Rules Sinus | ||
|
||
function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) | ||
d = x - y | ||
sind = sinpi.(d) | ||
abs2_sind_r = abs2.(sind) ./ s.r | ||
val = sum(abs2_sind_r) | ||
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) | ||
function evaluate_pullback(Δ::Any) | ||
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx | ||
end | ||
return val, evaluate_pullback | ||
end | ||
|
||
## Reverse Rulse SqMahalanobis | ||
|
||
function ChainRulesCore.rrule( | ||
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector | ||
) | ||
d = dist(a, b) | ||
function SqMahalanobis_pullback(Δ::Real) | ||
a_b = a - b | ||
∂qmat = InplaceableThunk( | ||
@thunk((a_b * a_b') * Δ), X̄ -> mul!(X̄, a_b, a_b', true, Δ) | ||
) | ||
∂a = InplaceableThunk( | ||
@thunk((2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, 2 * Δ) | ||
) | ||
∂b = InplaceableThunk( | ||
@thunk((-2 * Δ) * dist.qmat * a_b), X̄ -> mul!(X̄, dist.qmat, a_b, true, -2 * Δ) | ||
) | ||
return Composite{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b | ||
end | ||
return d, SqMahalanobis_pullback | ||
end | ||
|
||
## Reverse Rules for matrix wrappers | ||
|
||
function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) | ||
ColVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) | ||
function ColVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) | ||
return error( | ||
"Pullback on AbstractVector{<:AbstractVector}.\n" * | ||
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * | ||
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`", | ||
) | ||
end | ||
return ColVecs(X), ColVecs_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix) | ||
RowVecs_pullback(Δ::Composite) = (NO_FIELDS, Δ.X) | ||
function RowVecs_pullback(::AbstractVector{<:AbstractVector{<:Real}}) | ||
return error( | ||
"Pullback on AbstractVector{<:AbstractVector}.\n" * | ||
"This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n" * | ||
"To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`", | ||
) | ||
end | ||
return RowVecs(X), RowVecs_pullback | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
aa2099e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
aa2099e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/32800
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: