Tor 0.4.9.2-alpha-dev
All Data Structures Files Functions Variables Typedefs Enumerations Enumerator Macros Modules Pages
solver.c
1/* Copyright (c) 2020 tevador <tevador@gmail.com> */
2/* See LICENSE for licensing information */
3
4#include "solver.h"
5#include "context.h"
6#include "solver_heap.h"
7#include <hashx_endian.h>
8#include <string.h>
9#include <stdbool.h>
10#include <assert.h>
11#include <stdio.h>
12
13#ifdef _MSC_VER
14#pragma warning (disable : 4146) /* unary minus applied to unsigned type */
15#endif
16
17#define CLEAR(x) memset(&x, 0, sizeof(x))
18#define MAKE_ITEM(bucket, left, right) ((left) << 17 | (right) << 8 | (bucket))
19#define ITEM_BUCKET(item) (item) % NUM_COARSE_BUCKETS
20#define ITEM_LEFT_IDX(item) (item) >> 17
21#define ITEM_RIGHT_IDX(item) ((item) >> 8) & 511
22#define INVERT_BUCKET(idx) -(idx) % NUM_COARSE_BUCKETS
23#define INVERT_SCRATCH(idx) -(idx) % NUM_FINE_BUCKETS
24#define STAGE1_IDX(buck, pos) heap->stage1_indices.buckets[buck].items[pos]
25#define STAGE2_IDX(buck, pos) heap->stage2_indices.buckets[buck].items[pos]
26#define STAGE3_IDX(buck, pos) heap->stage3_indices.buckets[buck].items[pos]
27#define STAGE1_DATA(buck, pos) heap->stage1_data.buckets[buck].items[pos]
28#define STAGE2_DATA(buck, pos) heap->stage2_data.buckets[buck].items[pos]
29#define STAGE3_DATA(buck, pos) heap->stage3_data.buckets[buck].items[pos]
30#define STAGE1_SIZE(buck) heap->stage1_indices.counts[buck]
31#define STAGE2_SIZE(buck) heap->stage2_indices.counts[buck]
32#define STAGE3_SIZE(buck) heap->stage3_indices.counts[buck]
33#define SCRATCH(buck, pos) heap->scratch_ht.buckets[buck].items[pos]
34#define SCRATCH_SIZE(buck) heap->scratch_ht.counts[buck]
35#define SWAP_IDX(a, b) \
36 do { \
37 equix_idx temp = a; \
38 a = b; \
39 b = temp; \
40 } while(0)
41#define CARRY (bucket_idx != 0)
42#define BUCK_START 0
43#define BUCK_END (NUM_COARSE_BUCKETS / 2 + 1)
44
45typedef uint32_t u32;
46typedef stage1_idx_item s1_idx;
47typedef stage2_idx_item s2_idx;
48typedef stage3_idx_item s3_idx;
49
50static FORCE_INLINE bool hash_value(hashx_ctx* hash_func, equix_idx index, uint64_t *value_out) {
51 char hash[HASHX_SIZE];
52 hashx_result result = hashx_exec(hash_func, index, hash);
53 if (result == HASHX_OK) {
54 *value_out = load64(hash);
55 return true;
56 } else {
57 assert(false);
58 return false;
59 }
60}
61
62static void build_solution_stage1(equix_idx* output, solver_heap* heap, s2_idx root) {
63 u32 bucket = ITEM_BUCKET(root);
64 u32 bucket_inv = INVERT_BUCKET(bucket);
65 u32 left_parent_idx = ITEM_LEFT_IDX(root);
66 u32 right_parent_idx = ITEM_RIGHT_IDX(root);
67 s1_idx left_parent = STAGE1_IDX(bucket, left_parent_idx);
68 s1_idx right_parent = STAGE1_IDX(bucket_inv, right_parent_idx);
69 output[0] = left_parent;
70 output[1] = right_parent;
71 if (!tree_cmp1(&output[0], &output[1])) {
72 SWAP_IDX(output[0], output[1]);
73 }
74}
75
76static void build_solution_stage2(equix_idx* output, solver_heap* heap, s3_idx root) {
77 u32 bucket = ITEM_BUCKET(root);
78 u32 bucket_inv = INVERT_BUCKET(bucket);
79 u32 left_parent_idx = ITEM_LEFT_IDX(root);
80 u32 right_parent_idx = ITEM_RIGHT_IDX(root);
81 s2_idx left_parent = STAGE2_IDX(bucket, left_parent_idx);
82 s2_idx right_parent = STAGE2_IDX(bucket_inv, right_parent_idx);
83 build_solution_stage1(&output[0], heap, left_parent);
84 build_solution_stage1(&output[2], heap, right_parent);
85 if (!tree_cmp2(&output[0], &output[2])) {
86 SWAP_IDX(output[0], output[2]);
87 SWAP_IDX(output[1], output[3]);
88 }
89}
90
91static void build_solution(equix_solution* solution, solver_heap* heap, s3_idx left, s3_idx right) {
92 build_solution_stage2(&solution->idx[0], heap, left);
93 build_solution_stage2(&solution->idx[4], heap, right);
94 if (!tree_cmp4(&solution->idx[0], &solution->idx[4])) {
95 SWAP_IDX(solution->idx[0], solution->idx[4]);
96 SWAP_IDX(solution->idx[1], solution->idx[5]);
97 SWAP_IDX(solution->idx[2], solution->idx[6]);
98 SWAP_IDX(solution->idx[3], solution->idx[7]);
99 }
100}
101
102static void solve_stage0(hashx_ctx* hash_func, solver_heap* heap) {
103 CLEAR(heap->stage1_indices.counts);
104 for (u32 i = 0; i < INDEX_SPACE; ++i) {
105 uint64_t value;
106 if (!hash_value(hash_func, i, &value))
107 break;
108 u32 bucket_idx = value % NUM_COARSE_BUCKETS;
109 u32 item_idx = STAGE1_SIZE(bucket_idx);
110 if (item_idx >= COARSE_BUCKET_ITEMS)
111 continue;
112 STAGE1_SIZE(bucket_idx) = item_idx + 1;
113 STAGE1_IDX(bucket_idx, item_idx) = i;
114 STAGE1_DATA(bucket_idx, item_idx) = value / NUM_COARSE_BUCKETS; /* 52 bits */
115 }
116}
117
118#define MAKE_PAIRS1 \
119 stage1_data_item value = STAGE1_DATA(bucket_idx, item_idx) + CARRY; \
120 u32 fine_buck_idx = value % NUM_FINE_BUCKETS; \
121 u32 fine_cpl_bucket = INVERT_SCRATCH(fine_buck_idx); \
122 u32 fine_cpl_size = SCRATCH_SIZE(fine_cpl_bucket); \
123 for (u32 fine_idx = 0; fine_idx < fine_cpl_size; ++fine_idx) { \
124 u32 cpl_index = SCRATCH(fine_cpl_bucket, fine_idx); \
125 stage1_data_item cpl_value = STAGE1_DATA(cpl_bucket, cpl_index); \
126 stage1_data_item sum = value + cpl_value; \
127 assert((sum % NUM_FINE_BUCKETS) == 0); \
128 sum /= NUM_FINE_BUCKETS; /* 45 bits */ \
129 u32 s2_buck_id = sum % NUM_COARSE_BUCKETS; \
130 u32 s2_item_id = STAGE2_SIZE(s2_buck_id); \
131 if (s2_item_id >= COARSE_BUCKET_ITEMS) \
132 continue; \
133 STAGE2_SIZE(s2_buck_id) = s2_item_id + 1; \
134 STAGE2_IDX(s2_buck_id, s2_item_id) = \
135 MAKE_ITEM(bucket_idx, item_idx, cpl_index); \
136 STAGE2_DATA(s2_buck_id, s2_item_id) = \
137 sum / NUM_COARSE_BUCKETS; /* 37 bits */ \
138 } \
139
140static void solve_stage1(solver_heap* heap) {
141 CLEAR(heap->stage2_indices.counts);
142 for (u32 bucket_idx = BUCK_START; bucket_idx < BUCK_END; ++bucket_idx) {
143 u32 cpl_bucket = INVERT_BUCKET(bucket_idx);
144 CLEAR(heap->scratch_ht.counts);
145 u32 cpl_buck_size = STAGE1_SIZE(cpl_bucket);
146 for (u32 item_idx = 0; item_idx < cpl_buck_size; ++item_idx) {
147 {
148 stage1_data_item value = STAGE1_DATA(cpl_bucket, item_idx);
149 u32 fine_buck_idx = value % NUM_FINE_BUCKETS;
150 u32 fine_item_idx = SCRATCH_SIZE(fine_buck_idx);
151 if (fine_item_idx >= FINE_BUCKET_ITEMS)
152 continue;
153 SCRATCH_SIZE(fine_buck_idx) = fine_item_idx + 1;
154 SCRATCH(fine_buck_idx, fine_item_idx) = item_idx;
155 }
156 if (cpl_bucket == bucket_idx) {
157 MAKE_PAIRS1
158 }
159 }
160 if (cpl_bucket != bucket_idx) {
161 u32 buck_size = STAGE1_SIZE(bucket_idx);
162 for (u32 item_idx = 0; item_idx < buck_size; ++item_idx) {
163 MAKE_PAIRS1
164 }
165 }
166 }
167}
168
169#define MAKE_PAIRS2 \
170 stage2_data_item value = STAGE2_DATA(bucket_idx, item_idx) + CARRY; \
171 u32 fine_buck_idx = value % NUM_FINE_BUCKETS; \
172 u32 fine_cpl_bucket = INVERT_SCRATCH(fine_buck_idx); \
173 u32 fine_cpl_size = SCRATCH_SIZE(fine_cpl_bucket); \
174 for (u32 fine_idx = 0; fine_idx < fine_cpl_size; ++fine_idx) { \
175 u32 cpl_index = SCRATCH(fine_cpl_bucket, fine_idx); \
176 stage2_data_item cpl_value = STAGE2_DATA(cpl_bucket, cpl_index); \
177 stage2_data_item sum = value + cpl_value; \
178 assert((sum % NUM_FINE_BUCKETS) == 0); \
179 sum /= NUM_FINE_BUCKETS; /* 30 bits */ \
180 u32 s3_buck_id = sum % NUM_COARSE_BUCKETS; \
181 u32 s3_item_id = STAGE3_SIZE(s3_buck_id); \
182 if (s3_item_id >= COARSE_BUCKET_ITEMS) \
183 continue; \
184 STAGE3_SIZE(s3_buck_id) = s3_item_id + 1; \
185 STAGE3_IDX(s3_buck_id, s3_item_id) = \
186 MAKE_ITEM(bucket_idx, item_idx, cpl_index); \
187 STAGE3_DATA(s3_buck_id, s3_item_id) = \
188 (stage3_data_item)(sum / NUM_COARSE_BUCKETS); /* 22 bits */ \
189 } \
190
191static void solve_stage2(solver_heap* heap) {
192 CLEAR(heap->stage3_indices.counts);
193 for (u32 bucket_idx = BUCK_START; bucket_idx < BUCK_END; ++bucket_idx) {
194 u32 cpl_bucket = INVERT_BUCKET(bucket_idx);
195 CLEAR(heap->scratch_ht.counts);
196 u32 cpl_buck_size = STAGE2_SIZE(cpl_bucket);
197 for (u32 item_idx = 0; item_idx < cpl_buck_size; ++item_idx) {
198 {
199 stage2_data_item value = STAGE2_DATA(cpl_bucket, item_idx);
200 u32 fine_buck_idx = value % NUM_FINE_BUCKETS;
201 u32 fine_item_idx = SCRATCH_SIZE(fine_buck_idx);
202 if (fine_item_idx >= FINE_BUCKET_ITEMS)
203 continue;
204 SCRATCH_SIZE(fine_buck_idx) = fine_item_idx + 1;
205 SCRATCH(fine_buck_idx, fine_item_idx) = item_idx;
206 }
207 if (cpl_bucket == bucket_idx) {
208 MAKE_PAIRS2
209 }
210 }
211 if (cpl_bucket != bucket_idx) {
212 u32 buck_size = STAGE2_SIZE(bucket_idx);
213 for (u32 item_idx = 0; item_idx < buck_size; ++item_idx) {
214 MAKE_PAIRS2
215 }
216 }
217 }
218}
219
220#define MAKE_PAIRS3 \
221 stage3_data_item value = STAGE3_DATA(bucket_idx, item_idx) + CARRY; \
222 u32 fine_buck_idx = value % NUM_FINE_BUCKETS; \
223 u32 fine_cpl_bucket = INVERT_SCRATCH(fine_buck_idx); \
224 u32 fine_cpl_size = SCRATCH_SIZE(fine_cpl_bucket); \
225 for (u32 fine_idx = 0; fine_idx < fine_cpl_size; ++fine_idx) { \
226 u32 cpl_index = SCRATCH(fine_cpl_bucket, fine_idx); \
227 stage3_data_item cpl_value = STAGE3_DATA(cpl_bucket, cpl_index); \
228 stage3_data_item sum = value + cpl_value; \
229 assert((sum % NUM_FINE_BUCKETS) == 0); \
230 sum /= NUM_FINE_BUCKETS; /* 15 bits */ \
231 if ((sum & EQUIX_STAGE1_MASK) == 0) { \
232 /* we have a solution */ \
233 s3_idx item_left = STAGE3_IDX(bucket_idx, item_idx); \
234 s3_idx item_right = STAGE3_IDX(cpl_bucket, cpl_index); \
235 build_solution(&output[sols_found], heap, item_left, item_right); \
236 if (++(sols_found) >= EQUIX_MAX_SOLS) { \
237 return sols_found; \
238 } \
239 } \
240 } \
241
242static int solve_stage3(solver_heap* heap, equix_solution output[EQUIX_MAX_SOLS]) {
243 int sols_found = 0;
244
245 for (u32 bucket_idx = BUCK_START; bucket_idx < BUCK_END; ++bucket_idx) {
246 u32 cpl_bucket = -bucket_idx & (NUM_COARSE_BUCKETS - 1);
247 CLEAR(heap->scratch_ht.counts);
248 u32 cpl_buck_size = STAGE3_SIZE(cpl_bucket);
249 for (u32 item_idx = 0; item_idx < cpl_buck_size; ++item_idx) {
250 {
251 stage3_data_item value = STAGE3_DATA(cpl_bucket, item_idx);
252 u32 fine_buck_idx = value % NUM_FINE_BUCKETS;
253 u32 fine_item_idx = SCRATCH_SIZE(fine_buck_idx);
254 if (fine_item_idx >= FINE_BUCKET_ITEMS)
255 continue;
256 SCRATCH_SIZE(fine_buck_idx) = fine_item_idx + 1;
257 SCRATCH(fine_buck_idx, fine_item_idx) = item_idx;
258 }
259 if (cpl_bucket == bucket_idx) {
260 MAKE_PAIRS3
261 }
262 }
263 if (cpl_bucket != bucket_idx) {
264 u32 buck_size = STAGE3_SIZE(bucket_idx);
265 for (u32 item_idx = 0; item_idx < buck_size; ++item_idx) {
266 MAKE_PAIRS3
267 }
268 }
269 }
270
271 return sols_found;
272}
273
274int equix_solver_solve(
275 hashx_ctx* hash_func,
276 solver_heap* heap,
277 equix_solution output[EQUIX_MAX_SOLS])
278{
279 solve_stage0(hash_func, heap);
280 solve_stage1(heap);
281 solve_stage2(heap);
282 return solve_stage3(heap, output);
283}