diff --git a/src/transform/transform.jl b/src/transform/transform.jl index 40ce8c058..c7da2729d 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -5,6 +5,9 @@ Abstract type defining a transformation of the input. """ abstract type Transform end +# We introduce our own _map for Transform so that we can work around +# https://github.com/FluxML/Zygote.jl/issues/646 and define our own pullback +# (see zygoterules.jl) Base.map(t::Transform, x::AbstractVector) = _map(t, x) _map(t::Transform, x::AbstractVector) = t.(x)