This repository has been archived by the owner on Mar 12, 2021. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 83
/
highlevel.jl
135 lines (112 loc) · 4.82 KB
/
highlevel.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import Base: A_mul_B!, At_mul_B!, A_mul_Bt!, Ac_mul_B!, A_mul_Bc!, At_mul_Bt!, Ac_mul_Bc!, At_mul_Bt!
cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1))
CublasArray{T<:CublasFloat} = CuArray{T}
###########
#
# BLAS 1
#
###########
Base.scale!(x::CuArray{T}, k::Number) where T<:CublasFloat =
scal!(length(x), convert(eltype(x), k), x, 1)
# Work around ambiguity with GPUArrays wrapper
Base.scale!(x::CuArray{T}, k::Real) where T<:CublasFloat =
invoke(scale!, (typeof(x), Number), x, k)
function Base.BLAS.dot(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{Float32,Float64}
n = length(DX)
n==length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
dot(n, DX, 1, DY, 1)
end
function Base.BLAS.dotc(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{Complex64,Complex128}
n = length(DX)
n==length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
dotc(n, DX, 1, DY, 1)
end
function Base.BLAS.dot(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{Complex64,Complex128}
Base.BLAS.dotc(DX, DY)
end
function Base.BLAS.dotu(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{Complex64,Complex128}
n = length(DX)
n==length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
dotu(n, DX, 1, DY, 1)
end
Base.At_mul_B(x::CuVector{T}, y::CuVector{T}) where T<:CublasReal = Base.BLAS.dot(x, y)
Base.norm(x::CublasArray) = nrm2(x)
Base.BLAS.asum(x::CublasArray) = asum(length(x), x, 1)
function Base.axpy!(alpha::Number, x::CuArray{T}, y::CuArray{T}) where T<:CublasFloat
length(x)==length(y) || throw(DimensionMismatch(""))
axpy!(length(x), convert(T,alpha), x, 1, y, 1)
end
Base.indmin(xs::CublasArray{T}) where T <: CublasReal = iamin(xs)
Base.indmax(xs::CublasArray{T}) where T <: CublasReal = iamax(xs)
############
#
# BLAS 2
#
############
#########
# GEMV
##########
function gemv_wrapper!(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
alpha = one(T), beta = zero(T)) where T<:CublasFloat
mA, nA = cublas_size(tA, A)
if nA != length(x)
throw(DimensionMismatch("second dimension of A, $nA, does not match length of x, $(length(x))"))
end
if mA != length(y)
throw(DimensionMismatch("first dimension of A, $mA, does not match length of y, $(length(y))"))
end
if mA == 0
return y
end
if nA == 0
return scale!(y, 0)
end
gemv!(tA, alpha, A, x, beta, y)
end
A_mul_B!(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(y, 'N', A, x)
At_mul_B!(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(y, 'T', A, x)
Ac_mul_B!(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(y, 'T', A, x)
Ac_mul_B!(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) where T<:CublasComplex = gemv_wrapper!(y, 'C', A, x)
############
#
# BLAS 3
#
############
########
# GEMM
########
function gemm_wrapper!(C::CuVecOrMat{T}, tA::Char, tB::Char,
A::CuVecOrMat{T},
B::CuVecOrMat{T},
alpha = one(T),
beta = zero(T)) where T <: CublasFloat
mA, nA = cublas_size(tA, A)
mB, nB = cublas_size(tB, B)
if nA != mB
throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
end
if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end
if mA == 0 || nA == 0 || nB == 0
if size(C) != (mA, nB)
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
end
return scale!(C, 0)
end
gemm!(tA, tB, alpha, A, B, beta, C)
end
# Mutating
A_mul_B!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where T<:CublasFloat = gemm_wrapper!(C, 'N', 'N', A, B)
At_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'N', A, B)
A_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'N', 'T', A, B)
At_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'T', A, B)
Ac_mul_B!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where T<:CublasReal = At_mul_B!(C, A, B)
Ac_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'C', 'N', A, B)
A_mul_Bc!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where T<:CublasReal = A_mul_Bt!(C, A, B)
A_mul_Bc!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'N', 'C', A, B)
Ac_mul_Bc!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where T<:CublasReal = At_mul_Bt!(C, A, B)
Ac_mul_Bc!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'C', 'C', A, B)
function A_mul_B!(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T}) where T
gemm_wrapper!(C, 'N', 'N', A, B)
end