Skip to content

Commit

Permalink
Add vector-vector and matrix-matrix Kronecker product (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio authored Dec 20, 2024
1 parent 8094ded commit 87b95a9
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ Manifest.toml

# MacOS generated files
*.DS_Store

/.vscode
77 changes: 77 additions & 0 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,80 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T}

Array(y)[]
end

## Kronecker product

function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2}, y::AbstractGPUVector{T3}) where {T1,T2,T3}
@assert length(z) == length(x) * length(y)

@kernel function kron_kernel!(z, @Const(x), @Const(y))
i, j = @index(Global, NTuple)

@inbounds z[(i - 1) * length(y) + j] = x[i] * y[j]
end

backend = KernelAbstractions.get_backend(z)
kernel = kron_kernel!(backend)

kernel(z, x, y, ndrange=(length(x), length(y)))

return z
end

function LinearAlgebra.kron(x::AbstractGPUVector{T1}, y::AbstractGPUVector{T2}) where {T1,T2}
T = promote_type(T1, T2)
z = similar(x, T, length(x) * length(y))
return LinearAlgebra.kron!(z, x, y)
end

trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$T}), T -> 'N', identity),
(T -> :(Transpose{$T, <:AbstractGPUMatrix{$T}}), T -> 'T', A -> :(parent($A))),
(T -> :(Adjoint{$T, <:AbstractGPUMatrix{$T}}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))

for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in trans_adj_wrappers
TypeA = wrapa(:(T1))
TypeB = wrapb(:(T2))
TypeC = :(AbstractGPUMatrix{T3})

@eval function LinearAlgebra.kron!(C::$TypeC, A::$TypeA, B::$TypeB) where {T1,T2,T3}
@assert size(C, 1) == size(A, 1) * size(B, 1)
@assert size(C, 2) == size(A, 2) * size(B, 2)

ta = $transa(T1)
tb = $transb(T2)

@kernel function kron_kernel!(C, @Const(A), @Const(B))
ai, aj = @index(Global, NTuple) # Indices in the result matrix

# lb1, lb2 = size(B) # Dimensions of B
lb1, lb2 = tb == 'N' ? size(B) : reverse(size(B))

# Map global indices (ai, aj) to submatrices of the Kronecker product
i_a = (ai - 1) ÷ lb1 + 1 # Corresponding row index in A
i_b = (ai - 1) % lb1 + 1 # Corresponding row index in B
j_a = (aj - 1) ÷ lb2 + 1 # Corresponding col index in A
j_b = (aj - 1) % lb2 + 1 # Corresponding col index in B

@inbounds begin
a_ij = ta == 'N' ? A[i_a, j_a] : (ta == 'T' ? A[j_a, i_a] : conj(A[j_a, i_a]))
b_ij = tb == 'N' ? B[i_b, j_b] : (tb == 'T' ? B[j_b, i_b] : conj(B[j_b, i_b]))

C[ai, aj] = a_ij * b_ij
end
end

backend = KernelAbstractions.get_backend(C)
kernel = kron_kernel!(backend)

kernel(C, $(unwrapa(:A)), $(unwrapb(:B)), ndrange=(size(C, 1), size(C, 2)))

return C
end

@eval function LinearAlgebra.kron(A::$TypeA, B::$TypeB) where {T1, T2}
T = promote_type(T1, T2)
size_C = (size(A, 1) * size(B, 1), size(A, 2) * size(B, 2))
C = similar(A, T, size_C...)
return kron!(C, A, B)
end
end
9 changes: 9 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@
@test iszero(A)
@test isone(A) == false
end

@testset "kron" begin
for T in eltypes
@test compare(kron, AT, rand(T, 32), rand(T, 64))
for opa in (identity, transpose, adjoint), opb in (identity, transpose, adjoint)
@test compare(kron, AT, opa(rand(T, 32, 64)), opb(rand(T, 128, 16)))
end
end
end
end

@testsuite "linalg/mul!/vector-matrix" (AT, eltypes)->begin
Expand Down

0 comments on commit 87b95a9

Please sign in to comment.