Skip to content

Commit

Permalink
Turn deps into weakdeps (#218)
Browse files Browse the repository at this point in the history
* add 1.6 and 1.9 to CI

* weakdep CovarianceEstimation

* weakdep Tracker

* weakdep ChainRulesCore

* weakdep AbstractFFTs

* bump
  • Loading branch information
aplavin authored Mar 2, 2024
1 parent cd36392 commit e55c11c
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 125 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ jobs:
fail-fast: false
matrix:
version:
- "1.6"
- "1.9"
- "1" # Latest Release
os:
- ubuntu-latest
Expand Down
19 changes: 17 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDims"
uuid = "356022a1-0364-5f58-8944-0da4b18d706f"
authors = ["Invenia Technical Computing Corporation"]
version = "1.2.1"
version = "1.2.2"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -12,6 +12,19 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
AbstractFFTsExt = "AbstractFFTs"
ChainRulesCoreExt = "ChainRulesCore"
CovarianceEstimationExt = "CovarianceEstimation"
TrackerExt = "Tracker"

[compat]
AbstractFFTs = "0.4, 0.5, 1"
BenchmarkTools = "0.5"
Expand All @@ -24,12 +37,14 @@ julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["BenchmarkTools", "ChainRulesTestUtils", "FFTW", "OffsetArrays", "SparseArrays", "Test", "Tracker"]
test = ["BenchmarkTools", "ChainRulesCore", "ChainRulesTestUtils", "CovarianceEstimation", "FFTW", "OffsetArrays", "SparseArrays", "Test", "Tracker"]
105 changes: 105 additions & 0 deletions ext/AbstractFFTsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
module AbstractFFTsExt

using AbstractFFTs
using NamedDims
using NamedDims: wave_name, _rename


################################################
# FFT

for fun in (:fft, :ifft, :bfft, :rfft, :irfft, :brfft)
plan_fun = Symbol(:plan_, fun)

if fun in (:irfft, :brfft) # These take one more argument, a size
arg, str = (:(d::Integer),), "d, "
else
arg, str = (), ""
end

@eval begin

"""
$($fun)(A, $($str):time => :freq, :x => :kx)
Acting on a `NamedDimsArray`, this specifies to take the transform along the dimensions
named `:time, :x`, and return an array with names `:freq` and `:kx` in their places.
$($fun)(A, $($str):x) # => :x∿
If new names are not given, then the default is `:x => :x∿` and `:x∿ => :x`,
applied to all dimensions, or to those specified as usual, e.g. `$($fun)(A, $($str)(1,2))`
or `$($fun)(A, $($str):time)`. The symbol "∿" can be typed by `\\sinewave<tab>`.
"""
function AbstractFFTs.$fun(A::NamedDimsArray{L}, $(arg...)) where {L}
data = AbstractFFTs.$fun(parent(A), $(arg...))
return NamedDimsArray(data, wave_name(L))
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), dims) where {L,T,N}
numerical_dims = dim(A, dims)
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = wave_name(L, numerical_dims)
return NamedDimsArray(data, newL)
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), p::Pair{Symbol,Symbol}, ps::Pair{Symbol,Symbol}...) where {L,T,N}
numerical_dims = dim(A, (first(p), first.(ps)...))
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = _rename(L, p, ps...)
return NamedDimsArray(data, newL)
end

"""
F = $($plan_fun)(A, $($str):time)
A∿ = F * A
A ≈ F \\ A∿ ≈ inv(F) * A∿
FFT plans for `NamedDimsArray`s, identical to `A∿ = $($fun)(A, $($str):time)`.
Note you cannot specify the final name, it always transforms `:time => :time∿`.
And that the plan `F` stores which dimension number to act on, not which name.
"""
function AbstractFFTs.$plan_fun(A::NamedDimsArray, $(arg...), dims = ntuple(identity, ndims(A)); kw...)
dims isa Pair && throw(ArgumentError("$($plan_fun) does not store final names, got Pair $dims"))
numerical_dims = Tuple(dim(A, dims))
AbstractFFTs.$plan_fun(parent(A), $(arg...), numerical_dims; kw...)
end
end

end

for shift in (:fftshift, :ifftshift)
@eval begin

function AbstractFFTs.$shift(A::NamedDimsArray)
data = AbstractFFTs.$shift(parent(A))
NamedDimsArray(data, dimnames(A))
end

function AbstractFFTs.$shift(A::NamedDimsArray, dims)
numerical_dims = dim(A, dims)
data = AbstractFFTs.$shift(parent(A), numerical_dims)
NamedDimsArray(data, dimnames(A))
end

end
end

# The dimensions on which plans act are not part of the type, unfortunately
for plan_type in (:Plan, :ScaledPlan)
@eval function Base.:*(plan::AbstractFFTs.$plan_type, A::NamedDimsArray{L,T,N}) where {L,T,N}
data = plan * parent(A)
if Base.sym_in(:region, propertynames(plan)) # true for plan_fft from FFTW
dims = plan.region # dims can be 1, (1,3) or 1:3
elseif Base.sym_in(:p, propertynames(plan))
dims = plan.p.region
else
return data
end
newL = ntuple(d -> d in dims ? wave_name(L[d]) : L[d], N)::NTuple{N,Symbol}
# newL = wave_name(L, Tuple(dims)) # this, using compile_time_return_hack, is much slower
return NamedDimsArray(data, newL)
end
end

end
8 changes: 8 additions & 0 deletions src/chainrules.jl → ext/ChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
module ChainRulesCoreExt

using ChainRulesCore
using NamedDims: NamedDimsArray, dimnames, unify_names


_NamedDimsArray_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
_NamedDimsArray_pullback(ȳ::Tangent) = (NoTangent(), ȳ.data, NoTangent())
_NamedDimsArray_pullback(ȳ::AbstractThunk) = _NamedDimsArray_pullback(unthunk(ȳ))
Expand All @@ -19,3 +25,5 @@ function (project::ProjectTo{NDA})(dx) where {NDA<:NamedDimsArray}
names = unify_names(dimnames(NDA), dimnames(dx))
return NamedDimsArray{names}(project.data(parent(dx)))
end

end
19 changes: 19 additions & 0 deletions ext/CovarianceEstimationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module CovarianceEstimationExt

using NamedDims
using CovarianceEstimation


for E in (:LinearShrinkage, :SimpleCovariance, :AnalyticalNonlinearShrinkage)
@eval function CovarianceEstimation.cov(
estimator::$E, a::NamedDimsArray{L,T,2}; dims=1, kwargs...
) where {L,T}
numerical_dims = dim(a, dims)
# cov returns a Symmetric matrix which needs to be rewrapped in a NamedDimsArray
data = cov(estimator, parent(a); dims=numerical_dims, kwargs...)
names = NamedDims.symmetric_names(L, numerical_dims)
return NamedDimsArray{names}(data)
end
end

end
7 changes: 6 additions & 1 deletion src/tracker_compat.jl → ext/TrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Tracker.jl Compat
module TrackerExt

isdefined(Base, :get_extension) ? (using Tracker) : (using ..Tracker)
using NamedDims: NamedDims, dimnames, NamedDimsStyle, NamedDimsArray, @declare_matmul

# The following blocks ever constructing TrackedArrays of NamedDimArrays.
# This is not strictly forbidden (thus is commented out) but is useful for debugging things
Expand Down Expand Up @@ -33,3 +36,5 @@ for f in (:forward, :back, :back!, :grad, :istracked, :tracker)
return Tracker.$f(parent(nda), args...)
end
end

end
24 changes: 14 additions & 10 deletions src/NamedDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,12 @@ module NamedDims
using Base: @propagate_inbounds
using Base.Broadcast:
Broadcasted, BroadcastStyle, DefaultArrayStyle, AbstractArrayStyle, Unknown
using ChainRulesCore
using CovarianceEstimation
using LinearAlgebra
using AbstractFFTs
using Pkg
using Requires
using Statistics

export NamedDimsArray, dim, rename, unname, dimnames

function __init__()
# NOTE: NamedDims is only compatible with Tracker v0.2.2; but no nice way to enforce that.
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker_compat.jl")
end

# We use CoVector to workout if we are taking the tranpose of a tranpose etc
const CoVector = Union{Adjoint{<:Any,<:AbstractVector},Transpose{<:Any,<:AbstractVector}}

Expand All @@ -26,7 +17,6 @@ include("show.jl")
include("name_operations.jl")

include("broadcasting.jl")
include("chainrules.jl")

include("functions.jl")
include("functions_dims.jl")
Expand All @@ -39,4 +29,18 @@ include("fft.jl")

include("functions_linearalgebra.jl")

@static if !isdefined(Base, :get_extension)
using Requires
end
@static if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsExt.jl")
include("../ext/ChainRulesCoreExt.jl")
include("../ext/CovarianceEstimationExt.jl")

function __init__()
# NOTE: NamedDims is only compatible with Tracker v0.2.2; but no nice way to enforce that.
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/TrackerExt.jl")
end
end

end # module
99 changes: 0 additions & 99 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -1,101 +1,3 @@

################################################
# FFT

for fun in (:fft, :ifft, :bfft, :rfft, :irfft, :brfft)
plan_fun = Symbol(:plan_, fun)

if fun in (:irfft, :brfft) # These take one more argument, a size
arg, str = (:(d::Integer),), "d, "
else
arg, str = (), ""
end

@eval begin

"""
$($fun)(A, $($str):time => :freq, :x => :kx)
Acting on a `NamedDimsArray`, this specifies to take the transform along the dimensions
named `:time, :x`, and return an array with names `:freq` and `:kx` in their places.
$($fun)(A, $($str):x) # => :x∿
If new names are not given, then the default is `:x => :x∿` and `:x∿ => :x`,
applied to all dimensions, or to those specified as usual, e.g. `$($fun)(A, $($str)(1,2))`
or `$($fun)(A, $($str):time)`. The symbol "∿" can be typed by `\\sinewave<tab>`.
"""
function AbstractFFTs.$fun(A::NamedDimsArray{L}, $(arg...)) where {L}
data = AbstractFFTs.$fun(parent(A), $(arg...))
return NamedDimsArray(data, wave_name(L))
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), dims) where {L,T,N}
numerical_dims = dim(A, dims)
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = wave_name(L, numerical_dims)
return NamedDimsArray(data, newL)
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), p::Pair{Symbol,Symbol}, ps::Pair{Symbol,Symbol}...) where {L,T,N}
numerical_dims = dim(A, (first(p), first.(ps)...))
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = _rename(L, p, ps...)
return NamedDimsArray(data, newL)
end

"""
F = $($plan_fun)(A, $($str):time)
A∿ = F * A
A ≈ F \\ A∿ ≈ inv(F) * A∿
FFT plans for `NamedDimsArray`s, identical to `A∿ = $($fun)(A, $($str):time)`.
Note you cannot specify the final name, it always transforms `:time => :time∿`.
And that the plan `F` stores which dimension number to act on, not which name.
"""
function AbstractFFTs.$plan_fun(A::NamedDimsArray, $(arg...), dims = ntuple(identity, ndims(A)); kw...)
dims isa Pair && throw(ArgumentError("$($plan_fun) does not store final names, got Pair $dims"))
numerical_dims = Tuple(dim(A, dims))
AbstractFFTs.$plan_fun(parent(A), $(arg...), numerical_dims; kw...)
end
end

end

for shift in (:fftshift, :ifftshift)
@eval begin

function AbstractFFTs.$shift(A::NamedDimsArray)
data = AbstractFFTs.$shift(parent(A))
NamedDimsArray(data, dimnames(A))
end

function AbstractFFTs.$shift(A::NamedDimsArray, dims)
numerical_dims = dim(A, dims)
data = AbstractFFTs.$shift(parent(A), numerical_dims)
NamedDimsArray(data, dimnames(A))
end

end
end

# The dimensions on which plans act are not part of the type, unfortunately
for plan_type in (:Plan, :ScaledPlan)
@eval function Base.:*(plan::AbstractFFTs.$plan_type, A::NamedDimsArray{L,T,N}) where {L,T,N}
data = plan * parent(A)
if Base.sym_in(:region, propertynames(plan)) # true for plan_fft from FFTW
dims = plan.region # dims can be 1, (1,3) or 1:3
elseif Base.sym_in(:p, propertynames(plan))
dims = plan.p.region
else
return data
end
newL = ntuple(d -> d in dims ? wave_name(L[d]) : L[d], N)::NTuple{N,Symbol}
# newL = wave_name(L, Tuple(dims)) # this, using compile_time_return_hack, is much slower
return NamedDimsArray(data, newL)
end
end

wave_name(s::Symbol) = wave_name(Val(s))

@generated function wave_name(::Val{sym}) where {sym}
Expand Down Expand Up @@ -123,4 +25,3 @@ end
wave_name(tup::Tuple, dims::Tuple) = wave_name(wave_name(tup, first(dims)), Base.tail(dims))
wave_name(tup::Tuple, dims::Tuple{}) = tup
# @btime NamedDims.wave_name((:k1, :k2, :k3), (1,3))

13 changes: 0 additions & 13 deletions src/functions_math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,6 @@ for fun in (:cor, :cov)
end
end

# CovarianceEstimation
for E in (:LinearShrinkage, :SimpleCovariance, :AnalyticalNonlinearShrinkage)
@eval function CovarianceEstimation.cov(
estimator::$E, a::NamedDimsArray{L,T,2}; dims=1, kwargs...
) where {L,T}
numerical_dims = dim(a, dims)
# cov returns a Symmetric matrix which needs to be rewrapped in a NamedDimsArray
data = cov(estimator, parent(a); dims=numerical_dims, kwargs...)
names = symmetric_names(L, numerical_dims)
return NamedDimsArray{names}(data)
end
end

function symmetric_names(L::Tuple{Symbol,Symbol}, dims::Integer)
# 0 Allocations. See `@btime (()-> symmetric_names((:foo, :bar), 1))()`
names = if dims == 1
Expand Down

2 comments on commit e55c11c

@mcabbott
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/102098

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.2.2 -m "<description of version>" e55c11cc3db94849dbf123e28e5636e423265648
git push origin v1.2.2

Please sign in to comment.