diff --git a/ext/ForwardDiffStaticArraysExt.jl b/ext/ForwardDiffStaticArraysExt.jl index ec8cbc83..bf0ef99a 100644 --- a/ext/ForwardDiffStaticArraysExt.jl +++ b/ext/ForwardDiffStaticArraysExt.jl @@ -21,8 +21,6 @@ using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult end end -@inline static_dual_eval(::Type{T}, f::F, x::StaticArray) where {T,F} = f(dualize(T, x)) - # To fix method ambiguity issues: function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N} return ForwardDiff._eigvals(A) @@ -54,12 +52,12 @@ end @inline function ForwardDiff.vector_mode_gradient(f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) - return extract_gradient(T, static_dual_eval(T, f, x), x) + return extract_gradient(T, f(dualize(T, x)), x) end @inline function ForwardDiff.vector_mode_gradient!(result, f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) - return extract_gradient!(T, result, static_dual_eval(T, f, x)) + return extract_gradient!(T, result, f(dualize(T, x))) end # Jacobian @@ -84,7 +82,7 @@ end @inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) - return extract_jacobian(T, static_dual_eval(T, f, x), x) + return extract_jacobian(T, f(dualize(T, x)), x) end function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T @@ -94,7 +92,7 @@ end @inline function ForwardDiff.vector_mode_jacobian!(result, f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) - ydual = static_dual_eval(T, f, x) + ydual = f(dualize(T, x)) result = extract_jacobian!(T, result, ydual, length(x)) result = extract_value!(T, result, ydual) return result @@ -102,7 +100,7 @@ end @inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) - ydual = static_dual_eval(T, f, x) + ydual = f(dualize(T, x)) result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x)) result = DiffResults.value!(Base.Fix1(value, T), result, ydual) return result