diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/ArrayShuffleBenchmark.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/ArrayShuffleBenchmark.java new file mode 100644 index 000000000..1b67d4117 --- /dev/null +++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/ArrayShuffleBenchmark.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.examples.jmh.sampling; + +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +/** + * Executes benchmark to compare the speed of shuffling an array. + * + *
Batched shuffle samples have been adapted from the blog post: + * + * Daniel Lemire: Faster random integer generation with batching. + * The samples provided in the blog and the referenced paper are for a 64-bit + * source of randomness which requires native support for 128-bit multiplication. + * These have been modified for a 32-bit source of randomness. + * + *
Note: The 32-bit based shuffle2 method has a size threshold of 2^15 + * (32768) for creating two samples from each 32-bit random value. + * Speed-up is most obvious for arrays below this size. + */ + @Param({"4", "16", "64", "256", "1024", "4096", "8192", "16384", "32768", "65536", "262148", "1048592"}) + private int size; + + /** The data. */ + private int[] data; + + /** + * @return the data + */ + public int[] getData() { + return data; + } + + /** + * Create the data. + */ + @Setup + public void setup() { + data = IntStream.range(0, size).toArray(); + } + } + + /** + * Defines the {@link RandomSource} for testing. + */ + @State(Scope.Benchmark) + public static class RngSource { + /** + * RNG providers. + * + *
Use different speeds.
+ * + * @see + * Commons RNG user guide + */ + @Param({"XO_RO_SHI_RO_128_PP", + //"MWC_256", + //"JDK" + }) + private String randomSourceName; + + /** RNG. */ + private UniformRandomProvider rng; + + /** + * Gets the source of randomness. + * + * @return RNG + */ + public UniformRandomProvider getRNG() { + return rng; + } + + /** + * Look-up the {@link RngSource} from the name and instantiates the generator. + */ + @Setup + public void setup() { + rng = RandomSource.valueOf(randomSourceName).create(); + } + } + + /** + * Defines the shuffle method. + */ + @State(Scope.Benchmark) + public static class ShuffleMethod { + /** + * Method name. + */ + @Param({"shuffle", "shuffle2"}) + private String method; + + /** Shuffle function. */ + private BiConsumerThe product bound can be any positive integer {@code >= range1*range2}. + * It may be updated to become {@code range1*range2}. + * + * @param range1 Range 1. + * @param range2 Range 2. + * @param productBound Product bound. + * @param rng Source of randomness. + * @return [i1, i2] + */ + static int[] randomBounded2(int range1, int range2, int[] productBound, UniformRandomProvider rng) { + long m = (rng.nextInt() & MASK_32) * range1; + // result1 and result2 are the top 32-bits of the long + long r1 = m; + // Leftover bits * range2 + m = (m & MASK_32) * range2; + long r2 = m; + // Leftover bits must be unsigned + long l = m & MASK_32; + if (l < productBound[0]) { + final int bound = range1 * range2; + productBound[0] = bound; + if (l < bound) { + // 2^32 % bound + long t = POW_32 % bound; + while (l < t) { + m = (rng.nextInt() & MASK_32) * range1; + r1 = m; + m = (m & MASK_32) * range2; + r2 = m; + l = m & MASK_32; + } + } + } + // Convert to [0, range1), [0, range2) + return new int[] {(int) (r1 >> 32), (int) (r2 >> 32)}; + } + + /** + * Shuffles the entries of the given array. + * + * @param rng Source of randomness. + * @param array Array whose entries will be shuffled (in-place). + * @return a reference to the given array + */ + static int[] shuffle2(UniformRandomProvider rng, int[] array) { + int i = array.length; + // The threshold provided in the Brackett-Rozinsky and Lemire paper + // is the power of 2 below 20724. Note that the product 2^15*2^15 + // is representable using signed integers. + for (; i > POW_15; i--) { + swap(array, i - 1, rng.nextInt(i)); + } + // Batches of 2 for sizes up to 2^15 elements + final int[] productBound = {i * (i - 1)}; + for (; i > 1; i -= 2) { + final int[] indices = randomBounded2(i, i - 1, productBound, rng); + final int index1 = indices[0]; + final int index2 = indices[1]; + swap(array, i - 1, index1); + swap(array, i - 2, index2); + } + return array; + } + + /** + * Performs a shuffle. + * + * @param data Shuffle data. + * @param source Source of randomness. + * @param method Shuffle method. + * @return the shuffled data + */ + @Benchmark + public Object shuffle(ShuffleData data, RngSource source, ShuffleMethod method) { + final int[] a = data.getData(); + method.getMethod().accept(source.getRNG(), a); + return a; + } +} diff --git a/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/ArrayShuffleBenchmarkTest.java b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/ArrayShuffleBenchmarkTest.java new file mode 100644 index 000000000..b2187a7fa --- /dev/null +++ b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/ArrayShuffleBenchmarkTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.examples.jmh.sampling; + +import org.apache.commons.math3.stat.inference.ChiSquareTest; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.ArraySampler; +import org.apache.commons.rng.sampling.PermutationSampler; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +/** + * Test for array shuffle samplers in the {@link ArrayShuffleBenchmark} class. + */ +class ArrayShuffleBenchmarkTest { + + /** + * The seed for the RNG used in the sampling tests. + * + *
This has been chosen to allow the test to pass with all generators. + * Set to null test with a random seed. When using a random + * seed re-run the test multiple times. Systematic failure of the same test + * should be investigated further. + */ + private static final Long SEED = 0xd1342543de82ef95L; + + @ParameterizedTest + @CsvSource({ + "42, 257", + "1356, 8073", + }) + void testBoundedRandom2(int range1, int range2) { + Assertions.assertTrue((long) range1 * range2 < 1L << 31, "Product must be less than 2^31"); + + final int samples = 1000000; + final int bins = 8; + final long[][] observed = new long[bins][bins]; + final UniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(SEED); + final int[] productBound = {range1 * range2}; + final int width1 = (int) Math.ceil((double) range1 / bins); + final int width2 = (int) Math.ceil((double) range2 / bins); + for (int i = 0; i < samples; i++) { + final int[] indices = ArrayShuffleBenchmark.randomBounded2(range1, range2, productBound, rng); + final int index1 = indices[0] / width1; + final int index2 = indices[1] / width2; + observed[index1][index2]++; + } + + final double p = new ChiSquareTest().chiSquareTest(observed); + Assertions.assertFalse(p < 1e-3, () -> "p-value too small: " + p); + } + + @ParameterizedTest + @CsvSource({ + "257", + "8073", + // Above the bounded random threshold of 2^15 + "31548", + }) + void testShuffle(int length) { + final int[] array = PermutationSampler.natural(length); + final UniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(SEED); + final int samples = 1000000 / length; + final int bins = 8; + final long[][] observed = new long[bins][bins]; + final int width = (int) Math.ceil((double) length / bins); + for (int j = 0; j < samples; j++) { + ArraySampler.shuffle(rng, array); + for (int i = 0; i < length; i++) { + observed[i / width][array[i] / width]++; + } + } + final double p = new ChiSquareTest().chiSquareTest(observed); + Assertions.assertFalse(p < 1e-3, () -> "p-value too small: " + p); + } +}