diff --git a/yacl/math/mpint/tommath_ext_features.cc b/yacl/math/mpint/tommath_ext_features.cc index f148dddb..79df14fb 100644 --- a/yacl/math/mpint/tommath_ext_features.cc +++ b/yacl/math/mpint/tommath_ext_features.cc @@ -75,7 +75,7 @@ inline bool is_prime_candidate(const mp_int *p) { // Let P>1 be an integer, and suppose there exist natural numbers A and Q such // that // * A^{P-1} = 1 mod P -// * Q is prime, Q|N−1 and Q > sqrt(P) - 1 +// * Q is prime, Q|P−1 and Q > sqrt(P) - 1 // * gcd(A^{(P-1)/Q} - 1, P) = 1 // Then P is prime // @@ -134,7 +134,7 @@ void mp_ext_safe_prime_rand(mp_int *p, int t, int psize) { int maskOR_msb_offset; bool res; mp_int q; - uint64_t mod; + uint64_t mod, original_m; /* sanity check the input */ YACL_ENFORCE(psize > 1 && t > 0, "with psize={}, t={}", psize, t); @@ -177,16 +177,16 @@ void mp_ext_safe_prime_rand(mp_int *p, int t, int psize) { /* read it in */ /* TODO: casting only for now until all lengths have been changed to the * type "size_t"*/ - MPINT_ENFORCE_OK(mp_from_ubin(&q, tmp, (size_t)bsize)); + mp_ext_from_mag_bytes(&q, tmp, (size_t)bsize, Endian::big); // Find a odd number `q` among q, q+2, .... , (1 << 20) satisfy: // 1. co-prime to `small_primes`. // 2. `q = 1 mod 3` (p = 2q+1). - MPINT_ENFORCE_OK(mp_mod_d(&q, small_prime_prod, &mod)); + MPINT_ENFORCE_OK(mp_mod_d(&q, small_prime_prod, &original_m)); uint64_t last_delta = 0; for (uint64_t delta = 0; delta < (1 << 20); delta += 2) { - uint64_t m = mod + delta; + uint64_t m = original_m + delta; if (!is_co_prime(m, small_primes, std::size(small_primes))) { continue; } @@ -204,9 +204,6 @@ void mp_ext_safe_prime_rand(mp_int *p, int t, int psize) { MPINT_ENFORCE_OK(mp_mul_2(&q, p)); MPINT_ENFORCE_OK(mp_incr(p)); - if (mp_ext_count_bits_fast(*p) != psize) { - continue; - } if (is_prime_candidate(p)) { break; } @@ -217,16 +214,16 @@ void mp_ext_safe_prime_rand(mp_int *p, int t, int psize) { continue; } // test Pocklington Criterion - if (!is_pocklington_criterion_satisfied(&q)) { + if (!is_pocklington_criterion_satisfied(p)) { continue; } - // final check + // final check, if q is prime, + // then p is 100% prime since Pocklington is deterministic MPINT_ENFORCE_OK(mp_prime_is_prime(&q, t, &res)); - if (!res) { - continue; + if (res) { + return; } - MPINT_ENFORCE_OK(mp_prime_is_prime(p, t, &res)); - } while (!res); + } while (true); } void mp_ext_rand_bits(mp_int *out, int64_t bits) { @@ -469,9 +466,9 @@ uint8_t mp_ext_get_bit(const mp_int &a, int index) { void mp_ext_set_bit(mp_int *a, int index, uint8_t value) { int limb = index / MP_DIGIT_BIT; - if (limb > a->alloc) { + if (limb >= a->alloc) { MPINT_ENFORCE_OK(mp_grow(a, limb + 1)); - for (int i = a->used + 1; i <= limb; ++i) { + for (int i = a->used; i <= limb; ++i) { a->dp[i] = 0; } }