-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark.h
125 lines (103 loc) · 3.29 KB
/
benchmark.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright (c) 2015, The Regents of the University of California (Regents)
// See LICENSE.txt for license details
#ifndef BENCHMARK_H_
#define BENCHMARK_H_
#include <algorithm>
#include <cinttypes>
#include <functional>
#include <random>
#include <utility>
#include <vector>
#include "builder.h"
#include "graph.h"
#include "timer.h"
#include "util.h"
#include "writer.h"
/*
GAP Benchmark Suite
File: Benchmark
Author: Scott Beamer
Various helper functions to ease writing of kernels
*/
// Default type signatures for commonly used types
typedef int32_t NodeID;
typedef int32_t WeightT;
typedef NodeWeight<NodeID, WeightT> WNode;
typedef CSRGraph<NodeID> Graph;
typedef CSRGraph<NodeID, WNode> WGraph;
typedef BuilderBase<NodeID, NodeID, WeightT> Builder;
typedef BuilderBase<NodeID, WNode, WeightT> WeightedBuilder;
typedef WriterBase<NodeID, NodeID> Writer;
typedef WriterBase<NodeID, WNode> WeightedWriter;
// Used to pick random non-zero degree starting points for search algorithms
template<typename GraphT_>
class SourcePicker {
public:
explicit SourcePicker(const GraphT_ &g, NodeID given_source = -1)
: given_source(given_source), rng(kRandSeed), udist(0, g.num_nodes()-1),
g_(g) {}
NodeID PickNext() {
if (given_source != -1)
return given_source;
NodeID source;
do {
source = udist(rng);
} while (g_.out_degree(source) == 0);
return source;
}
private:
NodeID given_source;
std::mt19937 rng;
std::uniform_int_distribution<NodeID> udist;
const GraphT_ &g_;
};
// Returns k pairs with largest values from list of key-value pairs
template<typename KeyT, typename ValT>
std::vector<std::pair<ValT, KeyT>> TopK(
const std::vector<std::pair<KeyT, ValT>> &to_sort, size_t k) {
std::vector<std::pair<ValT, KeyT>> top_k;
ValT min_so_far = 0;
for (auto kvp : to_sort) {
if ((top_k.size() < k) || (kvp.second > min_so_far)) {
top_k.push_back(std::make_pair(kvp.second, kvp.first));
std::sort(top_k.begin(), top_k.end(),
std::greater<std::pair<ValT, KeyT>>());
if (top_k.size() > k)
top_k.resize(k);
min_so_far = top_k.back().first;
}
}
return top_k;
}
bool VerifyUnimplemented(...) {
std::cout << "** verify unimplemented **" << std::endl;
return false;
}
// Calls (and times) kernel according to command line arguments
template<typename GraphT_, typename GraphFunc, typename AnalysisFunc,
typename VerifierFunc>
void BenchmarkKernel(const CLApp &cli, const GraphT_ &g,
GraphFunc kernel, AnalysisFunc stats,
VerifierFunc verify) {
g.PrintStats();
double total_seconds = 0;
Timer trial_timer;
for (int iter=0; iter < cli.num_trials(); iter++) {
trial_timer.Start();
auto result = kernel(g);
trial_timer.Stop();
PrintTime("Trial Time", trial_timer.Seconds());
total_seconds += trial_timer.Seconds();
if (cli.do_analysis() && (iter == (cli.num_trials()-1)))
stats(g, result);
if (cli.do_verify()) {
trial_timer.Start();
PrintLabel("Verification",
verify(std::ref(g), std::ref(result)) ? "PASS" : "FAIL");
trial_timer.Stop();
PrintTime("Verification Time", trial_timer.Seconds());
}
}
PrintTime("Average Time", total_seconds / cli.num_trials());
}
#endif // BENCHMARK_H_