6#include "solver_heap.h"
7#include <hashx_endian.h>
14#pragma warning (disable : 4146)
17#define CLEAR(x) memset(&x, 0, sizeof(x))
18#define MAKE_ITEM(bucket, left, right) ((left) << 17 | (right) << 8 | (bucket))
19#define ITEM_BUCKET(item) (item) % NUM_COARSE_BUCKETS
20#define ITEM_LEFT_IDX(item) (item) >> 17
21#define ITEM_RIGHT_IDX(item) ((item) >> 8) & 511
22#define INVERT_BUCKET(idx) -(idx) % NUM_COARSE_BUCKETS
23#define INVERT_SCRATCH(idx) -(idx) % NUM_FINE_BUCKETS
24#define STAGE1_IDX(buck, pos) heap->stage1_indices.buckets[buck].items[pos]
25#define STAGE2_IDX(buck, pos) heap->stage2_indices.buckets[buck].items[pos]
26#define STAGE3_IDX(buck, pos) heap->stage3_indices.buckets[buck].items[pos]
27#define STAGE1_DATA(buck, pos) heap->stage1_data.buckets[buck].items[pos]
28#define STAGE2_DATA(buck, pos) heap->stage2_data.buckets[buck].items[pos]
29#define STAGE3_DATA(buck, pos) heap->stage3_data.buckets[buck].items[pos]
30#define STAGE1_SIZE(buck) heap->stage1_indices.counts[buck]
31#define STAGE2_SIZE(buck) heap->stage2_indices.counts[buck]
32#define STAGE3_SIZE(buck) heap->stage3_indices.counts[buck]
33#define SCRATCH(buck, pos) heap->scratch_ht.buckets[buck].items[pos]
34#define SCRATCH_SIZE(buck) heap->scratch_ht.counts[buck]
35#define SWAP_IDX(a, b) \
41#define CARRY (bucket_idx != 0)
43#define BUCK_END (NUM_COARSE_BUCKETS / 2 + 1)
46typedef stage1_idx_item s1_idx;
47typedef stage2_idx_item s2_idx;
48typedef stage3_idx_item s3_idx;
50static FORCE_INLINE
bool hash_value(
hashx_ctx* hash_func, equix_idx index, uint64_t *value_out) {
51 char hash[HASHX_SIZE];
52 hashx_result result = hashx_exec(hash_func, index, hash);
53 if (result == HASHX_OK) {
54 *value_out = load64(hash);
62static void build_solution_stage1(equix_idx* output,
solver_heap* heap, s2_idx root) {
63 u32 bucket = ITEM_BUCKET(root);
64 u32 bucket_inv = INVERT_BUCKET(bucket);
65 u32 left_parent_idx = ITEM_LEFT_IDX(root);
66 u32 right_parent_idx = ITEM_RIGHT_IDX(root);
67 s1_idx left_parent = STAGE1_IDX(bucket, left_parent_idx);
68 s1_idx right_parent = STAGE1_IDX(bucket_inv, right_parent_idx);
69 output[0] = left_parent;
70 output[1] = right_parent;
71 if (!tree_cmp1(&output[0], &output[1])) {
72 SWAP_IDX(output[0], output[1]);
76static void build_solution_stage2(equix_idx* output,
solver_heap* heap, s3_idx root) {
77 u32 bucket = ITEM_BUCKET(root);
78 u32 bucket_inv = INVERT_BUCKET(bucket);
79 u32 left_parent_idx = ITEM_LEFT_IDX(root);
80 u32 right_parent_idx = ITEM_RIGHT_IDX(root);
81 s2_idx left_parent = STAGE2_IDX(bucket, left_parent_idx);
82 s2_idx right_parent = STAGE2_IDX(bucket_inv, right_parent_idx);
83 build_solution_stage1(&output[0], heap, left_parent);
84 build_solution_stage1(&output[2], heap, right_parent);
85 if (!tree_cmp2(&output[0], &output[2])) {
86 SWAP_IDX(output[0], output[2]);
87 SWAP_IDX(output[1], output[3]);
92 build_solution_stage2(&solution->idx[0], heap, left);
93 build_solution_stage2(&solution->idx[4], heap, right);
94 if (!tree_cmp4(&solution->idx[0], &solution->idx[4])) {
95 SWAP_IDX(solution->idx[0], solution->idx[4]);
96 SWAP_IDX(solution->idx[1], solution->idx[5]);
97 SWAP_IDX(solution->idx[2], solution->idx[6]);
98 SWAP_IDX(solution->idx[3], solution->idx[7]);
103 CLEAR(heap->stage1_indices.counts);
104 for (u32 i = 0; i < INDEX_SPACE; ++i) {
106 if (!hash_value(hash_func, i, &value))
108 u32 bucket_idx = value % NUM_COARSE_BUCKETS;
109 u32 item_idx = STAGE1_SIZE(bucket_idx);
110 if (item_idx >= COARSE_BUCKET_ITEMS)
112 STAGE1_SIZE(bucket_idx) = item_idx + 1;
113 STAGE1_IDX(bucket_idx, item_idx) = i;
114 STAGE1_DATA(bucket_idx, item_idx) = value / NUM_COARSE_BUCKETS;
119 stage1_data_item value = STAGE1_DATA(bucket_idx, item_idx) + CARRY; \
120 u32 fine_buck_idx = value % NUM_FINE_BUCKETS; \
121 u32 fine_cpl_bucket = INVERT_SCRATCH(fine_buck_idx); \
122 u32 fine_cpl_size = SCRATCH_SIZE(fine_cpl_bucket); \
123 for (u32 fine_idx = 0; fine_idx < fine_cpl_size; ++fine_idx) { \
124 u32 cpl_index = SCRATCH(fine_cpl_bucket, fine_idx); \
125 stage1_data_item cpl_value = STAGE1_DATA(cpl_bucket, cpl_index); \
126 stage1_data_item sum = value + cpl_value; \
127 assert((sum % NUM_FINE_BUCKETS) == 0); \
128 sum /= NUM_FINE_BUCKETS; \
129 u32 s2_buck_id = sum % NUM_COARSE_BUCKETS; \
130 u32 s2_item_id = STAGE2_SIZE(s2_buck_id); \
131 if (s2_item_id >= COARSE_BUCKET_ITEMS) \
133 STAGE2_SIZE(s2_buck_id) = s2_item_id + 1; \
134 STAGE2_IDX(s2_buck_id, s2_item_id) = \
135 MAKE_ITEM(bucket_idx, item_idx, cpl_index); \
136 STAGE2_DATA(s2_buck_id, s2_item_id) = \
137 sum / NUM_COARSE_BUCKETS; \
141 CLEAR(heap->stage2_indices.counts);
142 for (u32 bucket_idx = BUCK_START; bucket_idx < BUCK_END; ++bucket_idx) {
143 u32 cpl_bucket = INVERT_BUCKET(bucket_idx);
144 CLEAR(heap->scratch_ht.counts);
145 u32 cpl_buck_size = STAGE1_SIZE(cpl_bucket);
146 for (u32 item_idx = 0; item_idx < cpl_buck_size; ++item_idx) {
148 stage1_data_item value = STAGE1_DATA(cpl_bucket, item_idx);
149 u32 fine_buck_idx = value % NUM_FINE_BUCKETS;
150 u32 fine_item_idx = SCRATCH_SIZE(fine_buck_idx);
151 if (fine_item_idx >= FINE_BUCKET_ITEMS)
153 SCRATCH_SIZE(fine_buck_idx) = fine_item_idx + 1;
154 SCRATCH(fine_buck_idx, fine_item_idx) = item_idx;
156 if (cpl_bucket == bucket_idx) {
160 if (cpl_bucket != bucket_idx) {
161 u32 buck_size = STAGE1_SIZE(bucket_idx);
162 for (u32 item_idx = 0; item_idx < buck_size; ++item_idx) {
170 stage2_data_item value = STAGE2_DATA(bucket_idx, item_idx) + CARRY; \
171 u32 fine_buck_idx = value % NUM_FINE_BUCKETS; \
172 u32 fine_cpl_bucket = INVERT_SCRATCH(fine_buck_idx); \
173 u32 fine_cpl_size = SCRATCH_SIZE(fine_cpl_bucket); \
174 for (u32 fine_idx = 0; fine_idx < fine_cpl_size; ++fine_idx) { \
175 u32 cpl_index = SCRATCH(fine_cpl_bucket, fine_idx); \
176 stage2_data_item cpl_value = STAGE2_DATA(cpl_bucket, cpl_index); \
177 stage2_data_item sum = value + cpl_value; \
178 assert((sum % NUM_FINE_BUCKETS) == 0); \
179 sum /= NUM_FINE_BUCKETS; \
180 u32 s3_buck_id = sum % NUM_COARSE_BUCKETS; \
181 u32 s3_item_id = STAGE3_SIZE(s3_buck_id); \
182 if (s3_item_id >= COARSE_BUCKET_ITEMS) \
184 STAGE3_SIZE(s3_buck_id) = s3_item_id + 1; \
185 STAGE3_IDX(s3_buck_id, s3_item_id) = \
186 MAKE_ITEM(bucket_idx, item_idx, cpl_index); \
187 STAGE3_DATA(s3_buck_id, s3_item_id) = \
188 (stage3_data_item)(sum / NUM_COARSE_BUCKETS); \
192 CLEAR(heap->stage3_indices.counts);
193 for (u32 bucket_idx = BUCK_START; bucket_idx < BUCK_END; ++bucket_idx) {
194 u32 cpl_bucket = INVERT_BUCKET(bucket_idx);
195 CLEAR(heap->scratch_ht.counts);
196 u32 cpl_buck_size = STAGE2_SIZE(cpl_bucket);
197 for (u32 item_idx = 0; item_idx < cpl_buck_size; ++item_idx) {
199 stage2_data_item value = STAGE2_DATA(cpl_bucket, item_idx);
200 u32 fine_buck_idx = value % NUM_FINE_BUCKETS;
201 u32 fine_item_idx = SCRATCH_SIZE(fine_buck_idx);
202 if (fine_item_idx >= FINE_BUCKET_ITEMS)
204 SCRATCH_SIZE(fine_buck_idx) = fine_item_idx + 1;
205 SCRATCH(fine_buck_idx, fine_item_idx) = item_idx;
207 if (cpl_bucket == bucket_idx) {
211 if (cpl_bucket != bucket_idx) {
212 u32 buck_size = STAGE2_SIZE(bucket_idx);
213 for (u32 item_idx = 0; item_idx < buck_size; ++item_idx) {
221 stage3_data_item value = STAGE3_DATA(bucket_idx, item_idx) + CARRY; \
222 u32 fine_buck_idx = value % NUM_FINE_BUCKETS; \
223 u32 fine_cpl_bucket = INVERT_SCRATCH(fine_buck_idx); \
224 u32 fine_cpl_size = SCRATCH_SIZE(fine_cpl_bucket); \
225 for (u32 fine_idx = 0; fine_idx < fine_cpl_size; ++fine_idx) { \
226 u32 cpl_index = SCRATCH(fine_cpl_bucket, fine_idx); \
227 stage3_data_item cpl_value = STAGE3_DATA(cpl_bucket, cpl_index); \
228 stage3_data_item sum = value + cpl_value; \
229 assert((sum % NUM_FINE_BUCKETS) == 0); \
230 sum /= NUM_FINE_BUCKETS; \
231 if ((sum & EQUIX_STAGE1_MASK) == 0) { \
233 s3_idx item_left = STAGE3_IDX(bucket_idx, item_idx); \
234 s3_idx item_right = STAGE3_IDX(cpl_bucket, cpl_index); \
235 build_solution(&output[sols_found], heap, item_left, item_right); \
236 if (++(sols_found) >= EQUIX_MAX_SOLS) { \
245 for (u32 bucket_idx = BUCK_START; bucket_idx < BUCK_END; ++bucket_idx) {
246 u32 cpl_bucket = -bucket_idx & (NUM_COARSE_BUCKETS - 1);
247 CLEAR(heap->scratch_ht.counts);
248 u32 cpl_buck_size = STAGE3_SIZE(cpl_bucket);
249 for (u32 item_idx = 0; item_idx < cpl_buck_size; ++item_idx) {
251 stage3_data_item value = STAGE3_DATA(cpl_bucket, item_idx);
252 u32 fine_buck_idx = value % NUM_FINE_BUCKETS;
253 u32 fine_item_idx = SCRATCH_SIZE(fine_buck_idx);
254 if (fine_item_idx >= FINE_BUCKET_ITEMS)
256 SCRATCH_SIZE(fine_buck_idx) = fine_item_idx + 1;
257 SCRATCH(fine_buck_idx, fine_item_idx) = item_idx;
259 if (cpl_bucket == bucket_idx) {
263 if (cpl_bucket != bucket_idx) {
264 u32 buck_size = STAGE3_SIZE(bucket_idx);
265 for (u32 item_idx = 0; item_idx < buck_size; ++item_idx) {
274int equix_solver_solve(
279 solve_stage0(hash_func, heap);
282 return solve_stage3(heap, output);