diff --git a/Project.toml b/Project.toml index cfb96db..9ae0213 100644 --- a/Project.toml +++ b/Project.toml @@ -2,6 +2,9 @@ name = "ProximalOperators" uuid = "a725b495-10eb-56fe-b38b-717eba820537" version = "0.16.1" +[workspace] +projects = ["docs", "test", "benchmark"] + [deps] IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -15,7 +18,7 @@ TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" IterativeSolvers = "0.8 - 0.9" LinearAlgebra = "1.4" OSQP = "0.3 - 0.8" -ProximalCore = "0.1" +ProximalCore = "0.2.0" SparseArrays = "1.4" SuiteSparse = "1.4" TSVD = "0.3 - 0.4" diff --git a/src/ProximalOperators.jl b/src/ProximalOperators.jl index 7b9575a..bd3fc6e 100644 --- a/src/ProximalOperators.jl +++ b/src/ProximalOperators.jl @@ -4,15 +4,27 @@ module ProximalOperators using LinearAlgebra import ProximalCore: prox, prox!, gradient, gradient! -import ProximalCore: is_convex, is_generalized_quadratic +import ProximalCore: + is_convex, + is_strongly_convex, + is_generalized_quadratic, + is_proximable, + is_separable, + is_singleton_indicator, + is_cone_indicator, + is_affine_indicator, + is_set_indicator, + is_smooth, + is_locally_smooth, + is_support -const RealOrComplex{R <: Real} = Union{R, Complex{R}} -const HermOrSym{T, S} = Union{Hermitian{T, S}, Symmetric{T, S}} -const RealBasedArray{R} = AbstractArray{C, N} where {C <: RealOrComplex{R}, N} -const TupleOfArrays{R} = Tuple{RealBasedArray{R}, Vararg{RealBasedArray{R}}} -const ArrayOrTuple{R} = Union{RealBasedArray{R}, TupleOfArrays{R}} -const TransposeOrAdjoint{M} = Union{Transpose{C,M} where C, Adjoint{C,M} where C} -const Maybe{T} = Union{T, Nothing} +const RealOrComplex{R<:Real} = Union{R,Complex{R}} +const HermOrSym{T,S} = Union{Hermitian{T,S},Symmetric{T,S}} +const RealBasedArray{R} = AbstractArray{C,N} where {C<:RealOrComplex{R},N} +const TupleOfArrays{R} = Tuple{RealBasedArray{R},Vararg{RealBasedArray{R}}} +const ArrayOrTuple{R} = Union{RealBasedArray{R},TupleOfArrays{R}} +const TransposeOrAdjoint{M} = Union{Transpose{C,M} where C,Adjoint{C,M} where C} +const Maybe{T} = Union{T,Nothing} export prox, prox!, gradient, gradient! @@ -23,7 +35,6 @@ include("utilities/linops.jl") include("utilities/symmetricpacked.jl") include("utilities/uniformarrays.jl") include("utilities/normdiff.jl") -include("utilities/traits.jl") # Basic functions diff --git a/src/calculus/conjugate.jl b/src/calculus/conjugate.jl index cfb1138..23313fa 100644 --- a/src/calculus/conjugate.jl +++ b/src/calculus/conjugate.jl @@ -20,14 +20,14 @@ struct Conjugate{T} end end -is_prox_accurate(::Type{Conjugate{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{Conjugate{T}}) where T = is_proximable(T) is_convex(::Type{Conjugate{T}}) where T = true -is_cone(::Type{Conjugate{T}}) where T = is_cone(T) && is_convex(T) +is_cone_indicator(::Type{Conjugate{T}}) where T = is_cone_indicator(T) && is_convex(T) is_smooth(::Type{Conjugate{T}}) where T = is_strongly_convex(T) is_strongly_convex(::Type{Conjugate{T}}) where T = is_smooth(T) is_generalized_quadratic(::Type{Conjugate{T}}) where T = is_generalized_quadratic(T) -is_set(::Type{Conjugate{T}}) where T = is_convex(T) && is_support(T) -is_positively_homogeneous(::Type{Conjugate{T}}) where T = is_convex(T) && is_set(T) +is_set_indicator(::Type{Conjugate{T}}) where T = is_convex(T) && is_support(T) +is_positively_homogeneous(::Type{Conjugate{T}}) where T = is_convex(T) && is_set_indicator(T) Conjugate(f::T) where T = Conjugate{T}(f) @@ -37,7 +37,7 @@ Conjugate(f::T) where T = Conjugate{T}(f) function prox!(y, g::Conjugate, x, gamma) # Moreau identity v = prox!(y, g.f, x/gamma, 1/gamma) - if is_set(g) + if is_set_indicator(g) v = real(eltype(x))(0) else v = real(dot(x, y)) - gamma * real(dot(y, y)) - v @@ -50,7 +50,7 @@ end function prox_naive(g::Conjugate, x, gamma) y, v = prox_naive(g.f, x/gamma, 1/gamma) - return x - gamma * y, if is_set(g) real(eltype(x))(0) else real(dot(x, y)) - gamma * real(dot(y, y)) - v end + return x - gamma * y, if is_set_indicator(g) real(eltype(x))(0) else real(dot(x, y)) - gamma * real(dot(y, y)) - v end end # TODO: hard-code conjugation rules? E.g. precompose/epicompose diff --git a/src/calculus/distL2.jl b/src/calculus/distL2.jl index c80fd79..552e131 100644 --- a/src/calculus/distL2.jl +++ b/src/calculus/distL2.jl @@ -14,7 +14,7 @@ struct DistL2{R, T} ind::T lambda::R function DistL2{R, T}(ind::T, lambda::R) where {R, T} - if !is_set(ind) + if !is_set_indicator(ind) error("`ind` must be a convex set") end if lambda <= 0 @@ -25,7 +25,7 @@ struct DistL2{R, T} end end -is_prox_accurate(::Type{DistL2{R, T}}) where {R, T} = is_prox_accurate(T) +is_proximable(::Type{DistL2{R, T}}) where {R, T} = is_proximable(T) is_convex(::Type{DistL2{R, T}}) where {R, T} = is_convex(T) DistL2(ind::T, lambda::R=1) where {R, T} = DistL2{R, T}(ind, lambda) diff --git a/src/calculus/pointwiseMinimum.jl b/src/calculus/pointwiseMinimum.jl index c00dd78..71ad866 100644 --- a/src/calculus/pointwiseMinimum.jl +++ b/src/calculus/pointwiseMinimum.jl @@ -17,8 +17,8 @@ PointwiseMinimum(fs...) = PointwiseMinimum{typeof(fs)}(fs) component_types(::Type{PointwiseMinimum{T}}) where T = fieldtypes(T) -@generated is_set(::Type{T}) where T <: PointwiseMinimum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: PointwiseMinimum = return all(is_cone, component_types(T)) ? :(true) : :(false) +@generated is_set_indicator(::Type{T}) where T <: PointwiseMinimum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false) +@generated is_cone_indicator(::Type{T}) where T <: PointwiseMinimum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false) function (g::PointwiseMinimum{T})(x) where T return minimum(f(x) for f in g.fs) diff --git a/src/calculus/postcompose.jl b/src/calculus/postcompose.jl index ec87066..30cc889 100644 --- a/src/calculus/postcompose.jl +++ b/src/calculus/postcompose.jl @@ -23,14 +23,15 @@ struct Postcompose{T, R, S} end end -is_prox_accurate(::Type{<:Postcompose{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Postcompose{T}}) where T = is_proximable(T) is_separable(::Type{<:Postcompose{T}}) where T = is_separable(T) is_convex(::Type{<:Postcompose{T}}) where T = is_convex(T) -is_set(::Type{<:Postcompose{T}}) where T = is_set(T) -is_singleton(::Type{<:Postcompose{T}}) where T = is_singleton(T) -is_cone(::Type{<:Postcompose{T}}) where T = is_cone(T) -is_affine(::Type{<:Postcompose{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:Postcompose{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:Postcompose{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:Postcompose{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:Postcompose{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:Postcompose{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Postcompose{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Postcompose{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Postcompose{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/precompose.jl b/src/calculus/precompose.jl index 6a3a277..a104ceb 100644 --- a/src/calculus/precompose.jl +++ b/src/calculus/precompose.jl @@ -37,13 +37,14 @@ struct Precompose{T, M, U, V} end end -is_prox_accurate(::Type{<:Precompose{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Precompose{T}}) where T = is_proximable(T) is_convex(::Type{<:Precompose{T}}) where T = is_convex(T) -is_set(::Type{<:Precompose{T}}) where T = is_set(T) -is_singleton(::Type{<:Precompose{T}}) where T = is_singleton(T) -is_cone(::Type{<:Precompose{T}}) where T = is_cone(T) -is_affine(::Type{<:Precompose{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:Precompose{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:Precompose{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:Precompose{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:Precompose{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:Precompose{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Precompose{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Precompose{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Precompose{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/precomposeDiagonal.jl b/src/calculus/precomposeDiagonal.jl index f33d2cf..0787458 100644 --- a/src/calculus/precomposeDiagonal.jl +++ b/src/calculus/precomposeDiagonal.jl @@ -32,13 +32,14 @@ struct PrecomposeDiagonal{T, R, S} end is_separable(::Type{<:PrecomposeDiagonal{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:PrecomposeDiagonal{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:PrecomposeDiagonal{T}}) where T = is_proximable(T) is_convex(::Type{<:PrecomposeDiagonal{T}}) where T = is_convex(T) -is_set(::Type{<:PrecomposeDiagonal{T}}) where T = is_set(T) -is_singleton(::Type{<:PrecomposeDiagonal{T}}) where T = is_singleton(T) -is_cone(::Type{<:PrecomposeDiagonal{T}}) where T = is_cone(T) -is_affine(::Type{<:PrecomposeDiagonal{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:PrecomposeDiagonal{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:PrecomposeDiagonal{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:PrecomposeDiagonal{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:PrecomposeDiagonal{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:PrecomposeDiagonal{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/regularize.jl b/src/calculus/regularize.jl index 826a984..3d266db 100644 --- a/src/calculus/regularize.jl +++ b/src/calculus/regularize.jl @@ -25,9 +25,10 @@ struct Regularize{T, S, A} end is_separable(::Type{<:Regularize{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:Regularize{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Regularize{T}}) where T = is_proximable(T) is_convex(::Type{<:Regularize{T}}) where T = is_convex(T) is_smooth(::Type{<:Regularize{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Regularize{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Regularize{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Regularize}) = true diff --git a/src/calculus/separableSum.jl b/src/calculus/separableSum.jl index 5142cf4..e8999e2 100644 --- a/src/calculus/separableSum.jl +++ b/src/calculus/separableSum.jl @@ -29,15 +29,16 @@ SeparableSum(fs::Vararg) = SeparableSum((fs...,)) component_types(::Type{SeparableSum{T}}) where T = fieldtypes(T) -@generated is_prox_accurate(::Type{T}) where T <: SeparableSum = return all(is_prox_accurate, component_types(T)) ? :(true) : :(false) -@generated is_convex(::Type{T}) where T <: SeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false) -@generated is_set(::Type{T}) where T <: SeparableSum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_singleton(::Type{T}) where T <: SeparableSum = return all(is_singleton, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: SeparableSum = return all(is_cone, component_types(T)) ? :(true) : :(false) -@generated is_affine(::Type{T}) where T <: SeparableSum = return all(is_affine, component_types(T)) ? :(true) : :(false) -@generated is_smooth(::Type{T}) where T <: SeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false) -@generated is_generalized_quadratic(::Type{T}) where T <: SeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) -@generated is_strongly_convex(::Type{T}) where T <: SeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false) +@generated is_proximable(::Type{T}) where T <: SeparableSum = return all(is_proximable, component_types(T)) ? true : false +@generated is_convex(::Type{T}) where T <: SeparableSum = return all(is_convex, component_types(T)) ? true : false +@generated is_set_indicator(::Type{T}) where T <: SeparableSum = return all(is_set_indicator, component_types(T)) ? true : false +@generated is_singleton_indicator(::Type{T}) where T <: SeparableSum = return all(is_singleton_indicator, component_types(T)) ? true : false +@generated is_cone_indicator(::Type{T}) where T <: SeparableSum = return all(is_cone_indicator, component_types(T)) ? true : false +@generated is_affine_indicator(::Type{T}) where T <: SeparableSum = return all(is_affine_indicator, component_types(T)) ? true : false +@generated is_smooth(::Type{T}) where T <: SeparableSum = return all(is_smooth, component_types(T)) ? true : false +@generated is_locally_smooth(::Type{T}) where T <: SeparableSum = return all(is_locally_smooth, component_types(T)) ? true : false +@generated is_generalized_quadratic(::Type{T}) where T <: SeparableSum = return all(is_generalized_quadratic, component_types(T)) ? true : false +@generated is_strongly_convex(::Type{T}) where T <: SeparableSum = return all(is_strongly_convex, component_types(T)) ? true : false (g::SeparableSum)(xs::Tuple) = sum(f(x) for (f, x) in zip(g.fs, xs)) diff --git a/src/calculus/slicedSeparableSum.jl b/src/calculus/slicedSeparableSum.jl index 7a40fab..aec56af 100644 --- a/src/calculus/slicedSeparableSum.jl +++ b/src/calculus/slicedSeparableSum.jl @@ -79,13 +79,14 @@ end component_types(::Type{SlicedSeparableSum{S, T, N}}) where {S, T, N} = Tuple(A.parameters[1] for A in fieldtypes(S)) -@generated is_prox_accurate(::Type{T}) where T <: SlicedSeparableSum = return all(is_prox_accurate, component_types(T)) ? :(true) : :(false) +@generated is_proximable(::Type{T}) where T <: SlicedSeparableSum = return all(is_proximable, component_types(T)) ? :(true) : :(false) @generated is_convex(::Type{T}) where T <: SlicedSeparableSum = return all(is_convex, component_types(T)) ? :(true) : :(false) -@generated is_set(::Type{T}) where T <: SlicedSeparableSum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_singleton(::Type{T}) where T <: SlicedSeparableSum = return all(is_singleton, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: SlicedSeparableSum = return all(is_cone, component_types(T)) ? :(true) : :(false) -@generated is_affine(::Type{T}) where T <: SlicedSeparableSum = return all(is_affine, component_types(T)) ? :(true) : :(false) +@generated is_set_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_set_indicator, component_types(T)) ? :(true) : :(false) +@generated is_singleton_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_singleton_indicator, component_types(T)) ? :(true) : :(false) +@generated is_cone_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_cone_indicator, component_types(T)) ? :(true) : :(false) +@generated is_affine_indicator(::Type{T}) where T <: SlicedSeparableSum = return all(is_affine_indicator, component_types(T)) ? :(true) : :(false) @generated is_smooth(::Type{T}) where T <: SlicedSeparableSum = return all(is_smooth, component_types(T)) ? :(true) : :(false) +@generated is_locally_smooth(::Type{T}) where T <: SlicedSeparableSum = return all(is_locally_smooth, component_types(T)) ? :(true) : :(false) @generated is_generalized_quadratic(::Type{T}) where T <: SlicedSeparableSum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) @generated is_strongly_convex(::Type{T}) where T <: SlicedSeparableSum = return all(is_strongly_convex, component_types(T)) ? :(true) : :(false) diff --git a/src/calculus/sum.jl b/src/calculus/sum.jl index 534725f..155f997 100644 --- a/src/calculus/sum.jl +++ b/src/calculus/sum.jl @@ -17,16 +17,17 @@ Sum(fs::Vararg) = Sum((fs...,)) component_types(::Type{Sum{T}}) where T = fieldtypes(T) -# note: is_prox_accurate false because prox in general doesn't exist? -is_prox_accurate(::Type{<:Sum}) = false -@generated is_convex(::Type{T}) where T <: Sum = return all(is_convex, component_types(T)) ? :(true) : :(false) -@generated is_set(::Type{T}) where T <: Sum = return all(is_set, component_types(T)) ? :(true) : :(false) -@generated is_singleton(::Type{T}) where T <: Sum = return all(is_singleton, component_types(T)) ? :(true) : :(false) -@generated is_cone(::Type{T}) where T <: Sum = return all(is_cone, component_types(T)) ? :(true) : :(false) -@generated is_affine(::Type{T}) where T <: Sum = return all(is_affine, component_types(T)) ? :(true) : :(false) -@generated is_smooth(::Type{T}) where T <: Sum = return all(is_smooth, component_types(T)) ? :(true) : :(false) -@generated is_generalized_quadratic(::Type{T}) where T <: Sum = return all(is_generalized_quadratic, component_types(T)) ? :(true) : :(false) -@generated is_strongly_convex(::Type{T}) where T <: Sum = return (all(is_convex, component_types(T)) && any(is_strongly_convex, component_types(T))) ? :(true) : :(false) +# note: is_proximable false because prox in general doesn't exist? +is_proximable(::Type{<:Sum}) = false +@generated is_convex(::Type{T}) where T <: Sum = return all(is_convex, component_types(T)) ? true : false +@generated is_set_indicator(::Type{T}) where T <: Sum = return all(is_set_indicator, component_types(T)) ? true : false +@generated is_singleton_indicator(::Type{T}) where T <: Sum = return all(is_singleton_indicator, component_types(T)) ? true : false +@generated is_cone_indicator(::Type{T}) where T <: Sum = return all(is_cone_indicator, component_types(T)) ? true : false +@generated is_affine_indicator(::Type{T}) where T <: Sum = return all(is_affine_indicator, component_types(T)) ? true : false +@generated is_smooth(::Type{T}) where T <: Sum = return all(is_smooth, component_types(T)) ? true : false +@generated is_locally_smooth(::Type{T}) where T <: Sum = return all(is_locally_smooth, component_types(T)) ? true : false +@generated is_generalized_quadratic(::Type{T}) where T <: Sum = return all(is_generalized_quadratic, component_types(T)) ? true : false +@generated is_strongly_convex(::Type{T}) where T <: Sum = return (all(is_convex, component_types(T)) && any(is_strongly_convex, component_types(T))) ? true : false function (sumobj::Sum)(x) sum = real(eltype(x))(0) diff --git a/src/calculus/tilt.jl b/src/calculus/tilt.jl index 5d7f40b..c5b7d76 100644 --- a/src/calculus/tilt.jl +++ b/src/calculus/tilt.jl @@ -17,10 +17,11 @@ struct Tilt{T, S, R} end is_separable(::Type{<:Tilt{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:Tilt{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Tilt{T}}) where T = is_proximable(T) is_convex(::Type{<:Tilt{T}}) where T = is_convex(T) -is_singleton(::Type{<:Tilt{T}}) where T = is_singleton(T) +is_singleton_indicator(::Type{<:Tilt{T}}) where T = is_singleton_indicator(T) is_smooth(::Type{<:Tilt{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Tilt{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Tilt{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Tilt{T}}) where T = is_strongly_convex(T) diff --git a/src/calculus/translate.jl b/src/calculus/translate.jl index 902afca..e00cc71 100644 --- a/src/calculus/translate.jl +++ b/src/calculus/translate.jl @@ -14,13 +14,14 @@ struct Translate{T, V} end is_separable(::Type{<:Translate{T}}) where T = is_separable(T) -is_prox_accurate(::Type{<:Translate{T}}) where T = is_prox_accurate(T) +is_proximable(::Type{<:Translate{T}}) where T = is_proximable(T) is_convex(::Type{<:Translate{T}}) where T = is_convex(T) -is_set(::Type{<:Translate{T}}) where T = is_set(T) -is_singleton(::Type{<:Translate{T}}) where T = is_singleton(T) -is_cone(::Type{<:Translate{T}}) where T = is_cone(T) -is_affine(::Type{<:Translate{T}}) where T = is_affine(T) +is_set_indicator(::Type{<:Translate{T}}) where T = is_set_indicator(T) +is_singleton_indicator(::Type{<:Translate{T}}) where T = is_singleton_indicator(T) +is_cone_indicator(::Type{<:Translate{T}}) where T = is_cone_indicator(T) +is_affine_indicator(::Type{<:Translate{T}}) where T = is_affine_indicator(T) is_smooth(::Type{<:Translate{T}}) where T = is_smooth(T) +is_locally_smooth(::Type{<:Translate{T}}) where T = is_locally_smooth(T) is_generalized_quadratic(::Type{<:Translate{T}}) where T = is_generalized_quadratic(T) is_strongly_convex(::Type{<:Translate{T}}) where T = is_strongly_convex(T) diff --git a/src/functions/elasticNet.jl b/src/functions/elasticNet.jl index 744eb58..d92eb39 100644 --- a/src/functions/elasticNet.jl +++ b/src/functions/elasticNet.jl @@ -24,7 +24,7 @@ struct ElasticNet{R, S} end is_separable(f::Type{<:ElasticNet}) = true -is_prox_accurate(f::Type{<:ElasticNet}) = true +is_proximable(f::Type{<:ElasticNet}) = true is_convex(f::Type{<:ElasticNet}) = true ElasticNet(mu::R=1, lambda::S=1) where {R, S} = ElasticNet{R, S}(mu, lambda) diff --git a/src/functions/indAffine.jl b/src/functions/indAffine.jl index 9dda046..fe2c1ac 100644 --- a/src/functions/indAffine.jl +++ b/src/functions/indAffine.jl @@ -10,7 +10,7 @@ export IndAffine abstract type IndAffine end -is_affine(f::Type{<:IndAffine}) = true +is_affine_indicator(f::Type{<:IndAffine}) = true is_generalized_quadratic(f::Type{<:IndAffine}) = true fun_name(f::IndAffine) = "Indicator of an affine subspace" diff --git a/src/functions/indAffineIterative.jl b/src/functions/indAffineIterative.jl index 6dd8d7e..d3b1a7c 100644 --- a/src/functions/indAffineIterative.jl +++ b/src/functions/indAffineIterative.jl @@ -15,7 +15,7 @@ struct IndAffineIterative{M, V} <: IndAffine end end -is_prox_accurate(f::Type{<:IndAffineIterative}) = false +is_proximable(f::Type{<:IndAffineIterative}) = false IndAffineIterative(A::M, b::V) where {M, V} = IndAffineIterative{M, V}(A, b) diff --git a/src/functions/indBallL0.jl b/src/functions/indBallL0.jl index b2377bc..1e28c26 100644 --- a/src/functions/indBallL0.jl +++ b/src/functions/indBallL0.jl @@ -22,7 +22,7 @@ struct IndBallL0{I} end end -is_set(f::Type{<:IndBallL0}) = true +is_set_indicator(f::Type{<:IndBallL0}) = true IndBallL0(r::I) where {I} = IndBallL0{I}(r) diff --git a/src/functions/indBallL1.jl b/src/functions/indBallL1.jl index b646dac..b37a060 100644 --- a/src/functions/indBallL1.jl +++ b/src/functions/indBallL1.jl @@ -23,8 +23,8 @@ struct IndBallL1{R} end is_convex(f::Type{<:IndBallL1}) = true -is_set(f::Type{<:IndBallL1}) = true -is_prox_accurate(f::Type{<:IndBallL1}) = false +is_set_indicator(f::Type{<:IndBallL1}) = true +is_proximable(f::Type{<:IndBallL1}) = false IndBallL1(r::R=1.0) where R = IndBallL1{R}(r) diff --git a/src/functions/indBallL2.jl b/src/functions/indBallL2.jl index 03dd0ab..0e3871e 100644 --- a/src/functions/indBallL2.jl +++ b/src/functions/indBallL2.jl @@ -23,7 +23,7 @@ struct IndBallL2{R} end is_convex(f::Type{<:IndBallL2}) = true -is_set(f::Type{<:IndBallL2}) = true +is_set_indicator(f::Type{<:IndBallL2}) = true IndBallL2(r::R=1) where R = IndBallL2{R}(r) diff --git a/src/functions/indBallRank.jl b/src/functions/indBallRank.jl index deeaca9..f799dd7 100644 --- a/src/functions/indBallRank.jl +++ b/src/functions/indBallRank.jl @@ -25,8 +25,8 @@ struct IndBallRank{I} end end -is_set(f::Type{<:IndBallRank}) = true -is_prox_accurate(f::Type{<:IndBallRank}) = false +is_set_indicator(f::Type{<:IndBallRank}) = true +is_proximable(f::Type{<:IndBallRank}) = false IndBallRank(r::I=1) where I = IndBallRank{I}(r) diff --git a/src/functions/indBinary.jl b/src/functions/indBinary.jl index 282d763..403d9b7 100644 --- a/src/functions/indBinary.jl +++ b/src/functions/indBinary.jl @@ -16,7 +16,7 @@ struct IndBinary{T, S} high::S end -is_set(f::Type{<:IndBinary}) = true +is_set_indicator(f::Type{<:IndBinary}) = true IndBinary() = IndBinary(0, 1) diff --git a/src/functions/indBox.jl b/src/functions/indBox.jl index 685fead..e757fef 100644 --- a/src/functions/indBox.jl +++ b/src/functions/indBox.jl @@ -28,7 +28,7 @@ end is_separable(f::Type{<:IndBox}) = true is_convex(f::Type{<:IndBox}) = true -is_set(f::Type{<:IndBox}) = true +is_set_indicator(f::Type{<:IndBox}) = true compatible_bounds(::Real, ::Real) = true compatible_bounds(::Real, ::AbstractArray) = true diff --git a/src/functions/indExp.jl b/src/functions/indExp.jl index 8dffcbf..09bb45e 100644 --- a/src/functions/indExp.jl +++ b/src/functions/indExp.jl @@ -14,7 +14,7 @@ C = \\mathrm{cl} \\{ (r,s,t) : s > 0, s⋅e^{r/s} \\leq t \\} \\subset \\mathbb{ struct IndExpPrimal end is_convex(f::Type{<:IndExpPrimal}) = true -is_cone(f::Type{<:IndExpPrimal}) = true +is_cone_indicator(f::Type{<:IndExpPrimal}) = true """ IndExpDual() diff --git a/src/functions/indFree.jl b/src/functions/indFree.jl index 1afb0ec..b5ef5f6 100644 --- a/src/functions/indFree.jl +++ b/src/functions/indFree.jl @@ -12,8 +12,8 @@ struct IndFree end is_separable(f::Type{<:IndFree}) = true is_convex(f::Type{<:IndFree}) = true -is_affine(f::Type{<:IndFree}) = true -is_cone(f::Type{<:IndFree}) = true +is_affine_indicator(f::Type{<:IndFree}) = true +is_cone_indicator(f::Type{<:IndFree}) = true is_smooth(f::Type{<:IndFree}) = true is_generalized_quadratic(f::Type{<:IndFree}) = true diff --git a/src/functions/indGraph.jl b/src/functions/indGraph.jl index c86c99e..aafdafe 100644 --- a/src/functions/indGraph.jl +++ b/src/functions/indGraph.jl @@ -29,8 +29,8 @@ function IndGraph(A::AbstractMatrix) end is_convex(f::Type{<:IndGraph}) = true -is_set(f::Type{<:IndGraph}) = true -is_cone(f::Type{<:IndGraph}) = true +is_set_indicator(f::Type{<:IndGraph}) = true +is_cone_indicator(f::Type{<:IndGraph}) = true IndGraph(a::AbstractVector) = IndGraph(a') diff --git a/src/functions/indHalfspace.jl b/src/functions/indHalfspace.jl index 8c3237b..d44a45d 100644 --- a/src/functions/indHalfspace.jl +++ b/src/functions/indHalfspace.jl @@ -26,7 +26,7 @@ end IndHalfspace(a::T, b::R) where {R, T} = IndHalfspace{R, T}(a, b) is_convex(f::Type{<:IndHalfspace}) = true -is_set(f::Type{<:IndHalfspace}) = true +is_set_indicator(f::Type{<:IndHalfspace}) = true function (f::IndHalfspace)(x) R = real(eltype(x)) diff --git a/src/functions/indHyperslab.jl b/src/functions/indHyperslab.jl index e32274e..9602324 100644 --- a/src/functions/indHyperslab.jl +++ b/src/functions/indHyperslab.jl @@ -27,7 +27,7 @@ end IndHyperslab(low::R, a::T, upp::R) where {R, T} = IndHyperslab{R, T}(low, a, upp) is_convex(f::Type{<:IndHyperslab}) = true -is_set(f::Type{<:IndHyperslab}) = true +is_set_indicator(f::Type{<:IndHyperslab}) = true function (f::IndHyperslab)(x) R = real(eltype(x)) diff --git a/src/functions/indNonnegative.jl b/src/functions/indNonnegative.jl index ba5b3bf..fc08a6c 100644 --- a/src/functions/indNonnegative.jl +++ b/src/functions/indNonnegative.jl @@ -14,7 +14,7 @@ struct IndNonnegative end is_separable(f::Type{<:IndNonnegative}) = true is_convex(f::Type{<:IndNonnegative}) = true -is_cone(f::Type{<:IndNonnegative}) = true +is_cone_indicator(f::Type{<:IndNonnegative}) = true function (::IndNonnegative)(x) R = eltype(x) diff --git a/src/functions/indNonpositive.jl b/src/functions/indNonpositive.jl index ba481e9..7dac78f 100644 --- a/src/functions/indNonpositive.jl +++ b/src/functions/indNonpositive.jl @@ -14,7 +14,7 @@ struct IndNonpositive end is_separable(f::Type{<:IndNonpositive}) = true is_convex(f::Type{<:IndNonpositive}) = true -is_cone(f::Type{<:IndNonpositive}) = true +is_cone_indicator(f::Type{<:IndNonpositive}) = true function (::IndNonpositive)(x) R = eltype(x) diff --git a/src/functions/indPSD.jl b/src/functions/indPSD.jl index 36e72a6..9ec5794 100644 --- a/src/functions/indPSD.jl +++ b/src/functions/indPSD.jl @@ -46,7 +46,7 @@ function (::IndPSD)(X::Union{Symmetric, Hermitian}) end is_convex(f::Type{<:IndPSD}) = true -is_cone(f::Type{<:IndPSD}) = true +is_cone_indicator(f::Type{<:IndPSD}) = true function prox!(Y::Union{Symmetric, Hermitian}, ::IndPSD, X::Union{Symmetric, Hermitian}, gamma) R = real(eltype(X)) diff --git a/src/functions/indPoint.jl b/src/functions/indPoint.jl index afea354..fe7f40f 100644 --- a/src/functions/indPoint.jl +++ b/src/functions/indPoint.jl @@ -20,8 +20,8 @@ end is_separable(f::Type{<:IndPoint}) = true is_convex(f::Type{<:IndPoint}) = true -is_singleton(f::Type{<:IndPoint}) = true -is_affine(f::Type{<:IndPoint}) = true +is_singleton_indicator(f::Type{<:IndPoint}) = true +is_affine_indicator(f::Type{<:IndPoint}) = true IndPoint(p::T=0) where T = IndPoint{T}(p) diff --git a/src/functions/indPolyhedral.jl b/src/functions/indPolyhedral.jl index 77c894a..8405c2c 100644 --- a/src/functions/indPolyhedral.jl +++ b/src/functions/indPolyhedral.jl @@ -3,7 +3,7 @@ export IndPolyhedral abstract type IndPolyhedral end is_convex(::Type{<:IndPolyhedral}) = true -is_set(::Type{<:IndPolyhedral}) = true +is_set_indicator(::Type{<:IndPolyhedral}) = true """ IndPolyhedral([l,] A, [u, xmin, xmax]) diff --git a/src/functions/indPolyhedralOSQP.jl b/src/functions/indPolyhedralOSQP.jl index 3d2b490..6fb67d7 100644 --- a/src/functions/indPolyhedralOSQP.jl +++ b/src/functions/indPolyhedralOSQP.jl @@ -24,7 +24,7 @@ end # properties -is_prox_accurate(::Type{<:IndPolyhedralOSQP}) = false +is_proximable(::Type{<:IndPolyhedralOSQP}) = false # constructors diff --git a/src/functions/indSOC.jl b/src/functions/indSOC.jl index 2f3e6da..55547f0 100644 --- a/src/functions/indSOC.jl +++ b/src/functions/indSOC.jl @@ -22,7 +22,7 @@ function (::IndSOC)(x) end is_convex(f::Type{<:IndSOC}) = true -is_cone(f::Type{<:IndSOC}) = true +is_cone_indicator(f::Type{<:IndSOC}) = true function prox!(y, ::IndSOC, x, gamma) T = eltype(x) @@ -84,7 +84,7 @@ function (::IndRotatedSOC)(x) end is_convex(f::IndRotatedSOC) = true -is_set(f::IndRotatedSOC) = true +is_set_indicator(f::IndRotatedSOC) = true function prox!(y, ::IndRotatedSOC, x, gamma) T = eltype(x) diff --git a/src/functions/indSimplex.jl b/src/functions/indSimplex.jl index 451423d..65f7bce 100644 --- a/src/functions/indSimplex.jl +++ b/src/functions/indSimplex.jl @@ -24,7 +24,7 @@ struct IndSimplex{R} end is_convex(f::Type{<:IndSimplex}) = true -is_set(f::Type{<:IndSimplex}) = true +is_set_indicator(f::Type{<:IndSimplex}) = true IndSimplex(a::R=1) where R = IndSimplex{R}(a) diff --git a/src/functions/indSphereL2.jl b/src/functions/indSphereL2.jl index ce56871..8a31f87 100644 --- a/src/functions/indSphereL2.jl +++ b/src/functions/indSphereL2.jl @@ -22,7 +22,7 @@ struct IndSphereL2{R} end end -is_set(f::Type{<:IndSphereL2}) = true +is_set_indicator(f::Type{<:IndSphereL2}) = true IndSphereL2(r::R=1) where R = IndSphereL2{R}(r) diff --git a/src/functions/indStiefel.jl b/src/functions/indStiefel.jl index 433cc3a..3da5725 100644 --- a/src/functions/indStiefel.jl +++ b/src/functions/indStiefel.jl @@ -14,7 +14,7 @@ are inferred from the matrix provided as input. """ struct IndStiefel end -is_set(f::Type{<:IndStiefel}) = true +is_set_indicator(f::Type{<:IndStiefel}) = true function (::IndStiefel)(X) R = real(eltype(X)) diff --git a/src/functions/indZero.jl b/src/functions/indZero.jl index 8876bb5..c95efdb 100644 --- a/src/functions/indZero.jl +++ b/src/functions/indZero.jl @@ -11,9 +11,9 @@ struct IndZero end is_separable(f::Type{<:IndZero}) = true is_convex(f::Type{<:IndZero}) = true -is_singleton(f::Type{<:IndZero}) = true -is_cone(f::Type{<:IndZero}) = true -is_affine(f::Type{<:IndZero}) = true +is_singleton_indicator(f::Type{<:IndZero}) = true +is_cone_indicator(f::Type{<:IndZero}) = true +is_affine_indicator(f::Type{<:IndZero}) = true function (::IndZero)(x) C = eltype(x) diff --git a/src/functions/leastSquaresIterative.jl b/src/functions/leastSquaresIterative.jl index 29739ea..a9a848a 100644 --- a/src/functions/leastSquaresIterative.jl +++ b/src/functions/leastSquaresIterative.jl @@ -16,7 +16,7 @@ struct LeastSquaresIterative{N, R, RC, M, V, O, IsConvex} <: LeastSquares q::Array{RC, N} # n (by-p) end -is_prox_accurate(f::Type{<:LeastSquaresIterative}) = false +is_proximable(f::Type{<:LeastSquaresIterative}) = false is_convex(::Type{LeastSquaresIterative{N, R, RC, M, V, O, IsConvex}}) where {N, R, RC, M, V, O, IsConvex} = IsConvex function LeastSquaresIterative(A::M, b, lambda) where M diff --git a/src/functions/logBarrier.jl b/src/functions/logBarrier.jl index d082664..66001e6 100644 --- a/src/functions/logBarrier.jl +++ b/src/functions/logBarrier.jl @@ -26,6 +26,7 @@ end is_separable(f::Type{<:LogBarrier}) = true is_convex(f::Type{<:LogBarrier}) = true +is_locally_smooth(f::Type{<:LogBarrier}) = true LogBarrier(a::R=1, b::S=0, mu::T=1) where {R, S, T} = LogBarrier{R, S, T}(a, b, mu) diff --git a/src/functions/logisticLoss.jl b/src/functions/logisticLoss.jl index 0681c42..0f60f91 100644 --- a/src/functions/logisticLoss.jl +++ b/src/functions/logisticLoss.jl @@ -27,7 +27,7 @@ LogisticLoss(y::T, mu::R=1) where {R, T} = LogisticLoss{T, R}(y, mu) is_separable(f::Type{<:LogisticLoss}) = true is_convex(f::Type{<:LogisticLoss}) = true is_smooth(f::Type{<:LogisticLoss}) = true -is_prox_accurate(f::Type{<:LogisticLoss}) = false +is_proximable(f::Type{<:LogisticLoss}) = false # f(x) = mu log(1 + exp(-y x)) diff --git a/src/functions/quadraticIterative.jl b/src/functions/quadraticIterative.jl index 6753d64..15c5c7b 100644 --- a/src/functions/quadraticIterative.jl +++ b/src/functions/quadraticIterative.jl @@ -9,7 +9,7 @@ struct QuadraticIterative{M, V} <: Quadratic temp::V end -is_prox_accurate(f::Type{<:QuadraticIterative}) = false +is_proximable(f::Type{<:QuadraticIterative}) = false function QuadraticIterative(Q::M, q::V) where {M, V} if size(Q, 1) != size(Q, 2) || length(q) != size(Q, 2) diff --git a/src/functions/sqrNormL2.jl b/src/functions/sqrNormL2.jl index 0069c59..bec760f 100644 --- a/src/functions/sqrNormL2.jl +++ b/src/functions/sqrNormL2.jl @@ -25,11 +25,12 @@ struct SqrNormL2{T,SC} end end -is_convex(f::Type{<:SqrNormL2}) = true -is_smooth(f::Type{<:SqrNormL2}) = true -is_separable(f::Type{<:SqrNormL2}) = true -is_generalized_quadratic(f::Type{<:SqrNormL2}) = true -is_strongly_convex(f::Type{SqrNormL2{T,SC}}) where {T,SC} = SC +is_proximable(::Type{<:SqrNormL2}) = true +is_convex(::Type{<:SqrNormL2}) = true +is_smooth(::Type{<:SqrNormL2}) = true +is_separable(::Type{<:SqrNormL2}) = true +is_generalized_quadratic(::Type{<:SqrNormL2}) = true +is_strongly_convex(::Type{SqrNormL2{T,SC}}) where {T,SC} = SC SqrNormL2(lambda::T=1) where T = SqrNormL2{T,all(lambda .> 0)}(lambda) diff --git a/src/functions/sumPositive.jl b/src/functions/sumPositive.jl index ac0f1d6..8089626 100644 --- a/src/functions/sumPositive.jl +++ b/src/functions/sumPositive.jl @@ -14,6 +14,7 @@ struct SumPositive end is_separable(f::Type{<:SumPositive}) = true is_convex(f::Type{<:SumPositive}) = true +is_positively_homogeneous(f::Type{<:SumPositive}) = true function (::SumPositive)(x) return sum(xi -> max(xi, eltype(x)(0)), x) diff --git a/src/utilities/traits.jl b/src/utilities/traits.jl deleted file mode 100644 index b5d46f5..0000000 --- a/src/utilities/traits.jl +++ /dev/null @@ -1,32 +0,0 @@ -is_prox_accurate(::Type) = true -is_prox_accurate(::T) where T = is_prox_accurate(T) - -is_separable(::Type) = false -is_separable(::T) where T = is_separable(T) - -is_singleton(::Type) = false -is_singleton(::T) where T = is_singleton(T) - -is_cone(::Type) = false -is_cone(::T) where T = is_cone(T) - -is_affine(T::Type) = is_singleton(T) -is_affine(::T) where T = is_affine(T) - -is_set(T::Type) = is_cone(T) || is_affine(T) -is_set(::T) where T = is_set(T) - -is_positively_homogeneous(T::Type) = is_cone(T) -is_positively_homogeneous(::T) where T = is_positively_homogeneous(T) - -is_support(T::Type) = is_convex(T) && is_positively_homogeneous(T) -is_support(::T) where T = is_support(T) - -is_smooth(::Type) = false -is_smooth(::T) where T = is_smooth(T) - -is_quadratic(T::Type) = is_generalized_quadratic(T) && is_smooth(T) -is_quadratic(::T) where T = is_quadratic(T) - -is_strongly_convex(::Type) = false -is_strongly_convex(::T) where T = is_strongly_convex(T) diff --git a/test/Project.toml b/test/Project.toml index 1b3a6c1..96a3119 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" +ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index f91f581..761c5c8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,16 +1,15 @@ using Test using ProximalOperators - -using ProximalOperators: - ArrayOrTuple, - is_prox_accurate, +using ProximalOperators: ArrayOrTuple +using ProximalCore: + is_proximable, is_separable, is_convex, - is_singleton, - is_cone, - is_affine, - is_set, + is_singleton_indicator, + is_cone_indicator, + is_affine_indicator, + is_set_indicator, is_smooth, is_quadratic, is_generalized_quadratic, @@ -49,19 +48,19 @@ function prox_test(f, x::ArrayOrTuple{R}, gamma=1) where R <: Real @test typeof(fy_naive) == R - rtol = if ProximalOperators.is_prox_accurate(f) sqrt(eps(R)) else 1e-4 end + rtol = if is_proximable(f) sqrt(eps(R)) else 1e-4 end - if ProximalOperators.is_convex(f) + if is_convex(f) @test all(isapprox.(y_prealloc, y, rtol=rtol, atol=100*eps(R))) @test all(isapprox.(y_naive, y, rtol=rtol, atol=100*eps(R))) - if ProximalOperators.is_set(f) + if is_set_indicator(f) @test fy_prealloc == 0 end @test isapprox(fy_prealloc, fy, rtol=rtol, atol=100*eps(R)) @test isapprox(fy_naive, fy, rtol=rtol, atol=100*eps(R)) end - if !ProximalOperators.is_set(f) || ProximalOperators.is_prox_accurate(f) + if !is_set_indicator(f) || is_proximable(f) f_at_y = call_test(f, y) if f_at_y !== nothing @test isapprox(f_at_y, fy, rtol=rtol, atol=100*eps(R)) @@ -88,10 +87,10 @@ function predicates_test(f) is_generalized_quadratic, is_quadratic, is_smooth, - is_singleton, - is_cone, - is_affine, - is_set, + is_singleton_indicator, + is_cone_indicator, + is_affine_indicator, + is_set_indicator, is_positively_homogeneous, is_support, ] @@ -104,9 +103,9 @@ function predicates_test(f) # quadratic => generalized_quadratic && smooth @test !is_quadratic(f) || (is_generalized_quadratic(f) && is_smooth(f)) # (singleton || cone || affine) => set - @test !(is_singleton(f) || is_cone(f) || is_affine(f)) || is_set(f) + @test !(is_singleton_indicator(f) || is_cone_indicator(f) || is_affine_indicator(f)) || is_set_indicator(f) # cone => positively homogeneous - @test !is_cone(f) || is_positively_homogeneous(f) + @test !is_cone_indicator(f) || is_positively_homogeneous(f) # (convex && positively homogeneous) <=> (convex && support) @test (is_convex(f) && is_positively_homogeneous(f)) == (is_convex(f) && is_support(f)) # strongly_convex => convex diff --git a/test/test_calls.jl b/test/test_calls.jl index ee24121..516d200 100644 --- a/test/test_calls.jl +++ b/test/test_calls.jl @@ -598,7 +598,7 @@ test_cases_spec = [ y, fy = prox_test(f, x, gam) ##### compute prox with multiple random gammas - if ProximalOperators.is_separable(f) + if is_separable(f) gam = real(T)(0.5) .+ 2 .* rand(real(T), size(x)) y, fy = prox_test(f, x, gam) end diff --git a/test/test_gradients.jl b/test/test_gradients.jl index aa70041..350b972 100644 --- a/test/test_gradients.jl +++ b/test/test_gradients.jl @@ -158,7 +158,7 @@ for i in eachindex(stuff) ∇f, fx = gradient_test(f, x) for k = 1:10 # Test conditions in different directions - if ProximalOperators.is_convex(f) + if is_convex(f) # Test ∇f is subgradient if typeof(f) <: CrossEntropy d = x.*(rand(Float64, size(x)).-1)./2 # assures 0 <= x+d <= 1 diff --git a/test/test_huberLoss.jl b/test/test_huberLoss.jl index b97c57f..12b67f7 100644 --- a/test/test_huberLoss.jl +++ b/test/test_huberLoss.jl @@ -8,9 +8,9 @@ f = HuberLoss(1.5, 0.7) predicates_test(f) -@test ProximalOperators.is_smooth(f) == true -@test ProximalOperators.is_quadratic(f) == false -@test ProximalOperators.is_set(f) == false +@test is_smooth(f) == true +@test is_quadratic(f) == false +@test is_set_indicator(f) == false x = randn(10) x = 1.6*x/norm(x) diff --git a/test/test_indAffine.jl b/test/test_indAffine.jl index 5111802..b8a2121 100644 --- a/test/test_indAffine.jl +++ b/test/test_indAffine.jl @@ -16,10 +16,10 @@ x = randn(n) predicates_test(f) -@test ProximalOperators.is_smooth(f) == false -@test ProximalOperators.is_quadratic(f) == false -@test ProximalOperators.is_generalized_quadratic(f) == true -@test ProximalOperators.is_set(f) == true +@test is_smooth(f) == false +@test is_quadratic(f) == false +@test is_generalized_quadratic(f) == true +@test is_set_indicator(f) == true call_test(f, x) y, fy = prox_test(f, x) diff --git a/test/test_indPolyhedral.jl b/test/test_indPolyhedral.jl index 97b4efb..0d0a0bb 100644 --- a/test/test_indPolyhedral.jl +++ b/test/test_indPolyhedral.jl @@ -30,8 +30,8 @@ p = similar(x) () -> IndPolyhedral(l, A, u, xmin, xmax), ] f = constr() - @test ProximalOperators.is_convex(f) == true - @test ProximalOperators.is_set(f) == true + @test is_convex(f) == true + @test is_set_indicator(f) == true fx = call_test(f, x) p, fp = prox_test(f, x) end diff --git a/test/test_leastSquares.jl b/test/test_leastSquares.jl index d3ced45..5c50c49 100644 --- a/test/test_leastSquares.jl +++ b/test/test_leastSquares.jl @@ -33,10 +33,10 @@ x = randn(T, shape_x...) f = LeastSquares(A, b, iterative=(mode == :iterative)) predicates_test(f) -@test ProximalOperators.is_smooth(f) == true -@test ProximalOperators.is_quadratic(f) == true -@test ProximalOperators.is_generalized_quadratic(f) == true -@test ProximalOperators.is_set(f) == false +@test is_smooth(f) == true +@test is_quadratic(f) == true +@test is_generalized_quadratic(f) == true +@test is_set_indicator(f) == false grad_fx, fx = gradient_test(f, x) lsres = A*x - b diff --git a/test/test_moreauEnvelope.jl b/test/test_moreauEnvelope.jl index 882c0f1..a200665 100644 --- a/test/test_moreauEnvelope.jl +++ b/test/test_moreauEnvelope.jl @@ -14,9 +14,9 @@ using LinearAlgebra predicates_test(g) - @test ProximalOperators.is_smooth(g) == true - @test ProximalOperators.is_quadratic(g) == false - @test ProximalOperators.is_set(g) == false + @test is_smooth(g) == true + @test is_quadratic(g) == false + @test is_set_indicator(g) == false x = R[1.0, 2.0, 3.0, 4.0, 5.0] @@ -40,9 +40,9 @@ end predicates_test(g) - @test ProximalOperators.is_smooth(g) == true - @test ProximalOperators.is_quadratic(g) == false - @test ProximalOperators.is_set(g) == false + @test is_smooth(g) == true + @test is_quadratic(g) == false + @test is_set_indicator(g) == false x = R[1.0, 2.0, 3.0, 4.0, 5.0] diff --git a/test/test_pointwiseMinimum.jl b/test/test_pointwiseMinimum.jl index 60c964f..eff5a1d 100644 --- a/test/test_pointwiseMinimum.jl +++ b/test/test_pointwiseMinimum.jl @@ -9,8 +9,8 @@ f = PointwiseMinimum(IndPoint(T[-1.0]), IndPoint(T[1.0])) x = T[0.1] predicates_test(f) -@test ProximalOperators.is_set(f) == true -@test ProximalOperators.is_cone(f) == false +@test is_set_indicator(f) == true +@test is_cone_indicator(f) == false y, fy = prox_test(f, x) @test all(y .== T[1.0]) diff --git a/test/test_precompose.jl b/test/test_precompose.jl index d57bdfb..1dc5b01 100644 --- a/test/test_precompose.jl +++ b/test/test_precompose.jl @@ -17,9 +17,9 @@ g = Precompose(f, Q, 1.0) predicates_test(g) -@test ProximalOperators.is_smooth(g) == false -@test ProximalOperators.is_quadratic(g) == false -@test ProximalOperators.is_set(g) == true +@test is_smooth(g) == false +@test is_quadratic(g) == false +@test is_set_indicator(g) == true x = randn(10) diff --git a/test/test_quadratic.jl b/test/test_quadratic.jl index 6112e64..400cf4a 100644 --- a/test/test_quadratic.jl +++ b/test/test_quadratic.jl @@ -17,9 +17,9 @@ f = Quadratic(Q, q) predicates_test(f) -@test ProximalOperators.is_smooth(f) == true -@test ProximalOperators.is_quadratic(f) == true -@test ProximalOperators.is_set(f) == false +@test is_smooth(f) == true +@test is_quadratic(f) == true +@test is_set_indicator(f) == false x = randn(n) diff --git a/test/test_results.jl b/test/test_results.jl index d10125a..f80f775 100644 --- a/test/test_results.jl +++ b/test/test_results.jl @@ -340,7 +340,7 @@ stuff = [ y, fy = prox_test(f, x, gamma) @test y ≈ ref_y - if ProximalOperators.is_prox_accurate(f) + if is_proximable(f) @test fy ≈ ref_fy end diff --git a/test/test_sum.jl b/test/test_sum.jl index 24364c2..a1cdb07 100644 --- a/test/test_sum.jl +++ b/test/test_sum.jl @@ -10,9 +10,9 @@ f = Sum(f1, f2) predicates_test(f) -@test ProximalOperators.is_quadratic(f) == true -@test ProximalOperators.is_strongly_convex(f) == true -@test ProximalOperators.is_set(f) == false +@test is_quadratic(f) == true +@test is_strongly_convex(f) == true +@test is_set_indicator(f) == false xtest = randn(10) @@ -33,9 +33,9 @@ g = Sum(g1, g2) predicates_test(g) -@test ProximalOperators.is_smooth(g) == false -@test ProximalOperators.is_strongly_convex(g) == true -@test ProximalOperators.is_set(g) == false +@test is_smooth(g) == false +@test is_strongly_convex(g) == true +@test is_set_indicator(g) == false xtest = randn(10)