From 364a413942c7a80ce1abbe075dea4664b71680f8 Mon Sep 17 00:00:00 2001 From: Jamie Date: Fri, 16 Aug 2024 13:02:41 +0800 Subject: [PATCH] repo-sync-2024-08-16T11:45:41+0800 (#373) --- ALGORITHMS.md | 41 +-- README.md | 2 +- bazel/openssl.BUILD | 1 + examples/psu/BUILD.bazel | 2 +- examples/psu/krtw19_psu.cc | 11 +- examples/psu/krtw19_psu.h | 2 +- yacl/crypto/experimental/dpf/BUILD.bazel | 52 +++ yacl/crypto/experimental/dpf/dcf.cc | 267 ++++++++++++++ yacl/crypto/experimental/dpf/dcf.h | 113 ++++++ yacl/crypto/experimental/dpf/dcf_test.cc | 86 +++++ yacl/crypto/experimental/dpf/dpf.cc | 290 ++++++++------- yacl/crypto/experimental/dpf/dpf.h | 200 ++++------- yacl/crypto/experimental/dpf/dpf_test.cc | 169 +++------ yacl/crypto/experimental/dpf/ge2n.h | 104 ++++++ yacl/crypto/experimental/dpf/pprf.cc | 155 ++++++++ yacl/crypto/experimental/dpf/pprf.h | 70 ++++ yacl/crypto/experimental/dpf/pprf_test.cc | 49 +++ yacl/crypto/rand/drbg/native_factory.cc | 3 + yacl/crypto/rand/drbg/native_factory.h | 2 - yacl/crypto/rand/drbg/openssl_factory.cc | 3 + yacl/crypto/rand/drbg/openssl_factory.h | 2 - yacl/crypto/rand/rand.cc | 12 +- yacl/crypto/rand/rand.h | 1 + yacl/io/circuit/bristol_fashion.cc | 9 +- yacl/kernel/algorithms/BUILD.bazel | 22 +- yacl/kernel/algorithms/base_vole.h | 10 +- yacl/kernel/algorithms/base_vole_test.cc | 2 +- yacl/kernel/algorithms/ferret_ote_rn.h | 6 +- yacl/kernel/algorithms/kos_ote.cc | 8 +- yacl/kernel/algorithms/mp_vole.h | 2 +- yacl/kernel/algorithms/mp_vole_test.cc | 2 +- yacl/kernel/algorithms/mpfss.cc | 2 +- yacl/kernel/algorithms/mpfss_test.cc | 2 +- yacl/kernel/algorithms/silent_vole_test.cc | 2 +- yacl/kernel/algorithms/softspoken_ote.cc | 26 +- yacl/math/f2k/BUILD.bazel | 54 --- yacl/math/f2k/f2k.h | 317 ----------------- yacl/math/f2k/f2k_bench.cc | 221 ------------ yacl/math/f2k/f2k_utils.h | 107 ------ yacl/math/galois_field/BUILD.bazel | 31 +- yacl/math/galois_field/factory/gf_vector.h | 46 +-- yacl/math/galois_field/gf_intrinsic.cc | 332 ++++++++++++++++++ yacl/math/galois_field/gf_intrinsic.h | 190 ++++++++++ .../gf_intrinsic_test.cc} | 119 +++---- yacl/utils/serializer_adapter.h | 2 + yacl/utils/spi/argument/arg_k.h | 5 +- yacl/utils/spi/argument/arg_kv.h | 5 +- 47 files changed, 1891 insertions(+), 1268 deletions(-) create mode 100644 yacl/crypto/experimental/dpf/dcf.cc create mode 100644 yacl/crypto/experimental/dpf/dcf.h create mode 100644 yacl/crypto/experimental/dpf/dcf_test.cc create mode 100644 yacl/crypto/experimental/dpf/ge2n.h create mode 100644 yacl/crypto/experimental/dpf/pprf.cc create mode 100644 yacl/crypto/experimental/dpf/pprf.h create mode 100644 yacl/crypto/experimental/dpf/pprf_test.cc delete mode 100644 yacl/math/f2k/BUILD.bazel delete mode 100644 yacl/math/f2k/f2k.h delete mode 100644 yacl/math/f2k/f2k_bench.cc delete mode 100644 yacl/math/f2k/f2k_utils.h create mode 100644 yacl/math/galois_field/gf_intrinsic.cc create mode 100644 yacl/math/galois_field/gf_intrinsic.h rename yacl/math/{f2k/f2k_test.cc => galois_field/gf_intrinsic_test.cc} (51%) diff --git a/ALGORITHMS.md b/ALGORITHMS.md index c2fa4b6f..02f7d387 100644 --- a/ALGORITHMS.md +++ b/ALGORITHMS.md @@ -1,8 +1,6 @@ -# Supported Crypto Algorithms +# Supported Advanced Crypto Algorithms -## Primitives - -### Oblivious Transfer and Extensions +## Oblivious Transfer and Extensions - The Simplest Protocol for Oblivious Transfer\ *Tung Chou, Claudio Orlandi*\ @@ -36,7 +34,7 @@ *Lawrence Roy*\ Crypto 2022, [publisher](https://www.iacr.org/cryptodb//data/paper.php?pubkey=32258), Roy22 -### Vector Oblivious Linear Evaluation (over Field 2k) +## Vector Oblivious Linear Evaluation (over Field 2k) Base VOLE: @@ -58,8 +56,13 @@ Silent VOLE: *Elette Boyle, Geoffroy Couteau, Niv Gilboa, Yuval Ishai, Lisa Kohl, Nicolas Resch, Peter Scholl*\ Crypto 2022, [eprint](https://eprint.iacr.org/2022/1014), BCG+22 +Subfield VOLE: + +- Wolverine: Fast, Scalable, and Communication-Efficient Zero-Knowledge Proofs for Boolean and Arithmetic Circuits\ + *Chenkai Weng, Kang Yang, Jonathan Katz, Xiao Wang* + SP, 2021, [eprint](https://eprint.iacr.org/2020/925), WYKW21 -### Codes +## Codes Local Linear Code @@ -80,19 +83,13 @@ Expanding Accumulation Code Crypto 2022, [eprint](https://eprint.iacr.org/2022/1014), BCG+22 -## Theoretical Tools - -Random Oracle (RO) - -- TBD +## Distributed Point Functions -Random Permutation (RP) +- Function secret sharing: improvements and extensions\ + *Elette Boyle, Niv Gilboa, Yuval Ishai*\ + CCS 2016, [eprint](https://eprint.iacr.org/2018/707), BGI16 -- TBD - -Pseudorandom Generator (PRG) - -- TBD +## Theoretical Tools Correlation-Robust Hash (CrHash) @@ -106,13 +103,3 @@ Circular Correlation-Robust Hash (CCR Hash) *Chun Guo, Jonathan Katz, Xiao Wang, Yu Yu*\ Preprint 2019, [eprint](https://eprint.iacr.org/2019/074), GKWY19 -## Basic (Traditional) algorithms (TBD) - -- AEAD -- AES -- Block Cipher -- ECC -- Hash -- HMAC -- Public-Key Encryption: RSA, SM2 -- Digital Signature: RSA, SM2 diff --git a/README.md b/README.md index 9ae0c02a..4c6e0411 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Supported platforms: ## Getting Started -Yacl uses the [bazel](https://bazel.build/) build system, you may use the following codes to build and test yacl modules. For more guidelines about how to **do crypto research on Yacl**, **use Yacl's shipped crypto tools**, or **integrate Yacl into your system**, please check the [Getting Started Guide](GETTING_STARTED.md). +Yacl uses the [bazel](https://bazel.build/) build system, you may use the following codes to build and test yacl modules. For more guidelines about **how to develop on yacl**, please check the [Getting Started Guide](GETTING_STARTED.md). ## License diff --git a/bazel/openssl.BUILD b/bazel/openssl.BUILD index 37442510..15210084 100644 --- a/bazel/openssl.BUILD +++ b/bazel/openssl.BUILD @@ -59,6 +59,7 @@ yacl_configure_make( }), lib_name = "openssl", lib_source = ":all_srcs", + linkopts = ["-ldl"], # Note that for Linux builds, libssl must come before libcrypto on the linker command-line. # As such, libssl must be listed before libcrypto out_static_libs = [ diff --git a/examples/psu/BUILD.bazel b/examples/psu/BUILD.bazel index 64954c26..fa40f93f 100644 --- a/examples/psu/BUILD.bazel +++ b/examples/psu/BUILD.bazel @@ -34,7 +34,7 @@ yacl_cc_library( "//yacl/kernel/algorithms:kkrt_ote", "//yacl/kernel/algorithms:softspoken_ote", "//yacl/link", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", "@com_google_absl//absl/types:span", ], ) diff --git a/examples/psu/krtw19_psu.cc b/examples/psu/krtw19_psu.cc index 8dec7211..19817cc1 100644 --- a/examples/psu/krtw19_psu.cc +++ b/examples/psu/krtw19_psu.cc @@ -55,7 +55,7 @@ auto HashInputs(const std::vector& elem_hashes, size_t count) { uint64_t Evaluate(const std::vector& coeffs, uint64_t x) { uint64_t y = coeffs.back(); for (auto it = std::next(coeffs.rbegin()); it != coeffs.rend(); ++it) { - y = yacl::GfMul64(y, x) ^ *it; + y = yacl::math::Gf64Mul(y, x) ^ *it; } return y; } @@ -71,7 +71,7 @@ std::vector Interpolate(const std::vector& xs, for (size_t j = 0; j < size; ++j) { uint64_t sum = 0; for (size_t k = 0; k <= j + 1; ++k) { - sum = std::exchange(poly[k], yacl::GfMul64(poly[k], xs[j]) ^ sum); + sum = std::exchange(poly[k], yacl::math::Gf64Mul(poly[k], xs[j]) ^ sum); } } @@ -83,13 +83,14 @@ std::vector Interpolate(const std::vector& xs, uint64_t xi = xs[i]; subpoly[size - 1] = 1; for (int32_t k = size - 2; k >= 0; --k) { - subpoly[k] = poly[k + 1] ^ yacl::GfMul64(subpoly[k + 1], xi); + subpoly[k] = poly[k + 1] ^ yacl::math::Gf64Mul(subpoly[k + 1], xi); } - auto prod = yacl::GfMul64(ys[i], yacl::GfInv64(Evaluate(subpoly, xi))); + auto prod = + yacl::math::Gf64Mul(ys[i], yacl::math::Gf64Inv(Evaluate(subpoly, xi))); // update coeff for (size_t k = 0; k < size; ++k) { - coeffs[k] = coeffs[k] ^ yacl::GfMul64(subpoly[k], prod); + coeffs[k] = coeffs[k] ^ yacl::math::Gf64Mul(subpoly[k], prod); } } diff --git a/examples/psu/krtw19_psu.h b/examples/psu/krtw19_psu.h index 7186e69f..4c02f4ba 100644 --- a/examples/psu/krtw19_psu.h +++ b/examples/psu/krtw19_psu.h @@ -19,7 +19,7 @@ #include "yacl/base/int128.h" #include "yacl/link/link.h" -#include "yacl/math/f2k/f2k.h" +#include "yacl/math/galois_field/gf_intrinsic.h" #include "yacl/secparam.h" /* submodules */ diff --git a/yacl/crypto/experimental/dpf/BUILD.bazel b/yacl/crypto/experimental/dpf/BUILD.bazel index 6bafccff..e0a14388 100644 --- a/yacl/crypto/experimental/dpf/BUILD.bazel +++ b/yacl/crypto/experimental/dpf/BUILD.bazel @@ -16,12 +16,23 @@ load("//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") package(default_visibility = ["//visibility:public"]) +yacl_cc_library( + name = "ge2n", + srcs = ["ge2n.h"], + deps = [ + "//yacl/base:exception", + "//yacl/base:int128", + ], +) + yacl_cc_library( name = "dpf", srcs = ["dpf.cc"], hdrs = ["dpf.h"], deps = [ + ":ge2n", "//yacl/base:int128", + "//yacl/crypto/rand", "//yacl/crypto/tools:prg", "//yacl/link", ], @@ -34,3 +45,44 @@ yacl_cc_test( ":dpf", ], ) + +yacl_cc_library( + name = "dcf", + srcs = ["dcf.cc"], + hdrs = ["dcf.h"], + deps = [ + ":ge2n", + "//yacl/base:int128", + "//yacl/crypto/rand", + "//yacl/crypto/tools:prg", + "//yacl/link", + ], +) + +yacl_cc_test( + name = "dcf_test", + srcs = ["dcf_test.cc"], + deps = [ + ":dcf", + ], +) + +yacl_cc_library( + name = "pprf", + srcs = ["pprf.cc"], + hdrs = ["pprf.h"], + deps = [ + ":ge2n", + "//yacl/base:int128", + "//yacl/crypto/tools:prg", + ], +) + +yacl_cc_test( + name = "pprf_test", + srcs = ["pprf_test.cc"], + deps = [ + ":pprf", + "//yacl/crypto/rand", + ], +) diff --git a/yacl/crypto/experimental/dpf/dcf.cc b/yacl/crypto/experimental/dpf/dcf.cc new file mode 100644 index 00000000..10ba759d --- /dev/null +++ b/yacl/crypto/experimental/dpf/dcf.cc @@ -0,0 +1,267 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/experimental/dpf/dcf.h" + +#include + +#include "yacl/crypto/experimental/dpf/ge2n.h" +#include "yacl/secparam.h" +#include "yacl/utils/serializer.h" +#include "yacl/utils/serializer_adapter.h" + +namespace yacl::crypto { + +namespace { + +template +GE2n DcfPRG(uint128_t seed) { + Prg prng(seed); + return GE2n(prng()); +} + +std::tuple SplitDcfSeed( + uint128_t seed) { + uint128_t seed_left = 0; + uint128_t seed_right = 0; + bool t_left; + bool t_right; + uint128_t v_seed_left; + uint128_t v_seed_right; + + // TODO(@shanzhu.cjm): check if this implementation is secure and efficient + Prg prng(seed); + + seed_left = prng(); + seed_right = prng(); + v_seed_left = prng(); + v_seed_right = prng(); + uint128_t tmp = prng(); + + t_left = tmp >> 1 & 1; + t_right = tmp >> 2 & 1; + + return {seed_left, v_seed_left, t_left, seed_right, v_seed_right, t_right}; +} + +} // namespace + +// ----------------------------------------- +// Full domain key generation and evaluation +// ----------------------------------------- + +template +void DcfKeyGen(DcfKey* first_key, DcfKey* second_key, const GE2n& alpha, + const GE2n& beta, uint128_t first_mk, uint128_t second_mk) { + // enable the early termination + const uint32_t term_level = M; + + // set up the return keys + *first_key = DcfKey(false, first_mk); + *second_key = DcfKey(true, second_mk); + first_key->cws_vec.resize(term_level); + second_key->cws_vec.resize(term_level); + + std::array seeds_working; + seeds_working[0] = first_mk; + seeds_working[1] = second_mk; + + std::array t_working; + t_working[0] = false; // default by definition + t_working[1] = true; // default by definition + + auto v_working = GE2n(0); + + for (uint32_t i = 0; i < term_level; ++i) { + std::array seed_left; + std::array seed_right; + std::array t_left; + std::array t_right; + std::array v_seed_left; + std::array v_seed_right; + + bool alpha_bit = (alpha.GetBit(M - i - 1) != 0U); + + // Use working seed to generate seeds + // Note: this is the most time-consuming process + std::tie(seed_left[0], v_seed_left[0], t_left[0], seed_right[0], + v_seed_right[0], t_right[0]) = SplitDcfSeed(seeds_working[0]); + std::tie(seed_left[1], v_seed_left[1], t_left[1], seed_right[1], + v_seed_right[1], t_right[1]) = SplitDcfSeed(seeds_working[1]); + + const auto keep_seed = alpha_bit ? seed_right : seed_left; + const auto lose_seed = alpha_bit ? seed_left : seed_right; + const auto v_seed_keep = alpha_bit ? v_seed_right : v_seed_left; + const auto v_seed_lose = alpha_bit ? v_seed_left : v_seed_right; + const auto t_keep = alpha_bit ? t_right : t_left; + + bool cw_t_left; + bool cw_t_right; + GE2n cw_v; + + uint128_t cw_seed = lose_seed[0] ^ lose_seed[1]; + + // ----------------------------------------------------- + + GE2n prg_lose_0 = DcfPRG(v_seed_lose.at(0)); + GE2n prg_lose_1 = DcfPRG(v_seed_lose.at(1)); + GE2n prg_keep_0 = DcfPRG(v_seed_keep.at(0)); + GE2n prg_keep_1 = DcfPRG(v_seed_keep.at(1)); + + cw_v = prg_lose_1 + prg_lose_0.GetReverse() + v_working.GetReverse(); + + if (t_working.at(1)) { + cw_v.ReverseInplace(); + + if (alpha_bit) { + cw_v += beta.GetReverse(); + } + + // update v_working + v_working += prg_keep_1.GetReverse() + prg_keep_0 + cw_v.GetReverse(); + + } else { + if (alpha_bit) { + cw_v += beta; + } + + // update v_working + v_working += prg_keep_1.GetReverse() + prg_keep_0 + cw_v; + } + + // ----------------------------------------------------- + + cw_t_left = t_left[0] ^ t_left[1] ^ alpha_bit ^ 1; + cw_t_right = t_right[0] ^ t_right[1] ^ alpha_bit; + const auto& cw_t_keep = alpha_bit ? cw_t_right : cw_t_left; + + // get the seeds_working and t_working for next level + seeds_working[0] = t_working[0] ? keep_seed[0] ^ cw_seed : keep_seed[0]; + seeds_working[1] = t_working[1] ? keep_seed[1] ^ cw_seed : keep_seed[1]; + + t_working[0] = t_keep[0] ^ t_working[0] * cw_t_keep; + t_working[1] = t_keep[1] ^ t_working[1] * cw_t_keep; + + first_key->cws_vec[i].SetSeed(cw_seed); + first_key->cws_vec[i].SetLT(cw_t_left); + first_key->cws_vec[i].SetRT(cw_t_right); + first_key->cws_vec[i].SetV(cw_v.GetVal()); + } + + // Expand final seed_working + // get the final correlation words (has the same length as seeds) + // notice the notation is `somewhat' incorrect in the original paper + // + // First, we get the Convert(S_0 ^ key_block) and Convert(S_1 ^ key_block) + // + auto prg0 = DcfPRG(seeds_working[0]); + auto prg1 = DcfPRG(seeds_working[1]); + + // if !enable_evalall, we have only one last_cw_vec, otherwise, we + // have multiple last_cw_vec + YACL_ENFORCE(first_key->last_cw_vec.empty()); + YACL_ENFORCE(second_key->last_cw_vec.empty()); + + auto last_cw_ge2n = prg0.GetReverse() + prg1 + v_working.GetReverse(); + if (t_working[1]) { + last_cw_ge2n.ReverseInplace(); + } + + first_key->last_cw_vec.push_back(last_cw_ge2n.GetVal()); + + second_key->cws_vec = first_key->cws_vec; + second_key->last_cw_vec.push_back(first_key->last_cw_vec[0]); +} + +template +void DcfEval(const DcfKey& key, const GE2n& in, GE2n* out) { + uint128_t seed_working = key.GetSeed(); // the initial value + bool t_working = key.GetRank(); // the initial value + *out = GE2n(0); // init the out value + + for (uint32_t i = 0; i < M; ++i) { + const auto cw_seed = key.cws_vec[i].GetSeed(); + const GE2n cw_v(key.cws_vec[i].GetV()); + const auto cw_t_left = key.cws_vec[i].GetLT(); + const auto cw_t_right = key.cws_vec[i].GetRT(); + + uint128_t seed_left; + uint128_t seed_right; + bool t_left; + bool t_right; + uint128_t v_seed_left; + uint128_t v_seed_right; + + std::tie(seed_left, v_seed_left, t_left, seed_right, v_seed_right, + t_right) = SplitDcfSeed(seed_working); + + seed_left = t_working ? seed_left ^ cw_seed : seed_left; + t_left = t_left ^ (t_working * cw_t_left); + seed_right = t_working ? seed_right ^ cw_seed : seed_right; + t_right = t_right ^ (t_working * cw_t_right); + + GE2n prg0 = DcfPRG(v_seed_left); + GE2n prg1 = DcfPRG(v_seed_right); + + if (in.GetBit(M - i - 1) != 0U) { + if (t_working) { + *out += key.GetRank() ? (prg1 + cw_v).GetReverse() : prg1 + cw_v; + } else { + *out += key.GetRank() ? prg1.GetReverse() : prg1; + } + seed_working = seed_right; + t_working = t_right; + } else { + if (t_working) { + *out += key.GetRank() ? (prg0 + cw_v).GetReverse() : prg0 + cw_v; + } else { + *out += key.GetRank() ? prg0.GetReverse() : prg0; + } + seed_working = seed_left; + t_working = t_left; + } + } + + auto prg = DcfPRG(seed_working); + auto tmp = t_working ? prg + GE2n(key.last_cw_vec[0]) : prg; + *out += key.GetRank() ? tmp.GetReverse() : tmp; +} + +// template specification for different M and N +#define DCF_T_SPECIFY_FUNC(M, N) \ + template void DcfKeyGen(DcfKey * first_key, DcfKey * second_key, \ + const GE2n& alpha, const GE2n& beta, \ + uint128_t first_mk, uint128_t second_mk); \ + \ + template void DcfEval(const DcfKey& key, const GE2n& in, \ + GE2n* out); + +DCF_T_SPECIFY_FUNC(64, 64) +DCF_T_SPECIFY_FUNC(32, 64) +DCF_T_SPECIFY_FUNC(16, 64) +DCF_T_SPECIFY_FUNC(8, 64) +DCF_T_SPECIFY_FUNC(4, 64) +DCF_T_SPECIFY_FUNC(2, 64) +DCF_T_SPECIFY_FUNC(1, 64) + +DCF_T_SPECIFY_FUNC(64, 128) +DCF_T_SPECIFY_FUNC(32, 128) +DCF_T_SPECIFY_FUNC(16, 128) +DCF_T_SPECIFY_FUNC(8, 128) +DCF_T_SPECIFY_FUNC(4, 128) +DCF_T_SPECIFY_FUNC(2, 128) +DCF_T_SPECIFY_FUNC(1, 128) + +#undef DCF_T_SPECIFY_FUNC +} // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/dcf.h b/yacl/crypto/experimental/dpf/dcf.h new file mode 100644 index 00000000..447f362b --- /dev/null +++ b/yacl/crypto/experimental/dpf/dcf.h @@ -0,0 +1,113 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/base/int128.h" +#include "yacl/crypto/experimental/dpf/ge2n.h" + +/* submodules */ +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/prg.h" +#include "yacl/secparam.h" + +YACL_MODULE_DECLARE("dcf", SecParam::C::k128, SecParam::S::INF); + +namespace yacl::crypto { + +// Distributed Point Function (DCF) +// +// For more details, please see: https://eprint.iacr.org/2018/707 +// +class DcfKey { + public: + // Constructors + DcfKey() = default; + + explicit DcfKey(bool rank, const uint128_t mseed = SecureRandSeed()) + : rank_(rank), mseed_(mseed) {} + + // internal type definition + class CW { + public: + CW() = default; + CW(uint128_t seed, uint8_t t_store) : seed_(seed), t_store_(t_store) {} + + uint8_t GetLT() const { return t_store_ & 1; } + uint8_t GetRT() const { return (t_store_ >> 1) & 1; } + + uint128_t GetSeed() const { return seed_; } + uint8_t GetTStore() const { return t_store_; } + uint128_t GetV() const { return this->v_; } + + void SetLT(uint8_t t_left) { + YACL_ENFORCE(t_left == 0 || t_left == 1); + t_store_ = (GetRT() << 1) + t_left; + } + + void SetRT(uint8_t t_right) { + YACL_ENFORCE(t_right == 0 || t_right == 1); + t_store_ = (t_right << 1) + GetLT(); + } + + void SetSeed(uint128_t seed) { seed_ = seed; } + + void SetV(uint128_t v) { this->v_ = v; } + + private: + uint128_t seed_ = 0; // this level's seed, default = 0 + uint128_t v_; + uint8_t t_store_ = 0; // 1st bit=> t_left, 2nd bit=> t_right + }; + + std::vector cws_vec; // correlated words for each level + std::vector last_cw_vec; // the final correlation word + + bool GetRank() const { return rank_; } + void SetRank(bool rank) { rank_ = rank; } + + uint128_t GetSeed() const { return mseed_; } + void SetSeed(uint128_t seed) { mseed_ = seed; } + + private: + bool rank_{}; // only support two parties (0/1), compulsory param + uint128_t mseed_ = 0; // the master seed +}; + +// ---------------------------------------------------------------------------- +// Core Functions of DCF +// ---------------------------------------------------------------------------- +// NOTE: Supported (M, N) parameter pairs are: +// - (M = {8, 16, 32, 64}, N = {8, 16, 32, 64, 128}) +// +template +void DcfKeyGen(DcfKey* first_key, DcfKey* second_key, const GE2n& alpha, + const GE2n& beta, uint128_t first_mk, uint128_t second_mk); + +template +void DcfEval(const DcfKey& key, const GE2n& in, GE2n* out); + +} // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/dcf_test.cc b/yacl/crypto/experimental/dpf/dcf_test.cc new file mode 100644 index 00000000..20863d02 --- /dev/null +++ b/yacl/crypto/experimental/dpf/dcf_test.cc @@ -0,0 +1,86 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/experimental/dpf/dcf.h" + +#include +#include + +#include "gtest/gtest.h" + +#include "yacl/base/int128.h" +#include "yacl/crypto/experimental/dpf/ge2n.h" +#include "yacl/crypto/rand/rand.h" + +namespace yacl::crypto { + +TEST(DcfTest, Gen) { + DcfKey k0; + DcfKey k1; + uint128_t first_mk = SecureRandSeed(); + uint128_t second_mk = SecureRandSeed(); + + constexpr size_t k_in_bitnum = 16; + constexpr size_t k_out_bitnum = 64; + + auto alpha = GE2n(FastRandU64()); + auto beta = GE2n(FastRandU64()); + + DcfKeyGen(&k0, &k1, alpha, beta, first_mk, second_mk); +} + +TEST(DcfTest, Eval) { + DcfKey k0; + DcfKey k1; + uint128_t first_mk = SecureRandSeed(); + uint128_t second_mk = SecureRandSeed(); + + constexpr size_t k_in_bitnum = 16; + constexpr size_t k_out_bitnum = 64; + + auto alpha = GE2n(FastRandU64()); + auto beta = GE2n(FastRandU64()); + + DcfKeyGen(&k0, &k1, alpha, beta, first_mk, second_mk); + + /* smaller input */ + { + auto in = GE2n(FastRandU64()); + while (in > alpha) { + in = GE2n(FastRandU64()); + } + auto out1 = GE2n(0); + auto out2 = GE2n(0); + DcfEval(k0, in, &out1); + DcfEval(k1, in, &out2); + + EXPECT_EQ(out1 + out2, beta); + } + + /* larger input */ + { + auto in = GE2n(FastRandU64()); + while (in < alpha) { + in = GE2n(FastRandU64()); + } + auto out1 = GE2n(0); + auto out2 = GE2n(0); + DcfEval(k0, in, &out1); + DcfEval(k1, in, &out2); + + EXPECT_EQ(out1 + out2, GE2n(0)); + } +} + +} // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/dpf.cc b/yacl/crypto/experimental/dpf/dpf.cc index e183688d..b5125957 100644 --- a/yacl/crypto/experimental/dpf/dpf.cc +++ b/yacl/crypto/experimental/dpf/dpf.cc @@ -16,6 +16,8 @@ #include +#include "yacl/crypto/experimental/dpf/ge2n.h" +#include "yacl/secparam.h" #include "yacl/utils/serializer.h" #include "yacl/utils/serializer_adapter.h" @@ -23,15 +25,10 @@ namespace yacl::crypto { namespace { -// Get the i-th least significant bit of x -uint8_t GetBit(DpfInStore x, uint32_t i) { - YACL_ENFORCE(i < sizeof(DpfInStore) * 8, "GetBit: index out of range"); - return x >> i & 1; -} - -DpfOutStore DpfPRG(uint128_t seed) { +template +GE2n DpfPRG(uint128_t seed) { Prg prng(seed); - return prng(); + return GE2n(prng()); } std::tuple SplitDpfSeed(uint128_t seed) { @@ -45,7 +42,6 @@ std::tuple SplitDpfSeed(uint128_t seed) { seed_left = prng(); seed_right = prng(); - uint128_t tmp = prng(); t_left = tmp >> 1 & 1; @@ -54,30 +50,78 @@ std::tuple SplitDpfSeed(uint128_t seed) { return {seed_left, t_left, seed_right, t_right}; } +size_t GetTerminateLevel(bool enable_evalall, size_t m, size_t n) { + if (!enable_evalall) { + return m; + } + auto c = YACL_MODULE_SECPARAM_C_UINT("dpf"); + size_t x = ceil(m - log(c / n)); + return std::min(m, x); +} + +template +void Traverse(DpfKey* key, absl::Span> result, size_t current_level, + uint64_t current_pos, uint128_t seed_working, bool t_working, + size_t term_level) { + if (current_level < term_level) { + uint128_t seed_left; + uint128_t seed_right; + bool t_left; + bool t_right; + const auto cw_seed = key->cws_vec[current_level].GetSeed(); + const auto cw_t_left = key->cws_vec[current_level].GetLT(); + const auto cw_t_right = key->cws_vec[current_level].GetRT(); + + std::tie(seed_left, t_left, seed_right, t_right) = + SplitDpfSeed(seed_working); + + seed_left = t_working ? seed_left ^ cw_seed : seed_left; + t_left = t_left ^ (t_working * cw_t_left); + seed_right = t_working ? seed_right ^ cw_seed : seed_right; + t_right = t_right ^ (t_working * cw_t_right); + + uint64_t next_left_pos = current_pos; + uint64_t next_right_pos = (1ULL << current_level) + current_pos; + + Traverse(key, result, current_level + 1, next_left_pos, seed_left, + t_left, term_level); + Traverse(key, result, current_level + 1, next_right_pos, seed_right, + t_right, term_level); + + } else { + auto prg = DpfPRG(seed_working); + uint32_t expand_num = static_cast(1) << (M - term_level); + + for (uint32_t i = 0; i < expand_num; i++) { + auto tmp = GE2n(t_working * key->last_cw_vec[i]); + result[current_pos + (i << term_level)] = + key->GetRank() ? (prg + tmp).GetReverse() : (prg + tmp); + prg = DpfPRG(prg.GetVal()); + } + } +} + } // namespace // ----------------------------------------- // Full domain key generation and evaluation // ----------------------------------------- -void DpfContext::Gen(DpfKey& first_key, DpfKey& second_key, DpfInStore alpha, - DpfOutStore beta, uint128_t first_mk, uint128_t second_mk, - bool enable_evalall) { - YACL_ENFORCE(this->in_bitnum_ > 0); - YACL_ENFORCE(this->in_bitnum_ > log2(alpha)); - YACL_ENFORCE(this->in_bitnum_ <= 64); - YACL_ENFORCE(this->ss_bitnum_ > 0); - YACL_ENFORCE(this->ss_bitnum_ <= 64); +template +void DpfKeyGen(DpfKey* first_key, DpfKey* second_key, const GE2n& alpha, + const GE2n& beta, uint128_t first_mk, uint128_t second_mk, + bool enable_evalall) { + static_assert(M > 0 && M <= 64); // input bits number constrains + static_assert(N > 0 && N <= 128); // output bits number constrains // enable the early termination - uint32_t term_level = GetTerminateLevel(enable_evalall); + uint32_t term_level = GetTerminateLevel(enable_evalall, M, N); // set up the return keys - first_key = DpfKey(false, GetInBitNum(), GetSsBitNum(), sec_param_, first_mk); - second_key = - DpfKey(true, GetInBitNum(), GetSsBitNum(), sec_param_, second_mk); - first_key.cws_vec.resize(term_level); - second_key.cws_vec.resize(term_level); + *first_key = DpfKey(false, first_mk); + *second_key = DpfKey(true, second_mk); + first_key->cws_vec.resize(term_level); + second_key->cws_vec.resize(term_level); std::array seeds_working; seeds_working[0] = first_mk; @@ -93,11 +137,10 @@ void DpfContext::Gen(DpfKey& first_key, DpfKey& second_key, DpfInStore alpha, std::array t_left; std::array t_right; - bool alpha_bit = (GetBit(alpha, i) != 0U); + bool alpha_bit = (alpha.GetBit(i) != 0U); // Use working seed to generate seeds // Note: this is the most time-consuming process - // [TODO]: Make this parallel std::tie(seed_left[0], t_left[0], seed_right[0], t_right[0]) = SplitDpfSeed(seeds_working[0]); std::tie(seed_left[1], t_left[1], seed_right[1], t_right[1]) = @@ -122,9 +165,9 @@ void DpfContext::Gen(DpfKey& first_key, DpfKey& second_key, DpfInStore alpha, t_working[0] = t_keep[0] ^ t_working[0] * cw_t_keep; t_working[1] = t_keep[1] ^ t_working[1] * cw_t_keep; - first_key.cws_vec[i].SetSeed(cw_seed); - first_key.cws_vec[i].SetTLeft(cw_t_left); - first_key.cws_vec[i].SetTRight(cw_t_right); + first_key->cws_vec[i].SetSeed(cw_seed); + first_key->cws_vec[i].SetLT(cw_t_left); + first_key->cws_vec[i].SetRT(cw_t_right); } // Expand final seed_working @@ -133,64 +176,64 @@ void DpfContext::Gen(DpfKey& first_key, DpfKey& second_key, DpfInStore alpha, // // First, we get the Convert(S_0 ^ key_block) and Convert(S_1 ^ key_block) // - DpfOutStore prg0 = DpfPRG(seeds_working[0]); - DpfOutStore prg1 = DpfPRG(seeds_working[1]); + auto prg0 = DpfPRG(seeds_working[0]); + auto prg1 = DpfPRG(seeds_working[1]); // if !enable_evalall, we have only one last_cw_vec, otherwise, we // have multiple last_cw_vec - YACL_ENFORCE(first_key.last_cw_vec.empty()); - YACL_ENFORCE(second_key.last_cw_vec.empty()); + YACL_ENFORCE(first_key->last_cw_vec.empty()); + YACL_ENFORCE(second_key->last_cw_vec.empty()); if (!enable_evalall) { - first_key.last_cw_vec.push_back(TruncateSs(beta + ReverseSs(prg0) + prg1)); + first_key->last_cw_vec.push_back( + (beta + prg0.GetReverse() + prg1).GetVal()); if (t_working[1]) { - first_key.last_cw_vec[0] = ReverseSs(first_key.last_cw_vec[0]); + first_key->last_cw_vec[0] = + GE2n(first_key->last_cw_vec[0]).GetReverse().GetVal(); } - second_key.cws_vec = first_key.cws_vec; - second_key.last_cw_vec.push_back(first_key.last_cw_vec[0]); + second_key->cws_vec = first_key->cws_vec; + second_key->last_cw_vec.push_back(first_key->last_cw_vec[0]); } else { - first_key.EnableEvalAll(); - second_key.EnableEvalAll(); + first_key->EnableEvalAll(); + second_key->EnableEvalAll(); - uint32_t alpha_pos_term_level = alpha >> term_level; - uint32_t expand_num = static_cast(1) - << (GetInBitNum() - term_level); + uint32_t alpha_pos_term_level = alpha.GetVal() >> term_level; + uint32_t expand_num = static_cast(1) << (M - term_level); for (uint32_t i = 0; i < expand_num; i++) { - DpfOutStore last_cw = 0; + GE2n last_cw; if (i == alpha_pos_term_level) { - last_cw = TruncateSs(beta + ReverseSs(prg0) + TruncateSs(prg1)); + last_cw = beta + GE2n(prg0).GetReverse() + prg1; } else { - last_cw = TruncateSs(ReverseSs(prg0) + TruncateSs(prg1)); + last_cw = GE2n(prg0).GetReverse() + prg1; } if (t_working[1]) { - first_key.last_cw_vec.push_back(ReverseSs(last_cw)); + first_key->last_cw_vec.push_back(last_cw.GetReverse().GetVal()); } else { - first_key.last_cw_vec.push_back(last_cw); + first_key->last_cw_vec.push_back(last_cw.GetVal()); } - second_key.cws_vec = first_key.cws_vec; - second_key.last_cw_vec.push_back(first_key.last_cw_vec[i]); + second_key->cws_vec = first_key->cws_vec; + second_key->last_cw_vec.push_back(first_key->last_cw_vec[i]); - prg0 = DpfPRG(prg0); - prg1 = DpfPRG(prg1); + prg0 = DpfPRG(prg0.GetVal()); + prg1 = DpfPRG(prg1.GetVal()); } } - // return {std::move(first_key), std::move(second_key)}; } -DpfOutStore DpfContext::Eval(DpfKey& key, DpfInStore x) { - YACL_ENFORCE(this->in_bitnum_ > log2(x)); +template +void DpfEval(const DpfKey& key, const GE2n& in, GE2n* out) { YACL_ENFORCE(key.enable_evalall == false); uint128_t seed_working = key.GetSeed(); // the initial value bool t_working = key.GetRank(); // the initial value - for (uint32_t i = 0; i < GetInBitNum(); i++) { + for (uint32_t i = 0; i < M; i++) { const auto cw_seed = key.cws_vec[i].GetSeed(); - const auto cw_t_left = key.cws_vec[i].GetTLeft(); - const auto cw_t_right = key.cws_vec[i].GetTRight(); + const auto cw_t_left = key.cws_vec[i].GetLT(); + const auto cw_t_right = key.cws_vec[i].GetRT(); uint128_t seed_left; uint128_t seed_right; @@ -205,7 +248,7 @@ DpfOutStore DpfContext::Eval(DpfKey& key, DpfInStore x) { seed_right = t_working ? seed_right ^ cw_seed : seed_right; t_right = t_right ^ (t_working * cw_t_right); - if (GetBit(x, i) != 0U) { + if (in.GetBit(i) != 0U) { seed_working = seed_right; t_working = t_right; } else { @@ -214,105 +257,60 @@ DpfOutStore DpfContext::Eval(DpfKey& key, DpfInStore x) { } } - DpfOutStore prg = TruncateSs(DpfPRG(seed_working)); + auto prg = DpfPRG(seed_working); - DpfOutStore result = key.GetRank() - ? ReverseSs(prg + t_working * key.last_cw_vec[0]) - : TruncateSs(prg + t_working * key.last_cw_vec[0]); + auto tmp = GE2n(t_working * key.last_cw_vec[0]); + uint128_t result = + key.GetRank() ? (prg + tmp).GetReverse().GetVal() : (prg + tmp).GetVal(); - return TruncateSs(result); + *out = GE2n(result); } -void DpfContext::Traverse(DpfKey& key, std::vector& result, - size_t current_level, uint64_t current_pos, - uint128_t seed_working, bool t_working, - size_t term_level) { - if (current_level < term_level) { - uint128_t seed_left; - uint128_t seed_right; - bool t_left; - bool t_right; - const auto cw_seed = key.cws_vec[current_level].GetSeed(); - const auto cw_t_left = key.cws_vec[current_level].GetTLeft(); - const auto cw_t_right = key.cws_vec[current_level].GetTRight(); - - std::tie(seed_left, t_left, seed_right, t_right) = - SplitDpfSeed(seed_working); - - seed_left = t_working ? seed_left ^ cw_seed : seed_left; - t_left = t_left ^ (t_working * cw_t_left); - seed_right = t_working ? seed_right ^ cw_seed : seed_right; - t_right = t_right ^ (t_working * cw_t_right); +template +void DpfEvalAll(DpfKey* key, absl::Span> out) { + YACL_ENFORCE(key->enable_evalall == true); - uint64_t next_left_pos = current_pos; - uint64_t next_right_pos = (1ULL << current_level) + current_pos; + uint128_t seed_working = key->GetSeed(); // the initial value + bool t_working = key->GetRank(); // the initial value + uint32_t term_level = GetTerminateLevel(true, M, N); - Traverse(key, result, current_level + 1, next_left_pos, seed_left, t_left, - term_level); - Traverse(key, result, current_level + 1, next_right_pos, seed_right, - t_right, term_level); - - } else { - DpfOutStore prg = DpfPRG(seed_working); - uint32_t expand_num = static_cast(1) - << (GetInBitNum() - term_level); - - for (uint32_t i = 0; i < expand_num; i++) { - result[current_pos + (i << term_level)] = - key.GetRank() - ? ReverseSs(TruncateSs(prg) + t_working * key.last_cw_vec[i]) - : TruncateSs(TruncateSs(prg) + t_working * key.last_cw_vec[i]); - prg = DpfPRG(prg); - } - } -} - -std::vector DpfContext::EvalAll(DpfKey& key) { - YACL_ENFORCE(key.enable_evalall == true); - - uint128_t seed_working = key.GetSeed(); // the initial value - bool t_working = key.GetRank(); // the initial value - uint32_t term_level = GetTerminateLevel(true); - - YACL_ENFORCE(GetInBitNum() <= 25); // only support in_bin_num < 25 - - uint64_t num = 1ULL << GetInBitNum(); - std::vector result(num); + auto num = (uint128_t)1 << M; + std::vector result(num); uint64_t current_pos = 0; uint64_t current_level = 0; // we start from the top level - Traverse(key, result, current_level, current_pos, seed_working, t_working, - term_level); - - return result; -} - -Buffer DpfKey::Serialize() const { - // var "cws_vec" 's type 'std::vector' not supported, convert to STL - // type - std::vector> dpf_cws; - dpf_cws.reserve(cws_vec.size()); - for (const auto& cws : cws_vec) { - dpf_cws.emplace_back(cws.GetSeed(), cws.GetTStore()); - } - - // do serialize - return SerializeVars(enable_evalall, dpf_cws, last_cw_vec, rank_, in_bitnum_, - ss_bitnum_, sec_param_, mseed_); -} - -void DpfKey::Deserialize(ByteContainerView in) { - std::vector> dpf_cws; - DeserializeVarsTo(in, &enable_evalall, &dpf_cws, &last_cw_vec, &rank_, - &in_bitnum_, &ss_bitnum_, &sec_param_, &mseed_); - - // recover "cws_vec" with type std::vector - cws_vec.clear(); - cws_vec.reserve(dpf_cws.size()); - for (const auto& cws : dpf_cws) { - cws_vec.emplace_back(cws.first, cws.second); - } + Traverse(key, out, current_level, current_pos, seed_working, t_working, + term_level); } +// template specialization for different M and N +#define DPF_T_SPECIFY_FUNC(M, N) \ + template void DpfKeyGen(DpfKey * first_key, DpfKey * second_key, \ + const GE2n& alpha, const GE2n& beta, \ + uint128_t first_mk, uint128_t second_mk, \ + bool enable_evalall = false); \ + \ + template void DpfEval(const DpfKey& key, const GE2n& in, \ + GE2n* out); \ + \ + template void DpfEvalAll(DpfKey * key, absl::Span> out); + +DPF_T_SPECIFY_FUNC(64, 64) +DPF_T_SPECIFY_FUNC(32, 64) +DPF_T_SPECIFY_FUNC(16, 64) +DPF_T_SPECIFY_FUNC(8, 64) +DPF_T_SPECIFY_FUNC(4, 64) +DPF_T_SPECIFY_FUNC(2, 64) +DPF_T_SPECIFY_FUNC(1, 64) + +DPF_T_SPECIFY_FUNC(64, 128) +DPF_T_SPECIFY_FUNC(32, 128) +DPF_T_SPECIFY_FUNC(16, 128) +DPF_T_SPECIFY_FUNC(8, 128) +DPF_T_SPECIFY_FUNC(4, 128) +DPF_T_SPECIFY_FUNC(2, 128) +DPF_T_SPECIFY_FUNC(1, 128) + +#undef DPF_T_SPECIFY_FUNC } // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/dpf.h b/yacl/crypto/experimental/dpf/dpf.h index c5b933dd..bbd3cb46 100644 --- a/yacl/crypto/experimental/dpf/dpf.h +++ b/yacl/crypto/experimental/dpf/dpf.h @@ -27,65 +27,61 @@ #include "yacl/base/exception.h" #include "yacl/base/int128.h" +#include "yacl/crypto/experimental/dpf/ge2n.h" /* submodules */ +#include "yacl/crypto/rand/rand.h" #include "yacl/crypto/tools/prg.h" +#include "yacl/secparam.h" + +YACL_MODULE_DECLARE("dpf", SecParam::C::k128, SecParam::S::INF); namespace yacl::crypto { -// Implementation of Distributed Point Function (DPF) -// title : Function Secret Sharing: Improvements and Extensions -// eprint: https://eprint.iacr.org/2018/707 +// Distributed Point Function (DPF) // -// Assume we have a function F(*), where F(alpha)=beta, F(*!=alpha)=0. -// DPF splits the finction into two parts F1 and F2, and ensures F1(alpha)=r, -// F2(alpha)=-r+beta, and F1(*!=alpha)=r, F2(*!=alpha)=-r +// For more details, please see: https://eprint.iacr.org/2018/707 // -// alpha : arbitrary length mapping input -// beta : 128bit mapping output -// Note: result is A-share +class DpfKey { + public: + // internal type definition + class CW { + public: + CW() = default; + CW(uint128_t seed, uint8_t t_store) : seed_(seed), t_store_(t_store) {} -using DpfInStore = uint128_t; // the input room -using DpfOutStore = uint128_t; // the secret sharing room + uint8_t GetLT() const { return t_store_ & 1; } + uint8_t GetRT() const { return (t_store_ >> 1) & 1; } -struct DpfCW { - public: - DpfCW() = default; - DpfCW(uint128_t seed, uint8_t t_store) : seed_(seed), t_store_(t_store) {} + uint128_t GetSeed() const { return seed_; } + uint8_t GetTStore() const { return t_store_; } - bool GetTLeft() const { return t_store_ & 1; } - bool GetTRight() const { return (t_store_ >> 1) & 1; } + void SetLT(uint8_t t_left) { + YACL_ENFORCE(t_left == 0 || t_left == 1); + t_store_ = (GetRT() << 1) + t_left; + } - uint128_t GetSeed() const { return seed_; } - uint8_t GetTStore() const { return t_store_; } + void SetRT(uint8_t t_right) { + YACL_ENFORCE(t_right == 0 || t_right == 1); + t_store_ = (t_right << 1) + GetLT(); + } - void SetTLeft(bool t_left) { t_store_ = (GetTRight() << 1) + t_left; } - void SetTRight(bool t_right) { t_store_ = (t_right << 1) + GetTLeft(); } - void SetSeed(uint128_t seed) { seed_ = seed; } + void SetSeed(uint128_t seed) { seed_ = seed; } - private: - uint128_t seed_ = 0; // this level's seed, default = 0 - uint8_t t_store_ = 0; // 1st bit=> t_left, 2nd bit=> t_right -}; + private: + uint128_t seed_ = 0; // this level's seed, default = 0 + uint8_t t_store_ = 0; // 1st bit=> t_left, 2nd bit=> t_right + }; -class DpfKey { - public: - bool enable_evalall = false; // full domain eval - std::vector cws_vec; // correlated words for each level - std::vector last_cw_vec; // the final correlation word + bool enable_evalall = false; // full domain eval + std::vector cws_vec; // correlated words for each level + std::vector last_cw_vec; // the final correlation word // empty constructor DpfKey() = default; - DpfKey(bool rank, const uint128_t mseed) : rank_(rank), mseed_(mseed) {} - - DpfKey(bool rank, size_t in_bitnum, size_t ss_bitnum, uint32_t sec_param, - const uint128_t mseed) - : rank_(rank), - in_bitnum_(in_bitnum), - ss_bitnum_(ss_bitnum), - sec_param_(sec_param), - mseed_(mseed) {} + explicit DpfKey(bool rank, const uint128_t mseed = SecureRandSeed()) + : rank_(rank), mseed_(mseed) {} void EnableEvalAll() { enable_evalall = true; } void DisableFullEval() { enable_evalall = false; } @@ -96,100 +92,40 @@ class DpfKey { uint128_t GetSeed() const { return mseed_; } void SetSeed(uint128_t seed) { mseed_ = seed; } - size_t GetInBitNum() const { return in_bitnum_; } - size_t GetSsBitNum() const { return ss_bitnum_; } - - uint32_t GetSecParam() const { return sec_param_; } - - Buffer Serialize() const; - void Deserialize(ByteContainerView s); - private: - bool rank_{}; // only support two parties (0/1), compulsory param - size_t in_bitnum_ = 64; // bit number (for point), default = 64 - size_t ss_bitnum_ = 64; // bit number (for output value), default = 64 - uint32_t sec_param_ = 128; // we assume 128 bit security (fixed) - uint128_t mseed_ = 0; // the master seed (the default is not secure) + bool rank_{}; // only support two parties (0/1), compulsory param + uint128_t mseed_ = 0; // the master seed }; -class DpfContext { - public: - // constructors - DpfContext() = default; - - explicit DpfContext(size_t in_bitnum) : in_bitnum_(in_bitnum) {} - - DpfContext(size_t in_bitnum, size_t ss_bitnum) - : in_bitnum_(in_bitnum), ss_bitnum_(ss_bitnum) {} - - void SetInBitNum(size_t in_bitnum) { - YACL_ENFORCE(in_bitnum <= 64); - in_bitnum_ = in_bitnum; - } - size_t GetInBitNum() const { return in_bitnum_; } - - void SetSsBitNum(size_t ss_bitnum) { - YACL_ENFORCE(ss_bitnum <= 64); - ss_bitnum_ = ss_bitnum; - } - size_t GetSsBitNum() const { return ss_bitnum_; } - - // -------------------------------------- - // Original key generation and evaluation - // -------------------------------------- - std::pair Gen(DpfInStore alpha, DpfOutStore beta, - uint128_t first_mk, uint128_t second_mk, - bool enable_evalall = false) { - DpfKey k0; - DpfKey k1; - Gen(k0, k1, alpha, beta, first_mk, second_mk, enable_evalall); - return {std::move(k0), std::move(k1)}; - } - - void Gen(DpfKey& first_key, DpfKey& second_key, DpfInStore alpha, - DpfOutStore beta, uint128_t first_mk, uint128_t second_mk, - bool enable_evalall = false); - - DpfOutStore Eval(DpfKey& key, DpfInStore input); - - std::vector EvalAll(DpfKey& key); - - DpfOutStore GetSsMask() const { - YACL_ENFORCE(ss_bitnum_ <= 64); - if (ss_bitnum_ == 64) { - return 0xFFFFFFFFFFFFFFFF; - } - return (static_cast(1) << ss_bitnum_) - 1; - } - - DpfOutStore TruncateSs(DpfOutStore input) const { - YACL_ENFORCE(ss_bitnum_ <= 64); - return input & GetSsMask(); - } - - DpfOutStore ReverseSs(DpfOutStore input) const { - YACL_ENFORCE(ss_bitnum_ <= 64); - return TruncateSs(GetSsMask() - TruncateSs(input) + 1); - } +// ---------------------------------------------------------------------------- +// Core Functions of DPF +// ---------------------------------------------------------------------------- +// NOTE: Supported (M, N) parameter pairs are: +// - (M = {8, 16, 32, 64}, N = {8, 16, 32, 64, 128}) +// +// TODO maybe type traits +// template +// struct IsSupportedDpfType { +// static constexpr std::array m_set = {8, 16, 32, 64}; +// static constexpr std::array n_set = {8, 16, 32, 64, 128}; +// static constexpr bool value = []() constexpr { +// return std::find(std::begin(m_set), std::end(m_set), M) != +// std::end(m_set) && +// std::find(std::begin(n_set), std::end(m_set), N) != +// std::end(n_set); +// ; +// }; +// }; + +template +void DpfKeyGen(DpfKey* first_key, DpfKey* second_key, const GE2n& alpha, + const GE2n& beta, uint128_t first_mk, uint128_t second_mk, + bool enable_evalall = false); + +template +void DpfEval(const DpfKey& key, const GE2n& in, GE2n* out); + +template +void DpfEvalAll(DpfKey* key, absl::Span> out); - private: - void Traverse(DpfKey& key, std::vector& result, - size_t current_level, uint64_t current_pos, - uint128_t seed_working, bool t_working, size_t term_level); - - // Note that for the case of sec_param = 128 and ss_bitnum = 64, we - // always have term_level = in_bitnum - size_t GetTerminateLevel(bool enable_evalall) const { - if (!enable_evalall) { - return in_bitnum_; - } - size_t n = in_bitnum_; - size_t x = ceil(n - log(sec_param_ / ss_bitnum_)); - return std::min(n, x); - } - - size_t in_bitnum_ = 64; - size_t ss_bitnum_ = 64; - uint32_t sec_param_ = 128; // we assume 128 bit security (fixed) -}; } // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/dpf_test.cc b/yacl/crypto/experimental/dpf/dpf_test.cc index 06ff858c..6fe5445e 100644 --- a/yacl/crypto/experimental/dpf/dpf_test.cc +++ b/yacl/crypto/experimental/dpf/dpf_test.cc @@ -19,150 +19,93 @@ #include "gtest/gtest.h" -namespace yacl::crypto { - -struct TestParams { - DpfInStore alpha; - DpfOutStore beta; - uint32_t InBitnum; - uint32_t SsBitnum; -}; +#include "yacl/base/int128.h" +#include "yacl/crypto/experimental/dpf/ge2n.h" +#include "yacl/crypto/rand/rand.h" -class FssDpfGenTest : public testing::TestWithParam {}; +namespace yacl::crypto { -class FssDpfEvalTest : public testing::TestWithParam {}; +TEST(DpfTest, Gen) { + DpfKey k0; + DpfKey k1; + uint128_t first_mk = SecureRandSeed(); + uint128_t second_mk = SecureRandSeed(); -class FssDpfEvalAllTest : public testing::TestWithParam {}; + constexpr size_t k_in_bitnum = 16; + constexpr size_t k_out_bitnum = 64; -TEST_P(FssDpfGenTest, Works) { - auto params = GetParam(); - DpfKey k0, k1; - uint128_t first_mk = 0; - uint128_t second_mk = 1; - DpfContext context; - context.SetInBitNum(params.InBitnum); - context.SetSsBitNum(params.SsBitnum); + auto alpha = GE2n(FastRandU64()); + auto beta = GE2n(FastRandU64()); - std::tie(k0, k1) = - context.Gen(params.alpha, params.beta, first_mk, second_mk, false); + DpfKeyGen(&k0, &k1, alpha, beta, first_mk, second_mk, false); } -TEST_P(FssDpfEvalTest, Works) { - auto params = GetParam(); +TEST(DpfTest, Eval) { DpfKey k0; DpfKey k1; - DpfContext context; - uint128_t first_mk = 0; - uint128_t second_mk = 1; + uint128_t first_mk = SecureRandSeed(); + uint128_t second_mk = SecureRandSeed(); - context.SetInBitNum(params.InBitnum); - context.SetSsBitNum(params.SsBitnum); + constexpr size_t k_in_bitnum = 16; + constexpr size_t k_out_bitnum = 64; - std::tie(k0, k1) = - context.Gen(params.alpha, params.beta, first_mk, second_mk, false); + auto alpha = GE2n(FastRandU64()); + auto beta = GE2n(FastRandU64()); - size_t range = 1 << context.GetInBitNum(); + DpfKeyGen(&k0, &k1, alpha, beta, first_mk, second_mk, false); - for (size_t i = 0; i < range; i++) { - DpfOutStore temp0 = context.Eval(k0, i); - DpfOutStore temp1 = context.Eval(k1, i); - DpfOutStore result = context.TruncateSs(temp0 + temp1); - if (i == params.alpha) { - EXPECT_EQ(result, params.beta); - } else { - EXPECT_EQ(result, 0); + /* wrong input */ + { + auto in = GE2n(FastRandU64()); + while (in == alpha) { + in = GE2n(FastRandU64()); } + auto out1 = GE2n(0); + auto out2 = GE2n(0); + DpfEval(k0, in, &out1); + DpfEval(k1, in, &out2); + EXPECT_EQ((out1 + out2).GetVal(), 0); } - DpfKey k1_copy; - auto k1_string = k1.Serialize(); - k1_copy.Deserialize(k1_string); - - for (size_t i = 0; i < range; i++) { - DpfOutStore temp0 = context.Eval(k0, i); - DpfOutStore temp1 = context.Eval(k1_copy, i); - DpfOutStore result = context.TruncateSs(temp0 + temp1); - if (i == params.alpha) { - EXPECT_EQ(result, params.beta); - } else { - EXPECT_EQ(result, 0); - } + /* correct input */ + { + auto out1 = GE2n(0); + auto out2 = GE2n(0); + DpfEval(k0, alpha, &out1); + DpfEval(k1, alpha, &out2); + EXPECT_EQ(out1 + out2, beta); } } -TEST_P(FssDpfEvalAllTest, Works) { - auto params = GetParam(); +TEST(DpfTest, EvalAll) { DpfKey k0; DpfKey k1; - DpfContext context; - uint128_t first_mk = 0; - uint128_t second_mk = 1; + uint128_t first_mk = SecureRandSeed(); + uint128_t second_mk = SecureRandSeed(); - context.SetInBitNum(params.InBitnum); - context.SetSsBitNum(params.SsBitnum); + constexpr size_t k_in_bitnum = 16; + constexpr size_t k_out_bitnum = 128; - std::tie(k0, k1) = - context.Gen(params.alpha, params.beta, first_mk, second_mk, true); + auto alpha = GE2n(FastRandU64()); + auto beta = GE2n(FastRandU64()); - // k0.Print(); - // k1.Print(); + DpfKeyGen(&k0, &k1, alpha, beta, first_mk, second_mk, true); - std::vector temp0 = context.EvalAll(k0); - std::vector temp1 = context.EvalAll(k1); - - size_t range = 1 << context.GetInBitNum(); - - for (size_t i = 0; i < range; i++) { - DpfOutStore result = context.TruncateSs(temp0.at(i) + temp1.at(i)); - - if (i == params.alpha) { - EXPECT_EQ(result, params.beta); - } else { - EXPECT_EQ(result, 0); - } - } - - DpfKey k1_copy; - auto k1_string = k1.Serialize(); - k1_copy.Deserialize(k1_string); - - temp0 = context.EvalAll(k0); - temp1 = context.EvalAll(k1_copy); + size_t range = 1 << k_in_bitnum; + auto out1 = std::vector>(range); + auto out2 = std::vector>(range); + DpfEvalAll(&k0, absl::MakeSpan(out1)); + DpfEvalAll(&k1, absl::MakeSpan(out2)); for (size_t i = 0; i < range; i++) { - DpfOutStore result = context.TruncateSs(temp0.at(i) + temp1.at(i)); + auto result = out1[i] + out2[i]; - if (i == params.alpha) { - EXPECT_EQ(result, params.beta); + if (i == alpha.GetVal()) { + EXPECT_EQ(result, beta); } else { - EXPECT_EQ(result, 0); + EXPECT_EQ(result.GetVal(), 0); } } } -INSTANTIATE_TEST_SUITE_P(Works_Instances, FssDpfGenTest, - testing::Values(TestParams{1, 1, 2, 1}, // - TestParams{1, 2, 2, 4}, // - TestParams{1, 2, 2, 8}, // - TestParams{3, 5, 4, 16}, // - TestParams{1, 2, 4, 32}, // - TestParams{1, 2, 8, 64})); - -INSTANTIATE_TEST_SUITE_P(Works_Instances, FssDpfEvalTest, - testing::Values(TestParams{1, 1, 2, 1}, // - TestParams{1, 2, 2, 4}, // - TestParams{1, 2, 2, 8}, // - TestParams{3, 5, 4, 16}, // - TestParams{1, 2, 4, 32}, // - TestParams{1, 2, 8, 64})); - -INSTANTIATE_TEST_SUITE_P(Works_Instances, FssDpfEvalAllTest, - testing::Values(TestParams{1, 1, 2, 1}, // - TestParams{1, 1, 4, 1}, // - TestParams{1, 1, 6, 1}, // - TestParams{1, 1, 8, 1}, // - TestParams{1, 2, 10, 4}, // - TestParams{1, 2, 12, 8}, // - TestParams{3, 5, 14, 16})); - } // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/ge2n.h b/yacl/crypto/experimental/dpf/ge2n.h new file mode 100644 index 00000000..641a8796 --- /dev/null +++ b/yacl/crypto/experimental/dpf/ge2n.h @@ -0,0 +1,104 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Group Element in 2n + +#pragma once + +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/base/int128.h" + +namespace yacl::crypto { + +template ::value, int> = 0> +class GE2n { + public: + // Constructors + GE2n() { static_assert(N <= sizeof(StoreTy) * 8); } + explicit GE2n(StoreTy value) { store_ = value; } + + // Get the N-bit truncated value + StoreTy GetVal() const { return store_ & kMask_; } + + // Get the N-bit mask + StoreTy GetMask() const { return kMask_; } + + // Get the bit num of group + size_t GetN() const { return N; } + + // Get the i-th least significant bit + uint8_t GetBit(size_t i) const { + YACL_ENFORCE(i < sizeof(StoreTy) * 8, "GetBit: index out of range"); + return store_ >> i & 1; + } + + // Reverse a group element inplace + void ReverseInplace() { store_ = kMask_ - GetVal() + 1; } + + // Get the reversed group element + GE2n GetReverse() const { + return GE2n(kMask_ - GetVal() + 1); + } + + // supported operators +#define GE2N_OVERLOAD_BINARY_OP(OP) \ + [[nodiscard]] GE2n operator OP(GE2n other) const { \ + return GE2n(this->store_ OP other.store_); \ + } + + GE2N_OVERLOAD_BINARY_OP(+) + GE2N_OVERLOAD_BINARY_OP(-) +#undef GE2N_OVERLOAD_BINARY_OP + + void operator+=(GE2n other) { this->store_ += other.store_; } + + void operator-=(GE2n other) { this->store_ -= other.store_; } + + [[nodiscard]] bool operator==(GE2n other) const { + return GetVal() == other.GetVal(); + } + + [[nodiscard]] bool operator!=(GE2n other) const { + return GetVal() != other.GetVal(); + } + + [[nodiscard]] bool operator>(GE2n other) const { + return GetVal() > other.GetVal(); + } + + [[nodiscard]] bool operator>=(GE2n other) const { + return GetVal() >= other.GetVal(); + } + + [[nodiscard]] bool operator<(GE2n other) const { + return GetVal() < other.GetVal(); + } + + [[nodiscard]] bool operator<=(GE2n other) const { + return GetVal() <= other.GetVal(); + } + + private: + static constexpr StoreTy kMask_ = + N == 128 ? Uint128Max() : (StoreTy(1) << N) - 1; + StoreTy store_; +}; + +} // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/pprf.cc b/yacl/crypto/experimental/dpf/pprf.cc new file mode 100644 index 00000000..bc42e9b7 --- /dev/null +++ b/yacl/crypto/experimental/dpf/pprf.cc @@ -0,0 +1,155 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/experimental/dpf/pprf.h" + +#include + +#include + +#include "spdlog/spdlog.h" + +#include "yacl/base/exception.h" +#include "yacl/base/int128.h" +#include "yacl/crypto/tools/prg.h" + +namespace yacl::crypto { + +namespace { +void GgmPrg(uint128_t in, uint128_t* out1, uint128_t* out2) { + if (out1 != nullptr) { + FillPRand(SymmetricCrypto::CryptoType::AES128_CTR, in, 0, 0, (char*)out1, + sizeof(uint128_t)); + } + if (out2 != nullptr) { + FillPRand(SymmetricCrypto::CryptoType::AES128_CTR, in, 0, 1, (char*)out2, + sizeof(uint128_t)); + } +} + +[[maybe_unused]] void GgmFullExpand(absl::Span working_span) { + const size_t num = working_span.size(); + if (num > 1) { + GgmPrg(working_span[0], &working_span[0], &working_span[num / 2]); + GgmFullExpand(working_span.subspan(0, num / 2)); + GgmFullExpand(working_span.subspan(num / 2, num / 2)); + } else { + // SPDLOG_INFO(working_span[0]); /* for debug */ + } +} + +template +void GgmExpandAndPunc(absl::Span working_span, GE2n punc_point, + PprfPuncKey* out) { + const size_t num = working_span.size(); // total number of levels + const size_t i = M - log2(num); // current level, starting from 0 + if (num > 1) { + GgmPrg(working_span[0], &working_span[0], &working_span[num / 2]); + if (!punc_point.GetBit(i)) { /* 0 means left */ + GgmExpandAndPunc(working_span.subspan(0, num / 2), punc_point, out); + out->seeds.insert({i, working_span[num / 2]}); + } else { + out->seeds.insert({i, working_span[0]}); + GgmExpandAndPunc(working_span.subspan(num / 2, num / 2), punc_point, out); + } + } else { + // SPDLOG_INFO("({}, {})", i, working_span[0]); /* for debug */ + } +} + +} // namespace + +template +void PprfPunc(uint128_t prf_key, GE2n punc_point, PprfPuncKey* out) { + static_assert(M <= 64); + auto m = M; // m is a runtime var + auto num = (m == 64) ? std::numeric_limits::max() : 1 << m; + + std::vector working_vec(num); + working_vec[0] = prf_key; + GgmExpandAndPunc(absl::MakeSpan(working_vec), punc_point, out); + out->punc_point = punc_point.GetVal(); +} + +template +void PprfPuncEval(const PprfPuncKey& punc_key, GE2n point, GE2n* out) { + static_assert(M <= 64); + GE2n punc_point(punc_key.punc_point); + YACL_ENFORCE( + punc_point != point, + "You cannot evaluate the already-punctured point with PprfPuncEval!"); + + bool is_same = true; + bool retrived = false; + uint128_t current_seed = 0; + for (size_t i = 0; i < M; ++i) { + is_same &= point.GetBit(i) == punc_point.GetBit(i); + if (!is_same) { + if (!retrived) { + current_seed = punc_key.seeds.at(i); + retrived = true; + } else { + if (!point.GetBit(i)) { /* 0 means left */ + GgmPrg(current_seed, ¤t_seed, nullptr); + } else { + GgmPrg(current_seed, nullptr, ¤t_seed); + } + // SPDLOG_INFO(current_seed); + } + } + } + *out = GE2n(current_seed); +} + +template +void PprfEval(uint128_t prf_key, GE2n point, GE2n* out) { + static_assert(M <= 64); + + uint128_t current_seed = prf_key; + for (size_t i = 0; i < M; ++i) { + if (!point.GetBit(i)) { /* 0 means left */ + GgmPrg(current_seed, ¤t_seed, nullptr); + } else { + GgmPrg(current_seed, nullptr, ¤t_seed); + } + } + *out = GE2n(current_seed); +} + +// template specification for different M and N +// +#define PPRF_T_SPECIFY_FUNC(M, N) \ + template void PprfPunc(uint128_t prf_key, GE2n punc_point, \ + PprfPuncKey * out); \ + template void PprfPuncEval(const PprfPuncKey& punc_key, GE2n point, \ + GE2n* out); \ + template void PprfEval(uint128_t prf_key, GE2n point, GE2n * out); + +PPRF_T_SPECIFY_FUNC(64, 64) +PPRF_T_SPECIFY_FUNC(32, 64) +PPRF_T_SPECIFY_FUNC(16, 64) +PPRF_T_SPECIFY_FUNC(8, 64) +PPRF_T_SPECIFY_FUNC(4, 64) +PPRF_T_SPECIFY_FUNC(2, 64) + +PPRF_T_SPECIFY_FUNC(64, 128) +PPRF_T_SPECIFY_FUNC(32, 128) +PPRF_T_SPECIFY_FUNC(16, 128) +PPRF_T_SPECIFY_FUNC(8, 128) +PPRF_T_SPECIFY_FUNC(4, 128) +PPRF_T_SPECIFY_FUNC(2, 128) + +#undef PPRF_T_SPECIFY_FUNC + +} // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/pprf.h b/yacl/crypto/experimental/dpf/pprf.h new file mode 100644 index 00000000..656279f4 --- /dev/null +++ b/yacl/crypto/experimental/dpf/pprf.h @@ -0,0 +1,70 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "yacl/base/int128.h" +#include "yacl/crypto/experimental/dpf/ge2n.h" + +/* submodules */ +#include "yacl/crypto/tools/prg.h" + +namespace yacl::crypto { + +// Puncturable Psedu-Random Function (PPRF) +// +// NOTE: this algorithm is experimental, also, this implemention is not +// *Private* Puncturable Pseu-Random Function +// +struct PprfPuncKey { + uint128_t punc_point; // NOTE: PPRF does not protect the punctured index + std::unordered_map seeds; +}; + +// PPRF punctured key generation function. On input the prf_key and the +// punc_point (it is supposed to a positive integer that is smaller than M), +// outputs a PprfPuncKey. +template +void PprfPunc(uint128_t prf_key, GE2n punc_point, PprfPuncKey* out); + +template +void PprfPuncEval(const PprfPuncKey& punc_key, GE2n point, GE2n* out); + +template +void PprfEval(uint128_t prf_key, GE2n point, GE2n* out); + +// --------------------------------------------------------------------------- +// PPRF with uint128_t support, the validation of point values will be checked +// by the constructor of GE2n, there is no need to perform additional +// check. +// --------------------------------------------------------------------------- + +template +void PprfPunc(uint128_t prf_key, uint128_t punc_point, PprfPuncKey* out) { + PprfPunc(prf_key, GE2n(punc_point), out); +} + +template +void PprfPuncEval(const PprfPuncKey& punc_key, uint128_t point, GE2n* out) { + PprfPuncEval(punc_key, GE2n(point), out); +} + +template +void PprfEval(uint128_t prf_key, uint128_t point, GE2n* out) { + PprfEval(prf_key, GE2n(point), out); +} + +} // namespace yacl::crypto diff --git a/yacl/crypto/experimental/dpf/pprf_test.cc b/yacl/crypto/experimental/dpf/pprf_test.cc new file mode 100644 index 00000000..a43c6a39 --- /dev/null +++ b/yacl/crypto/experimental/dpf/pprf_test.cc @@ -0,0 +1,49 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/experimental/dpf/pprf.h" + +#include "gtest/gtest.h" + +#include "yacl/base/exception.h" +#include "yacl/crypto/rand/rand.h" + +namespace yacl::crypto { + +TEST(PprfTest, Works) { + /* GIVEN */ + constexpr size_t M = 2; + constexpr size_t N = 128; + + size_t num = 1 << M; + uint128_t punc_point = RandInRange(M); + auto prf_key = SecureRandSeed(); + PprfPuncKey punc_key; + PprfPunc(prf_key, punc_point, &punc_key); + + GE2n out1; + GE2n out2; + for (size_t i = 0; i < num; ++i) { + if (i != punc_point) { + /* WHEN */ + PprfEval(prf_key, i, &out1); + PprfPuncEval(punc_key, i, &out2); + + /* THEN */ + EXPECT_EQ(out1.GetVal(), out2.GetVal()); + } + } +} + +} // namespace yacl::crypto diff --git a/yacl/crypto/rand/drbg/native_factory.cc b/yacl/crypto/rand/drbg/native_factory.cc index ed6197c3..019c7313 100644 --- a/yacl/crypto/rand/drbg/native_factory.cc +++ b/yacl/crypto/rand/drbg/native_factory.cc @@ -308,4 +308,7 @@ Buffer Sm4Drbg::Generate(size_t len, ByteContainerView additional_input) { } } // namespace internal + +REGISTER_DRBG_LIBRARY("NativeImpl", 100, NativeDrbg::Check, NativeDrbg::Create); + } // namespace yacl::crypto diff --git a/yacl/crypto/rand/drbg/native_factory.h b/yacl/crypto/rand/drbg/native_factory.h index 0f821aa9..a3c358a5 100644 --- a/yacl/crypto/rand/drbg/native_factory.h +++ b/yacl/crypto/rand/drbg/native_factory.h @@ -142,6 +142,4 @@ class NativeDrbg : public Drbg { std::unique_ptr drbg_impl_; }; -REGISTER_DRBG_LIBRARY("NativeImpl", 100, NativeDrbg::Check, NativeDrbg::Create); - } // namespace yacl::crypto diff --git a/yacl/crypto/rand/drbg/openssl_factory.cc b/yacl/crypto/rand/drbg/openssl_factory.cc index 62715cd9..fa507dc5 100644 --- a/yacl/crypto/rand/drbg/openssl_factory.cc +++ b/yacl/crypto/rand/drbg/openssl_factory.cc @@ -127,4 +127,7 @@ void OpensslDrbg::ReSeed() { /* prediction resistance flag */ 1, nullptr, 0) > 0); } + +REGISTER_DRBG_LIBRARY("OpenSSL", 100, OpensslDrbg::Check, OpensslDrbg::Create); + } // namespace yacl::crypto diff --git a/yacl/crypto/rand/drbg/openssl_factory.h b/yacl/crypto/rand/drbg/openssl_factory.h index 4ba087aa..e6ccc5c2 100644 --- a/yacl/crypto/rand/drbg/openssl_factory.h +++ b/yacl/crypto/rand/drbg/openssl_factory.h @@ -74,6 +74,4 @@ class OpensslDrbg : public Drbg { openssl::UniqueRandCtx ctx_; }; -REGISTER_DRBG_LIBRARY("OpenSSL", 100, OpensslDrbg::Check, OpensslDrbg::Create); - } // namespace yacl::crypto diff --git a/yacl/crypto/rand/rand.cc b/yacl/crypto/rand/rand.cc index 7d4152e3..3089aacb 100644 --- a/yacl/crypto/rand/rand.cc +++ b/yacl/crypto/rand/rand.cc @@ -111,7 +111,7 @@ std::vector SecureRandBytes(uint64_t len) { return RandBytes(len, false); } -#define IMPL_RANDBIT_DYNAMIC_BIT_TYPE(T) \ +#define SPECIFY_RANDBIT_TEMPLATE(T) \ template <> \ dynamic_bitset RandBits>(uint64_t len, \ bool fast_mode) { \ @@ -133,9 +133,11 @@ std::vector SecureRandBytes(uint64_t len) { return out; \ } -IMPL_RANDBIT_DYNAMIC_BIT_TYPE(uint128_t); -IMPL_RANDBIT_DYNAMIC_BIT_TYPE(uint64_t); -IMPL_RANDBIT_DYNAMIC_BIT_TYPE(uint32_t); -IMPL_RANDBIT_DYNAMIC_BIT_TYPE(uint16_t); +SPECIFY_RANDBIT_TEMPLATE(uint128_t); +SPECIFY_RANDBIT_TEMPLATE(uint64_t); +SPECIFY_RANDBIT_TEMPLATE(uint32_t); +SPECIFY_RANDBIT_TEMPLATE(uint16_t); + +#undef SPECIFY_RANDBIT_TEMPLATE } // namespace yacl::crypto diff --git a/yacl/crypto/rand/rand.h b/yacl/crypto/rand/rand.h index 5c2c351f..72ee592a 100644 --- a/yacl/crypto/rand/rand.h +++ b/yacl/crypto/rand/rand.h @@ -81,6 +81,7 @@ inline uint32_t RandInRange(uint32_t n) { uint32_t tmp = FastRandU64(); return tmp % n; } + // ----------------------------- // Random Support for Yacl Types // ----------------------------- diff --git a/yacl/io/circuit/bristol_fashion.cc b/yacl/io/circuit/bristol_fashion.cc index eb447866..1ae35715 100644 --- a/yacl/io/circuit/bristol_fashion.cc +++ b/yacl/io/circuit/bristol_fashion.cc @@ -104,8 +104,7 @@ void CircuitReader::ReadAllGates() { YACL_ENFORCE(absl::SimpleAtoi(splits[1], &circ_->gates[i].now)); /* it's okay to have more columns, but we'll stick with the niw and now */ - YACL_ENFORCE(splits.size() >= - circ_->gates[i].niw + circ_->gates[i].now + 2); + YACL_ENFORCE(splits.size() > circ_->gates[i].niw + circ_->gates[i].now + 2); circ_->gates[i].iw.resize(circ_->gates[i].niw); circ_->gates[i].ow.resize(circ_->gates[i].now); @@ -119,21 +118,25 @@ void CircuitReader::ReadAllGates() { /* check gate inputs num and op */ auto op_str = splits[circ_->gates[i].niw + circ_->gates[i].now + 2]; - YACL_ENFORCE(circ_->gates[i].now == 1); if (op_str == "XOR") { + YACL_ENFORCE(circ_->gates[i].now == 1); YACL_ENFORCE(circ_->gates[i].niw == 2); circ_->gates[i].op = BFCircuit::Op::XOR; } else if (op_str == "AND") { + YACL_ENFORCE(circ_->gates[i].now == 1); YACL_ENFORCE(circ_->gates[i].niw == 2); circ_->gates[i].op = BFCircuit::Op::AND; } else if (op_str == "INV") { + YACL_ENFORCE(circ_->gates[i].now == 1); YACL_ENFORCE(circ_->gates[i].niw == 1); circ_->gates[i].op = BFCircuit::Op::INV; } else if (op_str == "EQ") { + YACL_ENFORCE(circ_->gates[i].now == 1); YACL_ENFORCE(circ_->gates[i].niw == 1); circ_->gates[i].op = BFCircuit::Op::EQ; } else if (op_str == "EQW") { + YACL_ENFORCE(circ_->gates[i].now == 1); YACL_ENFORCE(circ_->gates[i].niw == 1); circ_->gates[i].op = BFCircuit::Op::EQW; } else if (op_str == "MAND") { diff --git a/yacl/kernel/algorithms/BUILD.bazel b/yacl/kernel/algorithms/BUILD.bazel index 766386d5..723a8d1f 100644 --- a/yacl/kernel/algorithms/BUILD.bazel +++ b/yacl/kernel/algorithms/BUILD.bazel @@ -51,7 +51,7 @@ yacl_cc_library( "//yacl/crypto/tools:ro", "//yacl/link", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", "@simplest_ot//:simplest_ot_x86_asm", ], ) @@ -227,7 +227,7 @@ yacl_cc_library( "//yacl/kernel/type:ot_store_utils", "//yacl/link", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", "//yacl/utils:cuckoo_index", ], ) @@ -259,7 +259,7 @@ yacl_cc_library( "//yacl/crypto/tools:rp", "//yacl/kernel/type:ot_store_utils", "//yacl/link", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", "//yacl/utils:matrix_utils", ], ) @@ -290,7 +290,7 @@ yacl_cc_library( "//yacl/crypto/tools:rp", "//yacl/kernel/type:ot_store_utils", "//yacl/link", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", "//yacl/utils:matrix_utils", ] + select({ "@platforms//cpu:aarch64": [ @@ -326,7 +326,7 @@ yacl_cc_library( "//yacl/kernel/algorithms:sgrr_ote", "//yacl/kernel/type:ot_store_utils", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", ], ) @@ -339,7 +339,7 @@ yacl_cc_test( "//yacl/crypto/rand", "//yacl/link:test_util", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", ], ) @@ -356,7 +356,7 @@ yacl_cc_library( "//yacl/kernel/algorithms:softspoken_ote", "//yacl/kernel/type:ot_store_utils", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", "//yacl/utils:serialize", ], ) @@ -370,7 +370,7 @@ yacl_cc_test( "//yacl/crypto/rand", "//yacl/link:test_util", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", ], ) @@ -389,7 +389,7 @@ yacl_cc_library( "//yacl/kernel/algorithms:mpfss", "//yacl/kernel/type:ot_store_utils", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", "//yacl/utils:serialize", ], ) @@ -403,7 +403,7 @@ yacl_cc_test( "//yacl/crypto/rand", "//yacl/link:test_util", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", ], ) @@ -437,6 +437,6 @@ yacl_cc_test( "//yacl/crypto/rand", "//yacl/link:test_util", "//yacl/math:gadget", - "//yacl/math/f2k", + "//yacl/math/galois_field:gf_intrinsic", ], ) diff --git a/yacl/kernel/algorithms/base_vole.h b/yacl/kernel/algorithms/base_vole.h index 7b54dab3..0bf0ba85 100644 --- a/yacl/kernel/algorithms/base_vole.h +++ b/yacl/kernel/algorithms/base_vole.h @@ -16,8 +16,8 @@ #include "yacl/base/exception.h" #include "yacl/base/int128.h" -#include "yacl/math/f2k/f2k_utils.h" #include "yacl/math/gadget.h" +#include "yacl/math/galois_field/gf_intrinsic.h" /* submodules */ #include "yacl/kernel/algorithms/softspoken_ote.h" @@ -50,9 +50,9 @@ void inline Ot2VoleSend(OtSendStore& send_ot, absl::Span w) { std::array w_buff; std::array basis; if (std::is_same::value) { - memcpy(basis.data(), gf128_basis.data(), T_bits * sizeof(K)); + memcpy(basis.data(), math::kGf128Basis().data(), T_bits * sizeof(K)); } else if (std::is_same::value) { - memcpy(basis.data(), gf64_basis.data(), T_bits * sizeof(K)); + memcpy(basis.data(), math::kGf64Basis().data(), T_bits * sizeof(K)); } else { YACL_THROW("VoleSend Error!"); } @@ -81,9 +81,9 @@ void inline Ot2VoleRecv(OtRecvStore& recv_ot, absl::Span u, std::array v_buff; std::array basis; if (std::is_same::value) { - memcpy(basis.data(), gf128_basis.data(), T_bits * sizeof(K)); + memcpy(basis.data(), math::kGf128Basis().data(), T_bits * sizeof(K)); } else if (std::is_same::value) { - memcpy(basis.data(), gf64_basis.data(), T_bits * sizeof(K)); + memcpy(basis.data(), math::kGf64Basis().data(), T_bits * sizeof(K)); } else { YACL_THROW("VoleSend Error!"); } diff --git a/yacl/kernel/algorithms/base_vole_test.cc b/yacl/kernel/algorithms/base_vole_test.cc index 901fca11..279b6a45 100644 --- a/yacl/kernel/algorithms/base_vole_test.cc +++ b/yacl/kernel/algorithms/base_vole_test.cc @@ -25,8 +25,8 @@ #include "yacl/base/int128.h" #include "yacl/crypto/rand/rand.h" #include "yacl/link/test_util.h" -#include "yacl/math/f2k/f2k.h" #include "yacl/math/gadget.h" +#include "yacl/math/galois_field/gf_intrinsic.h" namespace yacl::crypto { diff --git a/yacl/kernel/algorithms/ferret_ote_rn.h b/yacl/kernel/algorithms/ferret_ote_rn.h index fbb86cdb..92f70544 100644 --- a/yacl/kernel/algorithms/ferret_ote_rn.h +++ b/yacl/kernel/algorithms/ferret_ote_rn.h @@ -20,7 +20,7 @@ #include "yacl/crypto/hash/hash_utils.h" #include "yacl/crypto/tools/common.h" -#include "yacl/math/f2k/f2k_utils.h" +#include "yacl/math/galois_field/gf_intrinsic.h" #include "yacl/secparam.h" /* submodules */ @@ -89,7 +89,7 @@ inline void MpCotRNSend(const std::shared_ptr& ctx, for (size_t i = 0; i < 128; ++i) { check_cot_data[i] = check_cot.GetBlock(i, choices[i]); } - auto diff = PackGf128(absl::MakeSpan(check_cot_data)); + auto diff = math::Gf128Pack(absl::MakeSpan(check_cot_data)); uhash = uhash ^ diff; auto hash = Blake3(SerializeUint128(uhash)); @@ -142,7 +142,7 @@ inline void MpCotRNRecv(const std::shared_ptr& ctx, uint128_t choices = check_cot.CopyBitBuf().data()[0]; auto check_cot_data = check_cot.CopyBlkBuf(); - auto diff = PackGf128(absl::MakeSpan(check_cot_data)); + auto diff = math::Gf128Pack(absl::MakeSpan(check_cot_data)); uhash = uhash ^ diff; // find punctured indexes diff --git a/yacl/kernel/algorithms/kos_ote.cc b/yacl/kernel/algorithms/kos_ote.cc index 20824bc7..8cb65868 100644 --- a/yacl/kernel/algorithms/kos_ote.cc +++ b/yacl/kernel/algorithms/kos_ote.cc @@ -22,7 +22,7 @@ #include "yacl/base/byte_container_view.h" #include "yacl/base/int128.h" #include "yacl/crypto/tools/common.h" -#include "yacl/math/f2k/f2k.h" +#include "yacl/math/galois_field/gf_intrinsic.h" #include "yacl/utils/matrix_utils.h" #include "yacl/utils/serialize.h" @@ -165,7 +165,7 @@ void KosOtExtSend(const std::shared_ptr& ctx, for (size_t k = 0; k < kKappa; ++k) { auto k_msg_span = absl::MakeSpan( reinterpret_cast(ot_ext[k].data()), 2 * batch_num); - q_check[k] = GfMul64(absl::MakeSpan(rand_samples), k_msg_span); + q_check[k] = math::Gf64Mul(absl::MakeSpan(rand_samples), k_msg_span); } CheckMsg check_msgs; @@ -270,10 +270,10 @@ void KosOtExtRecv(const std::shared_ptr& ctx, // =================== CONSISTENCY CHECK =================== auto choice_span = absl::MakeSpan( reinterpret_cast(choice_ext.data()), batch_num * 2); - check_msgs.x = GfMul64(absl::MakeSpan(rand_samples), choice_span); + check_msgs.x = math::Gf64Mul(absl::MakeSpan(rand_samples), choice_span); for (size_t k = 0; k < kKappa; ++k) { - check_msgs.t[k] = GfMul64( + check_msgs.t[k] = math::Gf64Mul( absl::MakeSpan(rand_samples), absl::MakeSpan(reinterpret_cast(ot_ext.first[k].data()), batch_num * 2)); diff --git a/yacl/kernel/algorithms/mp_vole.h b/yacl/kernel/algorithms/mp_vole.h index dfa5b039..2f530610 100644 --- a/yacl/kernel/algorithms/mp_vole.h +++ b/yacl/kernel/algorithms/mp_vole.h @@ -22,8 +22,8 @@ #include "yacl/crypto/tools/common.h" #include "yacl/kernel/algorithms/mpfss.h" #include "yacl/kernel/type/ot_store_utils.h" -#include "yacl/math/f2k/f2k_utils.h" #include "yacl/math/gadget.h" +#include "yacl/math/galois_field/gf_intrinsic.h" #include "yacl/secparam.h" YACL_MODULE_DECLARE("mp_vole", SecParam::C::INF, SecParam::S::INF); diff --git a/yacl/kernel/algorithms/mp_vole_test.cc b/yacl/kernel/algorithms/mp_vole_test.cc index fe5f4277..d27d4493 100644 --- a/yacl/kernel/algorithms/mp_vole_test.cc +++ b/yacl/kernel/algorithms/mp_vole_test.cc @@ -24,8 +24,8 @@ #include "yacl/base/exception.h" #include "yacl/crypto/rand/rand.h" #include "yacl/link/test_util.h" -#include "yacl/math/f2k/f2k.h" #include "yacl/math/gadget.h" +#include "yacl/math/galois_field/gf_intrinsic.h" namespace yacl::crypto { diff --git a/yacl/kernel/algorithms/mpfss.cc b/yacl/kernel/algorithms/mpfss.cc index 00b01707..30682b73 100644 --- a/yacl/kernel/algorithms/mpfss.cc +++ b/yacl/kernel/algorithms/mpfss.cc @@ -23,8 +23,8 @@ #include "yacl/crypto/tools/crhash.h" #include "yacl/kernel/algorithms/gywz_ote.h" #include "yacl/kernel/algorithms/sgrr_ote.h" -#include "yacl/math/f2k/f2k.h" #include "yacl/math/gadget.h" +#include "yacl/math/galois_field/gf_intrinsic.h" namespace yacl::crypto { diff --git a/yacl/kernel/algorithms/mpfss_test.cc b/yacl/kernel/algorithms/mpfss_test.cc index b746f1af..4c6d9ace 100644 --- a/yacl/kernel/algorithms/mpfss_test.cc +++ b/yacl/kernel/algorithms/mpfss_test.cc @@ -25,8 +25,8 @@ #include "yacl/base/exception.h" #include "yacl/crypto/rand/rand.h" #include "yacl/link/test_util.h" -#include "yacl/math/f2k/f2k.h" #include "yacl/math/gadget.h" +#include "yacl/math/galois_field/gf_intrinsic.h" namespace yacl::crypto { diff --git a/yacl/kernel/algorithms/silent_vole_test.cc b/yacl/kernel/algorithms/silent_vole_test.cc index ca2d114b..039dbbf0 100644 --- a/yacl/kernel/algorithms/silent_vole_test.cc +++ b/yacl/kernel/algorithms/silent_vole_test.cc @@ -25,8 +25,8 @@ #include "yacl/base/int128.h" #include "yacl/link/test_util.h" -#include "yacl/math/f2k/f2k.h" #include "yacl/math/gadget.h" +#include "yacl/math/galois_field/gf_intrinsic.h" namespace yacl::crypto { diff --git a/yacl/kernel/algorithms/softspoken_ote.cc b/yacl/kernel/algorithms/softspoken_ote.cc index f861ff6e..80c76ba5 100644 --- a/yacl/kernel/algorithms/softspoken_ote.cc +++ b/yacl/kernel/algorithms/softspoken_ote.cc @@ -25,7 +25,7 @@ #include "yacl/base/exception.h" #include "yacl/crypto/tools/common.h" #include "yacl/kernel/type/ot_store_utils.h" -#include "yacl/math/f2k/f2k.h" +#include "yacl/math/galois_field/gf_intrinsic.h" #include "yacl/utils/matrix_utils.h" #include "yacl/utils/serialize.h" @@ -727,7 +727,7 @@ void SoftspokenOtExtSender::Send( CheckMsg check_msgs; for (size_t i = 0; i < all_batch_num; ++i) { for (size_t k = 0; k < kKappa; ++k) { - check_msgs.t[k] ^= ClMul64( + check_msgs.t[k] ^= math::Gf64ClMul( absl::MakeSpan(rand_samples.data() + i * 2, 2), absl::MakeSpan(reinterpret_cast(allV[i].data() + k), 2)); } @@ -736,7 +736,7 @@ void SoftspokenOtExtSender::Send( CheckMsg msgs; std::array check_vals; for (size_t k = 0; k < kKappa; ++k) { - check_vals[k] = Reduce64(check_msgs.t[k]); + check_vals[k] = math::Gf64Reduce(check_msgs.t[k]); } msgs.Unpack(ctx->Recv(ctx->NextRank(), fmt::format("MAL-SS-CHECK-FINAL"))); @@ -850,7 +850,7 @@ void SoftspokenOtExtSender::Send(const std::shared_ptr& ctx, CheckMsg check_msgs; for (size_t i = 0; i < all_batch_num; ++i) { for (size_t k = 0; k < kKappa; ++k) { - check_msgs.t[k] ^= ClMul64( + check_msgs.t[k] ^= math::Gf64ClMul( absl::MakeSpan(rand_samples.data() + i * 2, 2), absl::MakeSpan(reinterpret_cast(allV[i].data() + k), 2)); } @@ -859,7 +859,7 @@ void SoftspokenOtExtSender::Send(const std::shared_ptr& ctx, CheckMsg msgs; std::array check_vals; for (size_t k = 0; k < kKappa; ++k) { - check_vals[k] = Reduce64(check_msgs.t[k]); + check_vals[k] = math::Gf64Reduce(check_msgs.t[k]); } msgs.Unpack(ctx->Recv(ctx->NextRank(), fmt::format("MAL-SS-CHECK-FINAL"))); @@ -965,20 +965,20 @@ void SoftspokenOtExtReceiver::Recv(const std::shared_ptr& ctx, CheckMsg check_msgs; auto choice_span = absl::MakeSpan( reinterpret_cast(choice_ext.data()), all_batch_num * 2); - check_msgs.x ^= ClMul64(absl::MakeSpan(rand_samples), choice_span); + check_msgs.x ^= math::Gf64ClMul(absl::MakeSpan(rand_samples), choice_span); for (size_t i = 0; i < all_batch_num; ++i) { for (size_t k = 0; k < kKappa; ++k) { - check_msgs.t[k] ^= ClMul64( + check_msgs.t[k] ^= math::Gf64ClMul( absl::MakeSpan(rand_samples.data() + i * 2, 2), absl::MakeSpan(reinterpret_cast(allW[i].data() + k), 2)); } } CheckMsg msgs; - msgs.x = Reduce64(check_msgs.x); + msgs.x = math::Gf64Reduce(check_msgs.x); for (size_t k = 0; k < kKappa; ++k) { - msgs.t[k] = Reduce64(check_msgs.t[k]); + msgs.t[k] = math::Gf64Reduce(check_msgs.t[k]); } auto buf = msgs.Pack(); ctx->SendAsync(ctx->NextRank(), buf, fmt::format("MAL-SS-CHECK-FINAL")); @@ -1087,20 +1087,20 @@ void SoftspokenOtExtReceiver::Recv(const std::shared_ptr& ctx, CheckMsg check_msgs; auto choice_span = absl::MakeSpan( reinterpret_cast(choice_ext.data()), all_batch_num * 2); - check_msgs.x ^= ClMul64(absl::MakeSpan(rand_samples), choice_span); + check_msgs.x ^= math::Gf64ClMul(absl::MakeSpan(rand_samples), choice_span); for (size_t i = 0; i < all_batch_num; ++i) { for (size_t k = 0; k < kKappa; ++k) { - check_msgs.t[k] ^= ClMul64( + check_msgs.t[k] ^= math::Gf64ClMul( absl::MakeSpan(rand_samples.data() + i * 2, 2), absl::MakeSpan(reinterpret_cast(allW[i].data() + k), 2)); } } CheckMsg msgs; - msgs.x = Reduce64(check_msgs.x); + msgs.x = math::Gf64Reduce(check_msgs.x); for (size_t k = 0; k < kKappa; ++k) { - msgs.t[k] = Reduce64(check_msgs.t[k]); + msgs.t[k] = math::Gf64Reduce(check_msgs.t[k]); } auto buf = msgs.Pack(); ctx->SendAsync(ctx->NextRank(), buf, fmt::format("MAL-SS-CHECK-FINAL")); diff --git a/yacl/math/f2k/BUILD.bazel b/yacl/math/f2k/BUILD.bazel deleted file mode 100644 index 5f309b2a..00000000 --- a/yacl/math/f2k/BUILD.bazel +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//bazel:yacl.bzl", "AES_COPT_FLAGS", "yacl_cc_binary", "yacl_cc_library", "yacl_cc_test") - -package(default_visibility = ["//visibility:public"]) - -yacl_cc_library( - name = "f2k", - hdrs = [ - "f2k.h", - "f2k_utils.h", - ], - deps = [ - "//yacl/base:aligned_vector", - "//yacl/base:block", - "//yacl/base:exception", - "//yacl/base:int128", - "//yacl/math:gadget", - "@com_google_absl//absl/types:span", - ], -) - -yacl_cc_test( - name = "f2k_test", - srcs = ["f2k_test.cc"], - copts = AES_COPT_FLAGS, - deps = [ - ":f2k", - "//yacl/crypto/rand", - ], -) - -yacl_cc_binary( - name = "f2k_bench", - srcs = ["f2k_bench.cc"], - copts = AES_COPT_FLAGS, - deps = [ - ":f2k", - "//yacl/crypto/rand", - "@com_github_google_benchmark//:benchmark_main", - ], -) diff --git a/yacl/math/f2k/f2k.h b/yacl/math/f2k/f2k.h deleted file mode 100644 index 5fc14a13..00000000 --- a/yacl/math/f2k/f2k.h +++ /dev/null @@ -1,317 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include - -#include "yacl/base/block.h" -#include "yacl/base/exception.h" -#include "yacl/base/int128.h" - -// Galois Field GF(2^n) implmentation -// (As of now, only support GF(2^64) & GF(2^128)) -// -// Galois Field GF(2^n) could be viewed as GF(2)[X]/(P), -// where P is an irreducible polynomial in GF(2)[X] of degree n. -// -// To achieve multiplication over GF(2^n): -// 1. Perform polynomial multiplication over GF(2)[X], as known as, carry-less -// multiplication. -// 2. Reduce the product modulo the irreducible polynomial. -// -// For example, in GF(2^8) = GF(2)[X]/(x^8+x^4+x^3+x^2+x+1) -// -// x^7 x^6 x^5 x^4 x^3 x^2 x^1 x^0 x^7 x^6 x^5 x^4 x^3 x^2 x^1 x^0 -// 1 0 0 0 0 1 1 1 * 0 0 0 0 0 0 1 0 -// (carry-less multiplication) -// x^8 x^7 x^6 x^5 x^4 x^3 x^2 x^1 x^0 -// = 1 0 0 0 0 1 1 1 0 -// (reducing by x^8+x^4+x^3+x^2+x+1) -// x^7 x^6 x^5 x^4 x^3 x^2 x^1 x^0 -// = 0 0 0 1 0 1 0 1 -// -// For more information, -// 1. properties about GF(2^n) -// https://engineering.purdue.edu/kak/compsec/NewLectures/Lecture7.pdf -// 2. binary irreducible polynomials: -// https://www.hpl.hp.com/techreports/98/HPL-98-135.pdf - -namespace yacl { - -// Irreducible Polynomials of degree 128 and 64. -constexpr uint64_t kGfMod128 = (1 << 7) | (1 << 2) | (1 << 1) | 1; -constexpr uint64_t kGfMod64 = (1 << 4) | (1 << 3) | (1 << 1) | 1; - -// carry-less multiplication over Z_{2^128} -// ref: -// https://www.intel.com/content/dam/develop/external/us/en/documents/clmul-wp-rev-2-02-2014-04-20.pdf -// Figure 5 or Algorithm 1 -inline std::pair ClMul128(block x, block y) { - block low = _mm_clmulepi64_si128(x, y, 0x00); // low 64 of x, low 64 of y - block high = _mm_clmulepi64_si128(x, y, 0x11); // low 64 of x, low 64 of y - - block mid1 = _mm_clmulepi64_si128(x, y, 0x10); // low 64 of x, high 64 of y - block mid2 = _mm_clmulepi64_si128(x, y, 0x01); // high 64 of x, low 64 of y - block mid = _mm_xor_si128(mid1, mid2); - - mid1 = _mm_srli_si128(mid, 8); // mid1 = mid >> 64 - mid2 = _mm_slli_si128(mid, 8); // mid2 = mid << 64 - - high = _mm_xor_si128(high, mid1); // high ^ (mid >> 64) - low = _mm_xor_si128(low, mid2); // low ^ (mid << 64) - - return std::make_pair(high, low); -} - -inline std::pair ClMul128(uint128_t x, uint128_t y) { - auto [high, low] = ClMul128(block(x), block(y)); - return std::make_pair(toU128(high), toU128(low)); -} - -inline block Reduce128(block high, block low) { - const block modulo = block(kGfMod128); - - auto [upper, carry0] = ClMul128(high, modulo); - low = _mm_xor_si128(low, carry0); - - auto [zero, carry1] = ClMul128(upper, modulo); - low = _mm_xor_si128(low, carry1); - return low; -} - -inline uint128_t Reduce128(uint128_t high, uint128_t low) { - return toU128(Reduce128(block(high), block(low))); -} - -// multiplication over Galois Field F_{2^128} -inline block GfMul128(block x, block y) { - auto [high, low] = ClMul128(x, y); - return Reduce128(high, low); -} - -inline uint128_t GfMul128(uint128_t x, uint128_t y) { - return toU128(GfMul128(block(x), block(y))); -} - -// carry-less multiplication over Z_{2^64} -// ref: -// https://github.com/scipr-lab/libff/blob/9769030a06b7ab933d6c064db120019decd359f1/libff/algebra/fields/binary/gf64.cpp#L62 -inline uint128_t ClMul64(uint64_t x, uint64_t y) { - block rb = _mm_clmulepi64_si128(_mm_loadl_epi64((const __m128i*)&(x)), - _mm_loadl_epi64((const __m128i*)&(y)), 0x00); - return toU128(rb); -} - -inline uint64_t Reduce64(uint128_t x) { - const block modulo = block(0, kGfMod64); - auto xb = block(x); - - // low 64 of modulo, high 64 of x - // output is 96 bits, since modulo < 2^32 - auto temp = _mm_clmulepi64_si128(modulo, xb, 0x10); - xb = _mm_xor_si128(xb, temp); - - // low 64 of modulo, high 64 of temp - // output is 64 bits, since modulo < 2^32 && high 64 of temp < 2^32 - temp = _mm_clmulepi64_si128(modulo, temp, 0x10); - xb = _mm_xor_si128(xb, temp); - return xb.as()[0]; // low 64 bit -} - -// multiplication over Galois Field F_{2^64} -inline uint64_t GfMul64(uint64_t x, uint64_t y) { - return Reduce64(ClMul64(x, y)); -} - -// inverse over Galois Field F_{2^64} -inline uint64_t GfInv64(uint64_t x) { - uint64_t t0 = x; - uint64_t t1 = GfMul64(t0, t0); - uint64_t t2 = GfMul64(t1, t0); - t0 = GfMul64(t2, t2); - t0 = GfMul64(t0, t0); - t1 = GfMul64(t1, t0); - t2 = GfMul64(t2, t0); - t0 = GfMul64(t2, t2); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t1 = GfMul64(t1, t0); - t2 = GfMul64(t2, t0); - t0 = GfMul64(t2, t2); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t1 = GfMul64(t1, t0); - t2 = GfMul64(t2, t0); - t0 = GfMul64(t2, t2); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t0 = GfMul64(t0, t0); - t1 = GfMul64(t1, t0); - t0 = GfMul64(t0, t2); - for (int i = 0; i < 32; i++) { - t0 = GfMul64(t0, t0); - } - t0 = GfMul64(t0, t1); - return t0; -} - -// inverse over Galois Field F_{2^128} -inline uint128_t GfInv128(uint128_t x) { - uint128_t t0 = GfMul128(x, x); - uint128_t t1 = t0; - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - for (int i = 0; i < 60; i++) { - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - t0 = GfMul128(t0, t0); - t1 = GfMul128(t1, t0); - } - return t1; -} - -// Inner product -inline std::pair ClMul128(absl::Span x, - absl::Span y) { - YACL_ENFORCE(x.size() == y.size()); - - const uint64_t size = x.size(); - block ret_high = 0; - block ret_low = 0; - - for (uint64_t i = 0; i < size; ++i) { - auto [high, low] = ClMul128(block(x[i]), block(y[i])); - ret_high = _mm_xor_si128(ret_high, high); - ret_low = _mm_xor_si128(ret_low, low); - } - return std::make_pair(toU128(ret_high), toU128(ret_low)); -} - -inline uint128_t GfMul128(absl::Span x, - absl::Span y) { - YACL_ENFORCE(x.size() == y.size()); - auto [high, low] = ClMul128(x, y); - return Reduce128(high, low); -} - -inline uint128_t ClMul64(absl::Span x, - absl::Span y) { - YACL_ENFORCE(x.size() == y.size()); - - const uint64_t size = x.size(); - block ret = 0; - - uint64_t i = 0; - for (; i + 1 < size; i += 2) { - // pack - block xb = block(x[i + 1], x[i]); - block yb = block(y[i + 1], y[i]); - // low 64 of xb, low 64 of yb, x[i] * y[i] - block xy0 = _mm_clmulepi64_si128(xb, yb, 0x00); - // high 64 of xb, high 64 of yb, x[i+1] * y[i+1] - block xy1 = _mm_clmulepi64_si128(xb, yb, 0x11); - // xor - ret = _mm_xor_si128(ret, xy0); - ret = _mm_xor_si128(ret, xy1); - } - - for (; i < size; ++i) { - auto temp = block(ClMul64(x[i], y[i])); - ret = _mm_xor_si128(ret, temp); - } - - return toU128(ret); -} - -inline uint64_t GfMul64(absl::Span x, - absl::Span y) { - YACL_ENFORCE(x.size() == y.size()); - return Reduce64(ClMul64(x, y)); -} - -// As of now, f2k only support GF(2^128) and GF(2^64) -// TODO: @wenfan implement GF(2^k) -// // Reduce Z_{2^128} to Galois Field F_{2^k} -// uint64_t Reduce(uint128_t x,uint64_t k); -// // multiplication over Galois Field F_{2^k} -// uint64_t GfMul(uint64_t x, uint64_t y, uint64_t k); - -inline std::array GenGf128Basis() { - std::array basis = {0}; - uint128_t one = yacl::MakeUint128(0, 1); - for (size_t i = 0; i < 128; ++i) { - basis[i] = one << i; - } - return basis; -} - -inline std::array GenGf64Basis() { - std::array basis = {0}; - uint128_t one = yacl::MakeUint128(0, 1); - for (size_t i = 0; i < 64; ++i) { - basis[i] = one << i; - } - return basis; -} - -static std::array gf64_basis = GenGf64Basis(); -static std::array gf128_basis = GenGf128Basis(); - -inline uint128_t PackGf128(absl::Span data) { - const size_t size = data.size(); - YACL_ENFORCE(size <= 128); - // inner product - return GfMul128(data, absl::MakeSpan(gf128_basis.data(), size)); -} - -inline uint64_t PackGf64(absl::Span data) { - const size_t size = data.size(); - YACL_ENFORCE(size <= 64); - // inner product - return GfMul64(data, absl::MakeSpan(gf64_basis.data(), size)); -} -}; // namespace yacl diff --git a/yacl/math/f2k/f2k_bench.cc b/yacl/math/f2k/f2k_bench.cc deleted file mode 100644 index e4603eed..00000000 --- a/yacl/math/f2k/f2k_bench.cc +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "benchmark/benchmark.h" - -#include "yacl/crypto/rand/rand.h" -#include "yacl/math/f2k/f2k.h" - -static void BM_ClMul128_block(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - state.ResumeTiming(); - for (size_t i = 0; i < n; i++) { - benchmark::DoNotOptimize(yacl::ClMul128(x[i], y[i])); - } - } -} - -static void BM_GfMul128_block(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - state.ResumeTiming(); - for (size_t i = 0; i < n; i++) { - benchmark::DoNotOptimize(yacl::GfMul128(x[i], y[i])); - } - } -} - -static void BM_ClMul128(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - state.ResumeTiming(); - for (size_t i = 0; i < n; i++) { - benchmark::DoNotOptimize(yacl::ClMul128(x[i], y[i])); - } - } -} - -static void BM_GfMul128(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - state.ResumeTiming(); - for (size_t i = 0; i < n; i++) { - benchmark::DoNotOptimize(yacl::GfMul128(x[i], y[i])); - } - } -} - -static void BM_GfMul128_inner_product(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - auto x_span = absl::MakeSpan(x); - auto y_span = absl::MakeSpan(y); - - state.ResumeTiming(); - yacl::GfMul128(x_span, y_span); - } -} - -static void BM_ClMul64(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - state.ResumeTiming(); - for (size_t i = 0; i < n; i++) { - benchmark::DoNotOptimize(yacl::ClMul64(x[i], y[i])); - } - } -} - -static void BM_GfMul64(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - state.ResumeTiming(); - for (size_t i = 0; i < n; i++) { - benchmark::DoNotOptimize(yacl::GfMul64(x[i], y[i])); - } - } -} - -static void BM_GfMul64_inner_product(benchmark::State& state) { - for (auto _ : state) { - state.PauseTiming(); - size_t n = state.range(0); - - auto x = yacl::crypto::RandVec(n); - auto y = yacl::crypto::RandVec(n); - - auto x_span = absl::MakeSpan(x); - auto y_span = absl::MakeSpan(y); - - state.ResumeTiming(); - yacl::GfMul64(x_span, y_span); - } -} - -uint64_t g_interations = 10; - -BENCHMARK(BM_ClMul128_block) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); - -BENCHMARK(BM_GfMul128_block) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); - -BENCHMARK(BM_ClMul128) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); - -BENCHMARK(BM_GfMul128) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); - -BENCHMARK(BM_GfMul128_inner_product) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); - -BENCHMARK(BM_ClMul64) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); - -BENCHMARK(BM_GfMul64) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); - -BENCHMARK(BM_GfMul64_inner_product) - ->Unit(benchmark::kMillisecond) - ->Iterations(g_interations) - ->Arg(1 << 20) - ->Arg(1 << 21) - ->Arg(1 << 22) - ->Arg(1 << 23) - ->Arg(1 << 24) - ->Arg(1 << 25); diff --git a/yacl/math/f2k/f2k_utils.h b/yacl/math/f2k/f2k_utils.h deleted file mode 100644 index 90f42e27..00000000 --- a/yacl/math/f2k/f2k_utils.h +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "yacl/base/aligned_vector.h" -#include "yacl/math/f2k/f2k.h" -#include "yacl/math/gadget.h" - -namespace yacl::math { - -// ------------------------ -// f2k-field operation -// ------------------------ - -// inner-product -uint128_t inline GfMul(absl::Span a, - absl::Span b) { - return GfMul128(a, b); -} - -uint64_t inline GfMul(absl::Span a, - absl::Span b) { - return GfMul64(a, b); -} - -uint128_t inline GfMul(absl::Span a, - absl::Span b) { - UninitAlignedVector tmp(b.size()); - std::transform(b.cbegin(), b.cend(), tmp.begin(), [](const uint64_t& val) { - return static_cast(val); - }); - return GfMul128(a, absl::MakeSpan(tmp)); -} - -uint128_t inline GfMul(absl::Span a, - absl::Span b) { - return GfMul(b, a); -} - -// element-wise -uint128_t inline GfMul(uint128_t a, uint128_t b) { return GfMul128(a, b); } - -uint64_t inline GfMul(uint64_t a, uint64_t b) { return GfMul64(a, b); } - -uint128_t inline GfMul(uint128_t a, uint64_t b) { - return GfMul128(a, static_cast(b)); -} - -uint128_t inline GfMul(uint64_t a, uint128_t b) { - return GfMul128(static_cast(a), b); -} - -// ------------------------ -// f2k-Universal Hash -// ------------------------ - -// see difference between universal hash and collision-resistent hash functions: -// https://crypto.stackexchange.com/a/88247/61581 -template -T UniversalHash(T seed, absl::Span data) { - T ret = 0; - for_each(data.rbegin(), data.rend(), [&ret, &seed](const T& val) { - ret ^= val; - ret = GfMul(seed, ret); - }); - return ret; -} - -template -std::vector ExtractHashCoef(T seed, - absl::Span indexes /*sorted*/) { - std::array buff = {}; - auto max_bits = math::Log2Ceil(indexes.back()); - buff[0] = seed; - for (size_t i = 1; i <= max_bits; ++i) { - buff[i] = GfMul(buff[i - 1], buff[i - 1]); - } - - std::vector ret; - for (const auto& index : indexes) { - auto index_plus_one = index + 1; - uint64_t mask = 1; - T coef = 1; - for (size_t i = 0; i < 64 && mask <= index_plus_one; ++i) { - if (mask & index_plus_one) { - coef = GfMul(coef, buff[i]); - } - mask <<= 1; - } - ret.push_back(coef); - } - return ret; -} - -} // namespace yacl::math \ No newline at end of file diff --git a/yacl/math/galois_field/BUILD.bazel b/yacl/math/galois_field/BUILD.bazel index 01935daf..5aad8e4a 100644 --- a/yacl/math/galois_field/BUILD.bazel +++ b/yacl/math/galois_field/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") +load("//bazel:yacl.bzl", "AES_COPT_FLAGS", "yacl_cc_library", "yacl_cc_test") yacl_cc_library( name = "gf", @@ -43,3 +43,32 @@ yacl_cc_test( ":gf", ], ) + +yacl_cc_library( + name = "gf_intrinsic", + srcs = [ + "gf_intrinsic.cc", + ], + hdrs = [ + "gf_intrinsic.h", + ], + copts = AES_COPT_FLAGS, + visibility = ["//visibility:public"], + deps = [ + "//yacl/base:aligned_vector", + "//yacl/base:block", + "//yacl/base:exception", + "//yacl/base:int128", + "//yacl/math:gadget", + "@com_google_absl//absl/types:span", + ], +) + +yacl_cc_test( + name = "gf_intrinsic_test", + srcs = ["gf_intrinsic_test.cc"], + deps = [ + ":gf_intrinsic", + "//yacl/crypto/rand", + ], +) diff --git a/yacl/math/galois_field/factory/gf_vector.h b/yacl/math/galois_field/factory/gf_vector.h index cc44efe6..bba967cb 100644 --- a/yacl/math/galois_field/factory/gf_vector.h +++ b/yacl/math/galois_field/factory/gf_vector.h @@ -83,54 +83,54 @@ class GFVectorizedSketch : public GaloisField { virtual std::vector DeserializeT(ByteContainerView buffer) const = 0; private: -#define DefineUnaryFunc(FuncName) \ +#define DefineVecUnaryFunc(FuncName) \ auto FuncName(const Item& x) const override { \ return FuncName(x.AsSpan()); \ } -#define DefineUnaryInplaceFunc(FuncName) \ +#define DefineVecUnaryInplaceFunc(FuncName) \ void FuncName(Item* x) const override { return FuncName(x->AsSpan()); } -#define DefineBinaryFunc(FuncName) \ +#define DefineVecBinaryFunc(FuncName) \ auto FuncName(const Item& x, const Item& y) const override { \ return FuncName(x.AsSpan(), y.AsSpan()); \ } -#define DefineBinaryInplaceFunc(FuncName) \ +#define DefineVecBinaryInplaceFunc(FuncName) \ void FuncName(Item* x, const Item& y) const override { \ FuncName(x->AsSpan(), y.AsSpan()); \ } // if x is scalar, returns bool // if x is vectored, returns std::vector - DefineUnaryFunc(IsIdentityOne); - DefineUnaryFunc(IsIdentityZero); - DefineUnaryFunc(IsInField); - DefineBinaryFunc(Equal); + DefineVecUnaryFunc(IsIdentityOne); + DefineVecUnaryFunc(IsIdentityZero); + DefineVecUnaryFunc(IsInField); + DefineVecBinaryFunc(Equal); //==================================// // operations defined on field // //==================================// // get the additive inverse −a for all elements in set - DefineUnaryFunc(Neg); - DefineUnaryInplaceFunc(NegInplace); + DefineVecUnaryFunc(Neg); + DefineVecUnaryInplaceFunc(NegInplace); // get the multiplicative inverse 1/b for every nonzero element in set - DefineUnaryFunc(Inv); - DefineUnaryInplaceFunc(InvInplace); + DefineVecUnaryFunc(Inv); + DefineVecUnaryInplaceFunc(InvInplace); - DefineBinaryFunc(Add); - DefineBinaryInplaceFunc(AddInplace); + DefineVecBinaryFunc(Add); + DefineVecBinaryInplaceFunc(AddInplace); - DefineBinaryFunc(Sub); - DefineBinaryInplaceFunc(SubInplace); + DefineVecBinaryFunc(Sub); + DefineVecBinaryInplaceFunc(SubInplace); - DefineBinaryFunc(Mul); - DefineBinaryInplaceFunc(MulInplace); + DefineVecBinaryFunc(Mul); + DefineVecBinaryInplaceFunc(MulInplace); - DefineBinaryFunc(Div); - DefineBinaryInplaceFunc(DivInplace); + DefineVecBinaryFunc(Div); + DefineVecBinaryInplaceFunc(DivInplace); virtual Item Pow(const Item& x, const MPInt& y) const { return Pow(x.AsSpan(), y); @@ -148,11 +148,11 @@ class GFVectorizedSketch : public GaloisField { // I/O // //================================// - DefineUnaryFunc(DeepCopy); + DefineVecUnaryFunc(DeepCopy); // To human-readable string - DefineUnaryFunc(ToString); - DefineUnaryFunc(Serialize); + DefineVecUnaryFunc(ToString); + DefineVecUnaryFunc(Serialize); // serialize field element(s) to already allocated buffer. // if buf is nullptr, then calc serialize size only diff --git a/yacl/math/galois_field/gf_intrinsic.cc b/yacl/math/galois_field/gf_intrinsic.cc new file mode 100644 index 00000000..bf51fe58 --- /dev/null +++ b/yacl/math/galois_field/gf_intrinsic.cc @@ -0,0 +1,332 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/math/galois_field/gf_intrinsic.h" + +namespace yacl::math { + +// ---------------------------------- +// GF 128 +// ---------------------------------- + +void Gf128Mul(uint128_t x, uint128_t y, uint128_t* out) { + block temp; + Gf128Mul(block(x), block(y), &temp); + *out = toU128(temp); +} + +void Gf128Mul(block x, block y, block* out) { + block high; + block low; + Gf128ClMul(x, y, &high, &low); + Gf128Reduce(high, low, out); +} + +void Gf128Mul(absl::Span x, absl::Span y, + uint128_t* out) { + YACL_ENFORCE(x.size() == y.size()); + uint128_t high; + uint128_t low; + Gf128ClMul(x, y, &high, &low); + Gf128Reduce(high, low, out); +} + +void Gf128ClMul(uint128_t x, uint128_t y, uint128_t* out1, uint128_t* out2) { + block high; + block low; + Gf128ClMul(block(x), block(y), &high, &low); + *out1 = toU128(high); + *out2 = toU128(low); +} + +void Gf128ClMul(block x, block y, block* out1, block* out2) { + block low = _mm_clmulepi64_si128(x, y, 0x00); // low 64 of x, low 64 of y + block high = _mm_clmulepi64_si128(x, y, 0x11); // low 64 of x, low 64 of y + + block mid1 = _mm_clmulepi64_si128(x, y, 0x10); // low 64 of x, high 64 of y + block mid2 = _mm_clmulepi64_si128(x, y, 0x01); // high 64 of x, low 64 of y + block mid = _mm_xor_si128(mid1, mid2); + + mid1 = _mm_srli_si128(mid, 8); // mid1 = mid >> 64 + mid2 = _mm_slli_si128(mid, 8); // mid2 = mid << 64 + + *out1 = _mm_xor_si128(high, mid1); // high ^ (mid >> 64) + *out2 = _mm_xor_si128(low, mid2); // low ^ (mid << 64) +} + +void Gf128ClMul(absl::Span x, absl::Span y, + uint128_t* out1, uint128_t* out2) { + YACL_ENFORCE(x.size() == y.size()); + + const uint64_t size = x.size(); + block ret_high = 0; + block ret_low = 0; + + for (uint64_t i = 0; i < size; ++i) { + block high; + block low; + Gf128ClMul(block(x[i]), block(y[i]), &high, &low); + ret_high = _mm_xor_si128(ret_high, high); + ret_low = _mm_xor_si128(ret_low, low); + } + *out1 = toU128(ret_high); + *out2 = toU128(ret_low); +} + +void Gf128Reduce(block high, block low, block* out) { + const block modulo = block(kGfMod128); + block upper; + block carry0; + block carry1; + block zero; + Gf128ClMul(high, modulo, &upper, &carry0); + low = _mm_xor_si128(low, carry0); + Gf128ClMul(upper, modulo, &zero, &carry1); + *out = _mm_xor_si128(low, carry1); +} + +void Gf128Reduce(uint128_t high, uint128_t low, uint128_t* out) { + block temp; + Gf128Reduce(block(high), block(low), &temp); + *out = toU128(temp); +} + +void Gf128Pack(absl::Span data, uint128_t* out) { + const size_t size = data.size(); + YACL_ENFORCE(size <= 128); + Gf128Mul(data, absl::MakeSpan(kGf128Basis().data(), size), out); +} + +// ---------------------------------- +// GF 64 +// ---------------------------------- + +void Gf64Mul(uint64_t x, uint64_t y, uint64_t* out) { + uint128_t temp; + Gf64ClMul(x, y, &temp); + Gf64Reduce(temp, out); +} + +void Gf64Mul(absl::Span x, absl::Span y, + uint64_t* out) { + YACL_ENFORCE(x.size() == y.size()); + uint128_t temp; + Gf64ClMul(x, y, &temp); + Gf64Reduce(temp, out); +} + +void Gf64ClMul(uint64_t x, uint64_t y, uint128_t* out) { + *out = + toU128(_mm_clmulepi64_si128(_mm_loadl_epi64((const __m128i*)&(x)), + _mm_loadl_epi64((const __m128i*)&(y)), 0x00)); +} + +void Gf64ClMul(absl::Span x, absl::Span y, + uint128_t* out) { + YACL_ENFORCE(x.size() == y.size()); + + const uint64_t size = x.size(); + block ret = 0; + + uint64_t i = 0; + for (; i + 1 < size; i += 2) { + // pack + block xb = block(x[i + 1], x[i]); + block yb = block(y[i + 1], y[i]); + // low 64 of xb, low 64 of yb, x[i] * y[i] + block xy0 = _mm_clmulepi64_si128(xb, yb, 0x00); + // high 64 of xb, high 64 of yb, x[i+1] * y[i+1] + block xy1 = _mm_clmulepi64_si128(xb, yb, 0x11); + // xor + ret = _mm_xor_si128(ret, xy0); + ret = _mm_xor_si128(ret, xy1); + } + + for (; i < size; ++i) { + uint128_t temp; + Gf64ClMul(x[i], y[i], &temp); + ret = _mm_xor_si128(ret, block(temp)); + } + + *out = toU128(ret); +} + +void Gf64Reduce(uint128_t x, uint64_t* out) { + const block modulo = block(0, kGfMod64); + auto xb = block(x); + + // low 64 of modulo, high 64 of x + // output is 96 bits, since modulo < 2^32 + auto temp = _mm_clmulepi64_si128(modulo, xb, 0x10); + xb = _mm_xor_si128(xb, temp); + + // low 64 of modulo, high 64 of temp + // output is 64 bits, since modulo < 2^32 && high 64 of temp < 2^32 + temp = _mm_clmulepi64_si128(modulo, temp, 0x10); + xb = _mm_xor_si128(xb, temp); + *out = xb.as()[0]; // low 64 bit +} + +void Gf64Inv(uint64_t x, uint64_t* out) { + uint64_t t0 = x; + uint64_t t1; + uint64_t t2; + Gf64Mul(t0, t0, &t1); + Gf64Mul(t1, t0, &t2); + Gf64Mul(t2, t2, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t1, t0, &t1); + Gf64Mul(t2, t0, &t2); + Gf64Mul(t2, t2, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t1, t0, &t1); + Gf64Mul(t2, t0, &t2); + Gf64Mul(t2, t2, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t1, t0, &t1); + Gf64Mul(t2, t0, &t2); + Gf64Mul(t2, t2, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t0, t0, &t0); + Gf64Mul(t1, t0, &t1); + Gf64Mul(t0, t2, &t0); + for (int i = 0; i < 32; i++) { + Gf64Mul(t0, t0, &t0); + } + Gf64Mul(t0, t1, out); +} + +// ------------------------ +// GF Function Alias +// ------------------------ + +uint64_t Gf64Mul(uint64_t x, uint64_t y) { + uint64_t ret; + Gf64Mul(x, y, &ret); + return ret; +} + +uint64_t Gf64Mul(absl::Span x, absl::Span y) { + uint64_t ret; + Gf64Mul(x, y, &ret); + return ret; +} + +uint64_t Gf64Pack(absl::Span data) { + uint64_t ret; + Gf64Pack(data, &ret); + return ret; +} +void Gf64Pack(absl::Span data, uint64_t* out) { + const size_t size = data.size(); + YACL_ENFORCE(size <= 64); + Gf64Mul(data, absl::MakeSpan(kGf64Basis().data(), size), out); +} + +uint64_t Gf64Inv(uint64_t x) { + uint64_t ret; + Gf64Inv(x, &ret); + return ret; +} + +uint64_t Gf64Reduce(uint128_t x) { + uint64_t ret; + Gf64Reduce(x, &ret); + return ret; +} + +uint128_t Gf64ClMul(uint64_t x, uint64_t y) { + uint128_t ret; + Gf64ClMul(x, y, &ret); + return ret; +} + +uint128_t Gf64ClMul(absl::Span x, + absl::Span y) { + uint128_t ret; + Gf64ClMul(x, y, &ret); + return ret; +} + +uint128_t Gf128Mul(uint128_t x, uint128_t y) { + uint128_t ret; + Gf128Mul(x, y, &ret); + return ret; +} + +block Gf128Mul(block x, block y) { + block ret; + Gf128Mul(x, y, &ret); + return ret; +} + +uint128_t Gf128Reduce(uint128_t high, uint128_t low) { + uint128_t ret; + Gf128Reduce(high, low, &ret); + return ret; +} + +block Gf128Reduce(block high, block low) { + block ret; + Gf128Reduce(high, low, &ret); + return ret; +} + +uint128_t Gf128Mul(absl::Span x, + absl::Span y) { + uint128_t ret; + Gf128Mul(x, y, &ret); + return ret; +} + +uint128_t Gf128Pack(absl::Span data) { + uint128_t ret; + Gf128Pack(data, &ret); + return ret; +} + +uint128_t GfMul(absl::Span a, absl::Span b) { + UninitAlignedVector tmp(b.size()); + std::transform(b.cbegin(), b.cend(), tmp.begin(), [](const uint64_t& val) { + return static_cast(val); + }); + return Gf128Mul(a, absl::MakeSpan(tmp)); +} + +uint128_t GfMul(absl::Span a, absl::Span b) { + return GfMul(b, a); +} + +} // namespace yacl::math diff --git a/yacl/math/galois_field/gf_intrinsic.h b/yacl/math/galois_field/gf_intrinsic.h new file mode 100644 index 00000000..5f533d6b --- /dev/null +++ b/yacl/math/galois_field/gf_intrinsic.h @@ -0,0 +1,190 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "yacl/base/aligned_vector.h" +#include "yacl/base/block.h" +#include "yacl/base/exception.h" +#include "yacl/base/int128.h" +#include "yacl/math/gadget.h" + +namespace yacl::math { + +// Galois Field GF(2^n) implmentation +// (As of now, only support GF(2^64) & GF(2^128)) +// +// Galois Field GF(2^n) could be viewed as GF(2)[X]/(P), +// where P is an irreducible polynomial in GF(2)[X] of degree n. +// +// NOTE To achieve multiplication over GF(2^n): +// 1. Perform polynomial multiplication over GF(2)[X], as known as, carry-less +// multiplication. +// 2. Reduce the product modulo the irreducible polynomial. + +// Irreducible Polynomials of degree 128 and 64. +constexpr uint64_t kGfMod128 = (1 << 7) | (1 << 2) | (1 << 1) | 1; +constexpr uint64_t kGfMod64 = (1 << 4) | (1 << 3) | (1 << 1) | 1; + +constexpr auto kGf64Basis = []() constexpr { + std::array basis = {0}; + uint128_t one = yacl::MakeUint128(0, 1); + for (size_t i = 0; i < 64; ++i) { + basis[i] = one << i; + } + return basis; +}; + +constexpr auto kGf128Basis = []() constexpr { + std::array basis = {0}; + uint128_t one = yacl::MakeUint128(0, 1); + for (size_t i = 0; i < 128; ++i) { + basis[i] = one << i; + } + return basis; +}; + +// ---------------------------------- +// GF 128 +// ---------------------------------- +void Gf128Mul(uint128_t x, uint128_t y, uint128_t* out); +uint128_t Gf128Mul(uint128_t x, uint128_t y); + +void Gf128Mul(block x, block y, block* out); +block Gf128Mul(block x, block y); + +void Gf128Mul(absl::Span x, absl::Span y, + uint128_t* out); +uint128_t Gf128Mul(absl::Span x, + absl::Span y); + +void Gf128ClMul(uint128_t x, uint128_t y, uint128_t* out1, uint128_t* out2); +void Gf128ClMul(block x, block y, block* out1, block* out2); +void Gf128ClMul(absl::Span x, absl::Span y, + uint128_t* out1, uint128_t* out2); + +void Gf128Reduce(uint128_t high, uint128_t low, uint128_t* out); +uint128_t Gf128Reduce(uint128_t high, uint128_t low); + +void Gf128Reduce(block high, block low, block* out); +block Gf128Reduce(block high, block low); + +void Gf128Pack(absl::Span data, uint128_t* out); +uint128_t Gf128Pack(absl::Span data); + +// ---------------------------------- +// GF 64 +// ---------------------------------- +void Gf64Mul(uint64_t x, uint64_t y, uint64_t* out); +uint64_t Gf64Mul(uint64_t x, uint64_t y); + +void Gf64Mul(absl::Span x, absl::Span y, + uint64_t* out); +uint64_t Gf64Mul(absl::Span x, absl::Span y); + +void Gf64ClMul(uint64_t x, uint64_t y, uint128_t* out); +uint128_t Gf64ClMul(uint64_t x, uint64_t y); + +void Gf64ClMul(absl::Span x, absl::Span y, + uint128_t* out); +uint128_t Gf64ClMul(absl::Span x, absl::Span y); + +void Gf64Reduce(uint128_t x, uint64_t* out); +uint64_t Gf64Reduce(uint128_t x); + +void Gf64Inv(uint64_t x, uint64_t* out); +uint64_t Gf64Inv(uint64_t x); + +void Gf64Pack(absl::Span data, uint64_t* out); +uint64_t Gf64Pack(absl::Span data); + +// ------------------------ +// Generic Multiplication +// ------------------------ + +inline uint128_t GfMul(uint128_t x, uint128_t y) { return Gf128Mul(x, y); } +inline uint128_t GfMul(absl::Span a, + absl::Span b) { + return Gf128Mul(a, b); +} + +inline uint64_t GfMul(uint64_t x, uint64_t y) { return Gf64Mul(x, y); } +inline uint64_t GfMul(absl::Span a, + absl::Span b) { + return Gf64Mul(a, b); +} + +// NOTE The subfield (a.k.a GF(2^64)) is mapped to the larger field (a.k.a +// GF(2^128)) to proceed with arithmatic operations. Therefore, all subfield ops +// such as multiplications and additions are defined in GF(2^128) +// +uint128_t GfMul(absl::Span a, absl::Span b); +uint128_t GfMul(absl::Span a, absl::Span b); +inline uint128_t GfMul(uint128_t a, uint64_t b) { + return Gf128Mul(a, MakeUint128(0, b)); +} + +inline uint128_t GfMul(uint64_t a, uint128_t b) { + return Gf128Mul(MakeUint128(0, a), b); +} + +// ------------------------ +// GF Universal Hash +// ------------------------ + +// see difference between universal hash and collision-resistent hash functions: +// https://crypto.stackexchange.com/a/88247/61581 +template +T UniversalHash(T seed, absl::Span data) { + T ret = 0; + for_each(data.rbegin(), data.rend(), [&ret, &seed](const T& val) { + ret ^= val; + ret = GfMul(seed, ret); + }); + return ret; +} + +template +std::vector ExtractHashCoef(T seed, + absl::Span indexes /*sorted*/) { + std::array buff = {}; + auto max_bits = math::Log2Ceil(indexes.back()); + buff[0] = seed; + for (size_t i = 1; i <= max_bits; ++i) { + buff[i] = GfMul(buff[i - 1], buff[i - 1]); + } + + std::vector ret; + for (const auto& index : indexes) { + auto index_plus_one = index + 1; + uint64_t mask = 1; + T coef = 1; + for (size_t i = 0; i < 64 && mask <= index_plus_one; ++i) { + if (mask & index_plus_one) { + coef = GfMul(coef, buff[i]); + } + mask <<= 1; + } + ret.push_back(coef); + } + return ret; +} + +} // namespace yacl::math diff --git a/yacl/math/f2k/f2k_test.cc b/yacl/math/galois_field/gf_intrinsic_test.cc similarity index 51% rename from yacl/math/f2k/f2k_test.cc rename to yacl/math/galois_field/gf_intrinsic_test.cc index 2331d9da..dfb66d02 100644 --- a/yacl/math/f2k/f2k_test.cc +++ b/yacl/math/galois_field/gf_intrinsic_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "yacl/math/f2k/f2k.h" +#include "yacl/math/galois_field/gf_intrinsic.h" #include #include @@ -25,6 +25,8 @@ #include "yacl/base/int128.h" #include "yacl/crypto/rand/rand.h" +namespace yacl::math { + namespace { template bool operator==(const std::pair& lhs, const std::pair& rhs) { @@ -43,54 +45,47 @@ std::pair operator^(const std::pair& lhs, } // check commutative property over F2k -#define MULTEST(type, range, x0, x1, y0, y1) \ - auto x = x0 ^ x1; \ - auto y = y0 ^ y1; \ - auto xy = yacl::type##Mul##range(x, y); \ - auto yx = yacl::type##Mul##range(y, x); \ - auto zero = xy ^ xy; \ - EXPECT_EQ(xy, yx); \ - auto xy0 = yacl::type##Mul##range(x, y0); \ - auto xy1 = yacl::type##Mul##range(x, y1); \ - EXPECT_EQ(xy, xy0 ^ xy1); \ - auto x0y = yacl::type##Mul##range(x0, y); \ - auto x1y = yacl::type##Mul##range(x1, y); \ - EXPECT_EQ(xy, x0y ^ x1y); \ - EXPECT_NE(zero, xy); \ - EXPECT_NE(zero, yx); +#define GF_MUL_TEST(FUNC, T) \ + { \ + auto x = yacl::crypto::RandVec(2); \ + auto y = yacl::crypto::RandVec(2); \ + auto x_sum = x[0] ^ x[1]; \ + auto y_sum = y[0] ^ y[1]; \ + T xy; \ + T yx; \ + { \ + FUNC(x_sum, y_sum, &xy); \ + FUNC(y_sum, x_sum, &yx); \ + EXPECT_EQ(xy, yx); \ + } \ + auto zero = xy ^ xy; \ + { \ + T xy0; \ + T xy1; \ + FUNC(x_sum, y[0], &xy0); \ + FUNC(x_sum, y[1], &xy1); \ + EXPECT_EQ(xy, xy0 ^ xy1); \ + } \ + { \ + T x0y; \ + T x1y; \ + FUNC(x[0], y_sum, &x0y); \ + FUNC(x[1], y_sum, &x1y); \ + EXPECT_EQ(xy, x0y ^ x1y); \ + EXPECT_NE(zero, xy); \ + EXPECT_NE(zero, yx); \ + } \ + } } // namespace -TEST(F2kTest, ClMul128_block) { - auto t = yacl::crypto::RandVec(4); - MULTEST(Cl, 128, t[0], t[1], t[2], t[3]); -} - -TEST(F2kTest, ClMul128) { - auto t = yacl::crypto::RandVec(4); - MULTEST(Cl, 128, t[0], t[1], t[2], t[3]); +TEST(GFTest, Mul128) { + GF_MUL_TEST(Gf128Mul, block); + GF_MUL_TEST(Gf128Mul, uint128_t); } -TEST(F2kTest, ClMul64) { - auto t = yacl::crypto::RandVec(4); - MULTEST(Cl, 64, t[0], t[1], t[2], t[3]); -} - -TEST(F2kTest, GfMul128_block) { - auto t = yacl::crypto::RandVec(4); - MULTEST(Gf, 128, t[0], t[1], t[2], t[3]); -} +TEST(GFTest, Mul64) { GF_MUL_TEST(Gf64Mul, uint64_t); } -TEST(F2kTest, GfMul128) { - auto t = yacl::crypto::RandVec(4); - MULTEST(Gf, 128, t[0], t[1], t[2], t[3]); -} - -TEST(F2kTest, GfMul64) { - auto t = yacl::crypto::RandVec(4); - MULTEST(Gf, 64, t[0], t[1], t[2], t[3]); -} - -TEST(F2kTest, GfMul128_inner_product) { +TEST(GFTest, Gf128_inner_product) { const uint64_t size = 1001; auto zero = uint128_t(0); @@ -98,19 +93,22 @@ TEST(F2kTest, GfMul128_inner_product) { auto y = yacl::crypto::RandVec(size); auto x_span = absl::MakeSpan(x); auto y_span = absl::MakeSpan(y); + uint128_t ret; - auto ret = yacl::GfMul128(x_span, y_span); + Gf128Mul(x_span, y_span, &ret); uint128_t check = 0; for (uint64_t i = 0; i < size; ++i) { - check ^= yacl::GfMul128(x[i], y[i]); + uint128_t temp; + Gf128Mul(x[i], y[i], &temp); + check ^= temp; } EXPECT_EQ(ret, check); EXPECT_NE(ret, zero); } -TEST(F2kTest, GfMul64_inner_product) { +TEST(GFTest, Gf64_inner_product) { const uint64_t size = 1001; uint64_t zero = 0; @@ -119,36 +117,31 @@ TEST(F2kTest, GfMul64_inner_product) { auto x_span = absl::MakeSpan(x); auto y_span = absl::MakeSpan(y); - auto ret = yacl::GfMul64(x_span, y_span); + uint64_t ret; + Gf64Mul(x_span, y_span, &ret); uint64_t check = 0; for (uint64_t i = 0; i < size; ++i) { - check ^= yacl::GfMul64(x[i], y[i]); + uint64_t temp; + Gf64Mul(x[i], y[i], &temp); + check ^= temp; } EXPECT_EQ(ret, check); EXPECT_NE(ret, zero); } -TEST(F2kTest, GfInv64_inner_product) { +TEST(GFTest, GfInv64_inner_product) { const uint64_t size = 1001; auto x = yacl::crypto::RandVec(size); for (uint64_t i = 0; i < size; ++i) { - auto inv = yacl::GfInv64(x[i]); - auto check = yacl::GfMul64(x[i], inv); + uint64_t x_inv; + Gf64Inv(x[i], &x_inv); + uint64_t check; + Gf64Mul(x[i], x_inv, &check); EXPECT_EQ(uint64_t(1), check); } } -// test for the inverse of 128-bit field -TEST(F2kTest, GfInv128_inner_product) { - const uint64_t size = 1001; - - auto x = yacl::crypto::RandVec(size); - for (uint128_t i = 0; i < size; ++i) { - auto inv = yacl::GfInv128(x[i]); - auto check = yacl::GfMul128(x[i], inv); - EXPECT_EQ(uint128_t(1), check); - } -} \ No newline at end of file +} // namespace yacl::math diff --git a/yacl/utils/serializer_adapter.h b/yacl/utils/serializer_adapter.h index 9460929e..d25028c0 100644 --- a/yacl/utils/serializer_adapter.h +++ b/yacl/utils/serializer_adapter.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #include "yacl/base/int128.h" #include "yacl/utils/serializer.h" diff --git a/yacl/utils/spi/argument/arg_k.h b/yacl/utils/spi/argument/arg_k.h index 816df18e..5e6a6857 100644 --- a/yacl/utils/spi/argument/arg_k.h +++ b/yacl/utils/spi/argument/arg_k.h @@ -27,7 +27,10 @@ class SpiArgKey { public: using ValueType = T; - explicit SpiArgKey(const std::string &key) : key_(util::ToSnakeCase(key)) {} + explicit SpiArgKey(const std::string &key) : key_(util::ToSnakeCase(key)) { + YACL_ENFORCE(!key_.empty(), "Empty arg name is not allowed. raw_key={}", + key); + } const std::string &Key() const & { return key_; } diff --git a/yacl/utils/spi/argument/arg_kv.h b/yacl/utils/spi/argument/arg_kv.h index 5cd9e7c9..90c74854 100644 --- a/yacl/utils/spi/argument/arg_kv.h +++ b/yacl/utils/spi/argument/arg_kv.h @@ -32,11 +32,14 @@ namespace yacl { class SpiArg { public: - explicit SpiArg(const std::string &key) : key_(util::ToSnakeCase(key)) {} + explicit SpiArg(const std::string &key) : key_(util::ToSnakeCase(key)) { + YACL_ENFORCE(!key_.empty(), "Arg key is empty. raw_key={}", key); + } // If value is a string, it will be automatically converted to lowercase template SpiArg(const std::string &key, T &&value) : key_(util::ToSnakeCase(key)) { + YACL_ENFORCE(!key_.empty(), "Arg key is empty. raw_key={}", key); operator=(std::forward(value)); }