Skip to content

Commit

Permalink
weakdep AbstractFFTs
Browse files Browse the repository at this point in the history
  • Loading branch information
aplavin committed Mar 1, 2024
1 parent b2cea06 commit 91f38e4
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 111 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
AbstractFFTsExt = "AbstractFFTs"
ChainRulesCoreExt = "ChainRulesCore"
CovarianceEstimationExt = "CovarianceEstimation"
TrackerExt = "Tracker"
Expand Down
105 changes: 105 additions & 0 deletions ext/AbstractFFTsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
module AbstractFFTsExt

using AbstractFFTs
using NamedDims
using NamedDims: wave_name, _rename


################################################
# FFT

for fun in (:fft, :ifft, :bfft, :rfft, :irfft, :brfft)
plan_fun = Symbol(:plan_, fun)

if fun in (:irfft, :brfft) # These take one more argument, a size
arg, str = (:(d::Integer),), "d, "
else
arg, str = (), ""
end

@eval begin

"""
$($fun)(A, $($str):time => :freq, :x => :kx)
Acting on a `NamedDimsArray`, this specifies to take the transform along the dimensions
named `:time, :x`, and return an array with names `:freq` and `:kx` in their places.
$($fun)(A, $($str):x) # => :x∿
If new names are not given, then the default is `:x => :x∿` and `:x∿ => :x`,
applied to all dimensions, or to those specified as usual, e.g. `$($fun)(A, $($str)(1,2))`
or `$($fun)(A, $($str):time)`. The symbol "∿" can be typed by `\\sinewave<tab>`.
"""
function AbstractFFTs.$fun(A::NamedDimsArray{L}, $(arg...)) where {L}
data = AbstractFFTs.$fun(parent(A), $(arg...))
return NamedDimsArray(data, wave_name(L))
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), dims) where {L,T,N}
numerical_dims = dim(A, dims)
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = wave_name(L, numerical_dims)
return NamedDimsArray(data, newL)
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), p::Pair{Symbol,Symbol}, ps::Pair{Symbol,Symbol}...) where {L,T,N}
numerical_dims = dim(A, (first(p), first.(ps)...))
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = _rename(L, p, ps...)
return NamedDimsArray(data, newL)
end

"""
F = $($plan_fun)(A, $($str):time)
A∿ = F * A
A ≈ F \\ A∿ ≈ inv(F) * A∿
FFT plans for `NamedDimsArray`s, identical to `A∿ = $($fun)(A, $($str):time)`.
Note you cannot specify the final name, it always transforms `:time => :time∿`.
And that the plan `F` stores which dimension number to act on, not which name.
"""
function AbstractFFTs.$plan_fun(A::NamedDimsArray, $(arg...), dims = ntuple(identity, ndims(A)); kw...)
dims isa Pair && throw(ArgumentError("$($plan_fun) does not store final names, got Pair $dims"))
numerical_dims = Tuple(dim(A, dims))
AbstractFFTs.$plan_fun(parent(A), $(arg...), numerical_dims; kw...)
end
end

end

for shift in (:fftshift, :ifftshift)
@eval begin

function AbstractFFTs.$shift(A::NamedDimsArray)
data = AbstractFFTs.$shift(parent(A))
NamedDimsArray(data, dimnames(A))
end

function AbstractFFTs.$shift(A::NamedDimsArray, dims)
numerical_dims = dim(A, dims)
data = AbstractFFTs.$shift(parent(A), numerical_dims)
NamedDimsArray(data, dimnames(A))
end

end
end

# The dimensions on which plans act are not part of the type, unfortunately
for plan_type in (:Plan, :ScaledPlan)
@eval function Base.:*(plan::AbstractFFTs.$plan_type, A::NamedDimsArray{L,T,N}) where {L,T,N}
data = plan * parent(A)
if Base.sym_in(:region, propertynames(plan)) # true for plan_fft from FFTW
dims = plan.region # dims can be 1, (1,3) or 1:3
elseif Base.sym_in(:p, propertynames(plan))
dims = plan.p.region
else
return data

Check warning on line 97 in ext/AbstractFFTsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsExt.jl#L97

Added line #L97 was not covered by tests
end
newL = ntuple(d -> d in dims ? wave_name(L[d]) : L[d], N)::NTuple{N,Symbol}
# newL = wave_name(L, Tuple(dims)) # this, using compile_time_return_hack, is much slower
return NamedDimsArray(data, newL)
end
end

end
2 changes: 1 addition & 1 deletion src/NamedDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using Base: @propagate_inbounds
using Base.Broadcast:
Broadcasted, BroadcastStyle, DefaultArrayStyle, AbstractArrayStyle, Unknown
using LinearAlgebra
using AbstractFFTs
using Pkg
using Statistics

Expand Down Expand Up @@ -34,6 +33,7 @@ include("functions_linearalgebra.jl")
using Requires
end
@static if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsExt.jl")
include("../ext/ChainRulesCoreExt.jl")
include("../ext/CovarianceEstimationExt.jl")

Expand Down
99 changes: 0 additions & 99 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -1,101 +1,3 @@

################################################
# FFT

for fun in (:fft, :ifft, :bfft, :rfft, :irfft, :brfft)
plan_fun = Symbol(:plan_, fun)

if fun in (:irfft, :brfft) # These take one more argument, a size
arg, str = (:(d::Integer),), "d, "
else
arg, str = (), ""
end

@eval begin

"""
$($fun)(A, $($str):time => :freq, :x => :kx)
Acting on a `NamedDimsArray`, this specifies to take the transform along the dimensions
named `:time, :x`, and return an array with names `:freq` and `:kx` in their places.
$($fun)(A, $($str):x) # => :x∿
If new names are not given, then the default is `:x => :x∿` and `:x∿ => :x`,
applied to all dimensions, or to those specified as usual, e.g. `$($fun)(A, $($str)(1,2))`
or `$($fun)(A, $($str):time)`. The symbol "∿" can be typed by `\\sinewave<tab>`.
"""
function AbstractFFTs.$fun(A::NamedDimsArray{L}, $(arg...)) where {L}
data = AbstractFFTs.$fun(parent(A), $(arg...))
return NamedDimsArray(data, wave_name(L))
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), dims) where {L,T,N}
numerical_dims = dim(A, dims)
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = wave_name(L, numerical_dims)
return NamedDimsArray(data, newL)
end

function AbstractFFTs.$fun(A::NamedDimsArray{L,T,N}, $(arg...), p::Pair{Symbol,Symbol}, ps::Pair{Symbol,Symbol}...) where {L,T,N}
numerical_dims = dim(A, (first(p), first.(ps)...))
data = AbstractFFTs.$fun(parent(A), $(arg...), numerical_dims)
newL = _rename(L, p, ps...)
return NamedDimsArray(data, newL)
end

"""
F = $($plan_fun)(A, $($str):time)
A∿ = F * A
A ≈ F \\ A∿ ≈ inv(F) * A∿
FFT plans for `NamedDimsArray`s, identical to `A∿ = $($fun)(A, $($str):time)`.
Note you cannot specify the final name, it always transforms `:time => :time∿`.
And that the plan `F` stores which dimension number to act on, not which name.
"""
function AbstractFFTs.$plan_fun(A::NamedDimsArray, $(arg...), dims = ntuple(identity, ndims(A)); kw...)
dims isa Pair && throw(ArgumentError("$($plan_fun) does not store final names, got Pair $dims"))
numerical_dims = Tuple(dim(A, dims))
AbstractFFTs.$plan_fun(parent(A), $(arg...), numerical_dims; kw...)
end
end

end

for shift in (:fftshift, :ifftshift)
@eval begin

function AbstractFFTs.$shift(A::NamedDimsArray)
data = AbstractFFTs.$shift(parent(A))
NamedDimsArray(data, dimnames(A))
end

function AbstractFFTs.$shift(A::NamedDimsArray, dims)
numerical_dims = dim(A, dims)
data = AbstractFFTs.$shift(parent(A), numerical_dims)
NamedDimsArray(data, dimnames(A))
end

end
end

# The dimensions on which plans act are not part of the type, unfortunately
for plan_type in (:Plan, :ScaledPlan)
@eval function Base.:*(plan::AbstractFFTs.$plan_type, A::NamedDimsArray{L,T,N}) where {L,T,N}
data = plan * parent(A)
if Base.sym_in(:region, propertynames(plan)) # true for plan_fft from FFTW
dims = plan.region # dims can be 1, (1,3) or 1:3
elseif Base.sym_in(:p, propertynames(plan))
dims = plan.p.region
else
return data
end
newL = ntuple(d -> d in dims ? wave_name(L[d]) : L[d], N)::NTuple{N,Symbol}
# newL = wave_name(L, Tuple(dims)) # this, using compile_time_return_hack, is much slower
return NamedDimsArray(data, newL)
end
end

wave_name(s::Symbol) = wave_name(Val(s))

@generated function wave_name(::Val{sym}) where {sym}
Expand Down Expand Up @@ -123,4 +25,3 @@ end
wave_name(tup::Tuple, dims::Tuple) = wave_name(wave_name(tup, first(dims)), Base.tail(dims))
wave_name(tup::Tuple, dims::Tuple{}) = tup
# @btime NamedDims.wave_name((:k1, :k2, :k3), (1,3))

22 changes: 11 additions & 11 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ using Test
include("test_helpers.jl")

const testfiles = (
# "name_core.jl",
# "wrapper_array.jl",
# "name_operations.jl",
# "functions.jl",
# "functions_dims.jl",
# "functions_math.jl",
# "cat.jl",
# "functions_linearalgebra.jl",
# "broadcasting.jl",
"name_core.jl",
"wrapper_array.jl",
"name_operations.jl",
"functions.jl",
"functions_dims.jl",
"functions_math.jl",
"cat.jl",
"functions_linearalgebra.jl",
"broadcasting.jl",
"chainrules.jl",
# "fft.jl",
# "tracker_compat.jl",
"fft.jl",
"tracker_compat.jl",
)

@testset "NamedDims.jl" begin
Expand Down

0 comments on commit 91f38e4

Please sign in to comment.