forked from dingodb/dingo-store
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat][index] Add simd cpu instruction set. Automatically select the CPU
instruction set according to the running environment.
- Loading branch information
Showing
19 changed files
with
2,271 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule faiss
updated
37 files
Submodule hnswlib
updated
8 files
+1 −1 | README.md | |
+17 −7 | hnswlib/bruteforce.h | |
+1 −0 | hnswlib/hnswalg.h | |
+1 −0 | hnswlib/hnswlib.h | |
+85 −0 | hnswlib/hnswlibhook.h | |
+24 −7 | hnswlib/space_ip.h | |
+23 −7 | hnswlib/space_l2.h | |
+10 −33 | python_bindings/bindings.cpp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
// Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// Copyright (C) 2019-2023 Zilliz. All rights reserved. | ||
|
||
#if defined(__x86_64__) | ||
|
||
#include "simd/distances_avx.h" | ||
|
||
#include <immintrin.h> | ||
|
||
#include <cassert> | ||
|
||
namespace dingodb { | ||
|
||
#define ALIGNED(x) __attribute__((aligned(x))) | ||
|
||
// reads 0 <= d < 4 floats as __m128 | ||
static inline __m128 masked_read(int d, const float* x) { | ||
assert(0 <= d && d < 4); | ||
ALIGNED(16) float buf[4] = {0, 0, 0, 0}; | ||
switch (d) { | ||
case 3: | ||
buf[2] = x[2]; | ||
case 2: | ||
buf[1] = x[1]; | ||
case 1: | ||
buf[0] = x[0]; | ||
} | ||
return _mm_load_ps(buf); | ||
// cannot use AVX2 _mm_mask_set1_epi32 | ||
} | ||
|
||
float fvec_inner_product_avx(const float* x, const float* y, size_t d) { | ||
__m256 msum1 = _mm256_setzero_ps(); | ||
|
||
while (d >= 8) { | ||
__m256 mx = _mm256_loadu_ps(x); | ||
x += 8; | ||
__m256 my = _mm256_loadu_ps(y); | ||
y += 8; | ||
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my)); | ||
d -= 8; | ||
} | ||
|
||
__m128 msum2 = _mm256_extractf128_ps(msum1, 1); | ||
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); | ||
|
||
if (d >= 4) { | ||
__m128 mx = _mm_loadu_ps(x); | ||
x += 4; | ||
__m128 my = _mm_loadu_ps(y); | ||
y += 4; | ||
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my)); | ||
d -= 4; | ||
} | ||
|
||
if (d > 0) { | ||
__m128 mx = masked_read(d, x); | ||
__m128 my = masked_read(d, y); | ||
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my)); | ||
} | ||
|
||
msum2 = _mm_hadd_ps(msum2, msum2); | ||
msum2 = _mm_hadd_ps(msum2, msum2); | ||
return _mm_cvtss_f32(msum2); | ||
} | ||
|
||
float fvec_L2sqr_avx(const float* x, const float* y, size_t d) { | ||
__m256 msum1 = _mm256_setzero_ps(); | ||
|
||
while (d >= 8) { | ||
__m256 mx = _mm256_loadu_ps(x); | ||
x += 8; | ||
__m256 my = _mm256_loadu_ps(y); | ||
y += 8; | ||
const __m256 a_m_b1 = _mm256_sub_ps(mx, my); | ||
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1)); | ||
d -= 8; | ||
} | ||
|
||
__m128 msum2 = _mm256_extractf128_ps(msum1, 1); | ||
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); | ||
|
||
if (d >= 4) { | ||
__m128 mx = _mm_loadu_ps(x); | ||
x += 4; | ||
__m128 my = _mm_loadu_ps(y); | ||
y += 4; | ||
const __m128 a_m_b1 = _mm_sub_ps(mx, my); | ||
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1)); | ||
d -= 4; | ||
} | ||
|
||
if (d > 0) { | ||
__m128 mx = masked_read(d, x); | ||
__m128 my = masked_read(d, y); | ||
__m128 a_m_b1 = _mm_sub_ps(mx, my); | ||
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1)); | ||
} | ||
|
||
msum2 = _mm_hadd_ps(msum2, msum2); | ||
msum2 = _mm_hadd_ps(msum2, msum2); | ||
return _mm_cvtss_f32(msum2); | ||
} | ||
|
||
float fvec_L1_avx(const float* x, const float* y, size_t d) { | ||
__m256 msum1 = _mm256_setzero_ps(); | ||
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); | ||
|
||
while (d >= 8) { | ||
__m256 mx = _mm256_loadu_ps(x); | ||
x += 8; | ||
__m256 my = _mm256_loadu_ps(y); | ||
y += 8; | ||
const __m256 a_m_b = _mm256_sub_ps(mx, my); | ||
msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b)); | ||
d -= 8; | ||
} | ||
|
||
__m128 msum2 = _mm256_extractf128_ps(msum1, 1); | ||
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); | ||
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); | ||
|
||
if (d >= 4) { | ||
__m128 mx = _mm_loadu_ps(x); | ||
x += 4; | ||
__m128 my = _mm_loadu_ps(y); | ||
y += 4; | ||
const __m128 a_m_b = _mm_sub_ps(mx, my); | ||
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); | ||
d -= 4; | ||
} | ||
|
||
if (d > 0) { | ||
__m128 mx = masked_read(d, x); | ||
__m128 my = masked_read(d, y); | ||
__m128 a_m_b = _mm_sub_ps(mx, my); | ||
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); | ||
} | ||
|
||
msum2 = _mm_hadd_ps(msum2, msum2); | ||
msum2 = _mm_hadd_ps(msum2, msum2); | ||
return _mm_cvtss_f32(msum2); | ||
} | ||
|
||
float fvec_Linf_avx(const float* x, const float* y, size_t d) { | ||
__m256 msum1 = _mm256_setzero_ps(); | ||
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); | ||
|
||
while (d >= 8) { | ||
__m256 mx = _mm256_loadu_ps(x); | ||
x += 8; | ||
__m256 my = _mm256_loadu_ps(y); | ||
y += 8; | ||
const __m256 a_m_b = _mm256_sub_ps(mx, my); | ||
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); | ||
d -= 8; | ||
} | ||
|
||
__m128 msum2 = _mm256_extractf128_ps(msum1, 1); | ||
msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0)); | ||
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); | ||
|
||
if (d >= 4) { | ||
__m128 mx = _mm_loadu_ps(x); | ||
x += 4; | ||
__m128 my = _mm_loadu_ps(y); | ||
y += 4; | ||
const __m128 a_m_b = _mm_sub_ps(mx, my); | ||
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); | ||
d -= 4; | ||
} | ||
|
||
if (d > 0) { | ||
__m128 mx = masked_read(d, x); | ||
__m128 my = masked_read(d, y); | ||
__m128 a_m_b = _mm_sub_ps(mx, my); | ||
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); | ||
} | ||
|
||
msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); | ||
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1)); | ||
return _mm_cvtss_f32(msum2); | ||
} | ||
|
||
} // namespace dingodb | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// Copyright (C) 2019-2023 Zilliz. All rights reserved. | ||
|
||
#ifndef DINGODB_SIMD_DISTANCES_AVX_H_ | ||
#define DINGODB_SIMD_DISTANCES_AVX_H_ | ||
|
||
#include <cstddef> | ||
#include <cstdint> | ||
|
||
namespace dingodb { | ||
|
||
/// Squared L2 distance between two vectors | ||
float fvec_L2sqr_avx(const float* x, const float* y, size_t d); | ||
|
||
/// inner product | ||
float fvec_inner_product_avx(const float* x, const float* y, size_t d); | ||
|
||
/// L1 distance | ||
float fvec_L1_avx(const float* x, const float* y, size_t d); | ||
|
||
/// infinity distance | ||
float fvec_Linf_avx(const float* x, const float* y, size_t d); | ||
|
||
} // namespace dingodb | ||
|
||
#endif // DINGODB_SIMD_DISTANCES_AVX_H_ //NOLINT |
Oops, something went wrong.