forked from codeplaysoftware/portBLAS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
blas3_interface.h
86 lines (82 loc) · 4.26 KB
/
blas3_interface.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
/***************************************************************************
*
* @license
* Copyright (C) Codeplay Software Limited
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* For your convenience, a copy of the License has been included in this
* repository.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SYCL-BLAS: BLAS implementation using SYCL
*
* @filename blas3_interface.h
*
**************************************************************************/
#ifndef SYCL_BLAS_BLAS3_INTERFACE_H
#define SYCL_BLAS_BLAS3_INTERFACE_H
namespace blas {
namespace internal {
/*!
* @brief This is a top-level wrapper for GemmFactory, which provides a
* "standard" BLAS gemm interface.
*
* See the netlib blas interface documentation for more details of the hig
* level interface:
* http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
*/
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm(executor_t& ex, char _TransA,
char _TransB, index_t _M,
index_t _N, index_t _K,
element_t _alpha, container_0_t a_,
index_t _lda, container_1_t b_,
index_t _ldb, element_t _beta,
container_2_t _C, index_t _ldc);
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm_batched(
executor_t& ex, char _TransA, char _TransB, index_t _M, index_t _N,
index_t _K, element_t _alpha, container_0_t a_, index_t _lda,
container_1_t b_, index_t _ldb, element_t _beta, container_2_t _C,
index_t _ldc, index_t batch_size);
} // namespace internal
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm(executor_t& ex, char _TransA,
char _TransB, index_t _M,
index_t _N, index_t _K,
element_t _alpha, container_0_t a_,
index_t _lda, container_1_t b_,
index_t _ldb, element_t _beta,
container_2_t _C, index_t _ldc) {
return internal::_gemm(ex, _TransA, _TransB, _M, _N, _K, _alpha,
ex.get_policy_handler().get_buffer(a_), _lda,
ex.get_policy_handler().get_buffer(b_), _ldb, _beta,
ex.get_policy_handler().get_buffer(_C), _ldc);
}
template <typename executor_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
typename executor_t::policy_t::event_t _gemm_batched(
executor_t& ex, char _TransA, char _TransB, index_t _M, index_t _N,
index_t _K, element_t _alpha, container_0_t a_, index_t _lda,
container_1_t b_, index_t _ldb, element_t _beta, container_2_t _C,
index_t _ldc, index_t batch_size) {
return internal::_gemm_batched(ex, _TransA, _TransB, _M, _N, _K, _alpha,
ex.get_policy_handler().get_buffer(a_), _lda,
ex.get_policy_handler().get_buffer(b_), _ldb,
_beta, ex.get_policy_handler().get_buffer(_C),
_ldc, batch_size);
}
} // namespace blas
#endif // SYCL_BLAS_BLAS3_INTERFACE