arti_rpc_client_core/conn/connimpl.rs
1//! Implementation logic for RpcConn.
2//!
3//! Except for [`RpcConn`] itself, nothing in this module is a public API.
4//! This module exists so that we can more easily audit the code that
5//! touches the members of `RpcConn`.
6//!
7//! NOTE that many of the types and fields here have documented invariants.
8//! Except if noted otherwise, these invariants only hold when nobody
9//! is holding the lock on [`RequestState`].
10use std::{
11 collections::{HashMap, VecDeque},
12 sync::{Arc, Condvar, Mutex, MutexGuard},
13};
14
15use crate::{
16 llconn,
17 msgs::{
18 request::{IdGenerator, ValidatedRequest},
19 response::ValidatedResponse,
20 AnyRequestId, ObjectId,
21 },
22};
23
24use super::{ProtoError, ShutdownError};
25
26/// State held by the [`RpcConn`] for a single request ID.
27#[derive(Default)]
28struct RequestState {
29 /// A queue of replies received with this request's identity.
30 queue: VecDeque<ValidatedResponse>,
31 /// A condition variable used to wake a thread waiting for this request
32 /// to have messages.
33 ///
34 /// We `notify` this condvar thread under one of three circumstances:
35 ///
36 /// * When we queue a response for this request.
37 /// * When we store a fatal error affecting all requests in the RpcConn.
38 /// * When the thread currently reading from the [`llconn::Reader`] for this
39 /// RpcConn stops doing so, and the request waiting
40 /// on this thread has been chosen to take responsibility for reading.
41 ///
42 /// Invariants:
43 /// * The condvar is Some if (and only if) some thread is waiting
44 /// on it.
45 waiter: Option<Arc<Condvar>>,
46}
47
48impl RequestState {
49 /// Helper: Pop and return the next message for this request.
50 ///
51 /// If there are no queued messages, but a fatal error has occurred on the connection,
52 /// return that.
53 ///
54 /// If there are no queued messages and no fatal error, return None.
55 fn pop_next_msg(
56 &mut self,
57 fatal: &Option<ShutdownError>,
58 ) -> Option<Result<ValidatedResponse, ShutdownError>> {
59 if let Some(m) = self.queue.pop_front() {
60 Some(Ok(m))
61 } else {
62 fatal.as_ref().map(|f| Err(f.clone()))
63 }
64 }
65}
66
67/// Mutable state to implement receiving replies on an RpcConn.
68struct ReceiverState {
69 /// Helper to assign connection- unique IDs to any requests without them.
70 id_gen: IdGenerator,
71 /// A fatal error, if any has occurred.
72 fatal: Option<ShutdownError>,
73 /// A map from request ID to the corresponding state.
74 ///
75 /// There is an entry in this map for every request that we have sent,
76 /// unless we have received a final response for that request,
77 /// or we have cancelled that request.
78 ///
79 /// (TODO: We might handle cancelling differently.)
80 pending: HashMap<AnyRequestId, RequestState>,
81 /// A reader that we use to receive replies from Arti.
82 ///
83 /// Invariants:
84 ///
85 /// * If this is None, a thread is reading and will take responsibility
86 /// for liveness.
87 /// * If this is Some, no-one is reading and anyone who cares about liveness
88 /// must take on the reader role.
89 ///
90 /// (Therefore, when it becomes Some, we must signal a cv, if any is set.)
91 reader: Option<crate::llconn::Reader>,
92}
93
94impl ReceiverState {
95 /// Notify an arbitrarily chosen request's condvar.
96 fn alert_anybody(&self) {
97 // TODO: This is O(n) in the worst case.
98 //
99 // But with luck, nobody will make a million requests and
100 // then wait on them one at a time?
101 for ent in self.pending.values() {
102 if let Some(cv) = &ent.waiter {
103 cv.notify_one();
104 return;
105 }
106 }
107 }
108
109 /// Notify the condvar for every request.
110 fn alert_everybody(&self) {
111 for ent in self.pending.values() {
112 if let Some(cv) = &ent.waiter {
113 // By our rules, each condvar is waited on by precisely one thread.
114 // So we call `notify_one` even though we are trying to wake up everyone.
115 cv.notify_one();
116 }
117 }
118 }
119}
120
121/// Object to receive messages on an RpcConn.
122///
123/// This is a crate-internal abstraction.
124/// It's separate from RpcConn for a few reasons:
125///
126/// - So we can keep the reading side of the channel open while the RpcConn has
127/// been dropped.
128/// - So we can hold the lock on this part without being blocked on threads writing.
129/// - Because this is the only part that for which
130/// `RequestHandle` needs to keep a reference.
131pub(super) struct Receiver {
132 /// Mutable state.
133 ///
134 /// This lock should only be held briefly, and never while reading from the
135 /// `llconn::Reader`.
136 state: Mutex<ReceiverState>,
137}
138
139/// An open RPC connection to Arti.
140#[derive(educe::Educe)]
141#[educe(Debug)]
142pub struct RpcConn {
143 /// The receiver object for this conn.
144 ///
145 /// It's in an `Arc<>` so that we can share it with the RequestHandles.
146 #[educe(Debug(ignore))]
147 receiver: Arc<Receiver>,
148
149 /// A writer that we use to send requests to Arti.
150 ///
151 /// This has its own lock so that we do not have to lock the Receiver
152 /// just in order to write.
153 ///
154 /// This lock does not nest with the`receiver` lock. You must never hold
155 /// both at the same time.
156 ///
157 /// (For now, this lock is _ONLY_ held in the send_request method.)
158 #[educe(Debug(ignore))]
159 writer: Mutex<llconn::Writer>,
160
161 /// If set, we are authenticated and we have negotiated a session that has
162 /// this ObjectID.
163 pub(super) session: Option<ObjectId>,
164}
165
166/// Instruction to alert some additional condvar(s) before releasing our lock and returning
167///
168/// Any code which receives one of these must pass the instruction on to someone else,
169/// until, eventually, the instruction is acted on in [`Receiver::wait_on_message_for`].
170#[must_use]
171#[derive(Debug)]
172enum AlertWhom {
173 /// We don't need to alert anybody;
174 /// we have not taken the reader, or registered our own condvar:
175 /// therefore nobody expects us to take the reader.
176 Nobody,
177 /// We have taken the reader or been alerted via our condvar:
178 /// therefore, we are responsible for making sure
179 /// that _somebody_ takes the reader.
180 ///
181 /// We should therefore alert somebody if nobody currently has the reader.
182 Anybody,
183 /// We have been the first to encounter a fatal error.
184 /// Therefore, we should inform _everybody_.
185 Everybody,
186}
187
188impl RpcConn {
189 /// Construct a new RpcConn with a given reader and writer.
190 pub(super) fn new(reader: llconn::Reader, writer: llconn::Writer) -> Self {
191 Self {
192 receiver: Arc::new(Receiver {
193 state: Mutex::new(ReceiverState {
194 id_gen: IdGenerator::default(),
195 fatal: None,
196 pending: HashMap::new(),
197 reader: Some(reader),
198 }),
199 }),
200 writer: Mutex::new(writer),
201 session: None,
202 }
203 }
204
205 /// Send the request in `msg` on this connection, and return a RequestHandle
206 /// to wait for a reply.
207 ///
208 /// We validate `msg` before sending it out, and reject it if it doesn't
209 /// make sense. If `msg` has no `id` field, we allocate a new one
210 /// according to the rules in [`IdGenerator`].
211 ///
212 /// Limitation: We don't preserved unrecognized fields in the framing and meta
213 /// parts of `msg`. See notes in `request.rs`.
214 pub(super) fn send_request(&self, msg: &str) -> Result<super::RequestHandle, ProtoError> {
215 use std::collections::hash_map::Entry::*;
216
217 let mut state = self.receiver.state.lock().expect("poisoned");
218 if let Some(f) = &state.fatal {
219 // If there's been a fatal error we don't even try to send the request.
220 return Err(f.clone().into());
221 }
222
223 // Convert this request into validated form (with an ID) and re-encode it.
224 let valid: ValidatedRequest =
225 ValidatedRequest::from_string_loose(msg, || state.id_gen.next_id())?;
226
227 // Do the necessary housekeeping before we send the request, so that
228 // we'll be able to understand the replies.
229 let id = valid.id().clone();
230 match state.pending.entry(id.clone()) {
231 Occupied(_) => return Err(ProtoError::RequestIdInUse),
232 Vacant(v) => {
233 v.insert(RequestState::default());
234 }
235 }
236 // Release the lock on the ReceiverState here; the two locks must not overlap.
237 drop(state);
238
239 // NOTE: This is the only block of code that holds the writer lock!
240 let write_outcome = { self.writer.lock().expect("poisoned").send_valid(&valid) };
241
242 match write_outcome {
243 Err(e) => {
244 // A failed write is a fatal error for everybody.
245 let e = ShutdownError::Write(Arc::new(e));
246 let mut state = self.receiver.state.lock().expect("poisoned");
247 if state.fatal.is_none() {
248 state.fatal = Some(e.clone());
249 state.alert_everybody();
250 }
251 Err(e.into())
252 }
253
254 Ok(()) => Ok(super::RequestHandle {
255 id,
256 conn: Mutex::new(Arc::clone(&self.receiver)),
257 }),
258 }
259 }
260}
261
262impl Receiver {
263 /// Wait until there is either a fatal error on this connection,
264 /// _or_ there is a new message for the request with the provided `id`.
265 /// Return that message, or a copy of the fatal error.
266 pub(super) fn wait_on_message_for(
267 &self,
268 id: &AnyRequestId,
269 ) -> Result<ValidatedResponse, ProtoError> {
270 // Here in wait_on_message_for_impl, we do the the actual work
271 // of waiting for the message.
272 let state = self.state.lock().expect("poisoned");
273 let (result, mut state, should_alert) = self.wait_on_message_for_impl(state, id);
274
275 // Great; we have a message or a fatal error. All we need to do now
276 // is to restore our invariants before we drop state_lock.
277 //
278 // (It would be a bug to return early without restoring the invariants,
279 // so we'll use an IEFE pattern to prevent "?" and "return Err".)
280 #[allow(clippy::redundant_closure_call)]
281 (|| {
282 // "final" in this case means that we are not expecting any more
283 // replies for this request.
284 let is_final = match &result {
285 Err(_) => true,
286 Ok(r) => r.is_final(),
287 };
288
289 if is_final {
290 // Note 1: It might be cleaner to use Entry::remove(), but Entry is not
291 // exactly the right shape for us; see note in
292 // wait_on_message_for_impl.
293
294 // Note 2: This remove isn't necessary if `result` is
295 // RequestCancelled, but it won't hurt.
296
297 // Note 3: On DuplicateWait, it is not totally clear whether we should
298 // remove or not. But that's an internal error that should never occur,
299 // so it is probably okay if we let the _other_ waiter keep on trying.
300 state.pending.remove(id);
301 }
302
303 match should_alert {
304 AlertWhom::Nobody => {}
305 AlertWhom::Anybody if state.reader.is_none() => {}
306 AlertWhom::Anybody => state.alert_anybody(),
307 AlertWhom::Everybody => state.alert_everybody(),
308 }
309 })();
310
311 result
312 }
313
314 /// Helper to implement [`wait_on_message_for`](Self::wait_on_message_for).
315 ///
316 /// Takes a `MutexGuard` as one of its arguments, and returns an equivalent
317 /// `MutexGuard` on completion.
318 ///
319 /// The caller is responsible for:
320 ///
321 /// - Removing the appropriate entry from `pending`, if the result
322 /// indicates that no more messages will be received for this request.
323 /// - Possibly, notifying one or more condvars,
324 /// depending on the resulting `AlertWhom`.
325 ///
326 /// The caller must not drop the `MutexGuard` until it has done the above.
327 fn wait_on_message_for_impl<'a>(
328 &'a self,
329 mut state_lock: MutexGuard<'a, ReceiverState>,
330 id: &AnyRequestId,
331 ) -> (
332 Result<ValidatedResponse, ProtoError>,
333 MutexGuard<'a, ReceiverState>,
334 AlertWhom,
335 ) {
336 // At this point, we have not registered on a condvar, and we have not
337 // taken the reader.
338 // Therefore, we do not yet need to ensure that anybody else takes the reader.
339 //
340 // TODO: It is possibly too easy to forget to set this,
341 // or to set it to a less "alerty" value. Refactoring might help;
342 // see discussion at
343 // https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2258#note_3047267
344 let mut should_alert = AlertWhom::Nobody;
345
346 let mut state: &mut ReceiverState = &mut state_lock;
347
348 // Initialize `this_ent` to our own entry in the pending table.
349 let Some(mut this_ent) = state.pending.get_mut(id) else {
350 return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
351 };
352
353 let mut reader = loop {
354 // Note: It might be nice to use a hash_map::Entry here, but it
355 // doesn't really work the way we want. The `entry()` API is always
356 // ready to insert, and requires that we clone `id`. But what we
357 // want in this case is something that would give us a .remove()able
358 // Entry only if one is present.
359 if this_ent.waiter.is_some() {
360 // This is an internal error; nobody should be able to cause this.
361 return (Err(ProtoError::DuplicateWait), state_lock, should_alert);
362 }
363
364 if let Some(ready) = this_ent.pop_next_msg(&state.fatal) {
365 // There is a reply for us, or a fatal error.
366 return (ready.map_err(ProtoError::from), state_lock, should_alert);
367 }
368
369 // If we reach this point, we are about to either take the reader or
370 // register a cv. This means that when we return, we need to make
371 // sure that at least one other cv gets notified.
372 should_alert = AlertWhom::Anybody;
373
374 if let Some(r) = state.reader.take() {
375 // Nobody else is reading; we have to do it.
376 break r;
377 }
378
379 // Somebody else is reading; register a condvar.
380 let cv = Arc::new(Condvar::new());
381 this_ent.waiter = Some(Arc::clone(&cv));
382
383 state_lock = cv.wait(state_lock).expect("poisoned lock");
384 state = &mut state_lock;
385 // Restore `this_ent`...
386 let Some(e) = state.pending.get_mut(id) else {
387 return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
388 };
389 this_ent = e;
390 // ... And un-register our condvar.
391 this_ent.waiter = None;
392
393 // We have been notified: either there is a reply or us,
394 // or we are supposed to take the reader. We'll find out on our
395 // next time through the loop.
396 };
397
398 let (result, mut state_lock, should_alert) =
399 self.read_until_message_for(state_lock, &mut reader, id);
400 // Put the reader back.
401 state_lock.reader = Some(reader);
402
403 (result.map_err(ProtoError::from), state_lock, should_alert)
404 }
405
406 /// Read messages, delivering them as appropriate, until we find one for `id`,
407 /// or a fatal error occurs.
408 ///
409 /// Return that message or error, along with a `MutexGuard`.
410 ///
411 /// The caller is responsible for restoring the following state before
412 /// dropping the `MutexGuard`:
413 ///
414 /// - Putting `reader` back into the `reader` field.
415 /// - Other invariants as discussed in wait_on_message_for_impl.
416 fn read_until_message_for<'a>(
417 &'a self,
418 mut state_lock: MutexGuard<'a, ReceiverState>,
419 reader: &mut llconn::Reader,
420 id: &AnyRequestId,
421 ) -> (
422 Result<ValidatedResponse, ShutdownError>,
423 MutexGuard<'a, ReceiverState>,
424 AlertWhom,
425 ) {
426 loop {
427 // Importantly, we drop the state lock while we are reading.
428 // This is okay, since all our invariants should hold at this point.
429 drop(state_lock);
430
431 let result: Result<ValidatedResponse, _> = match reader.read_msg() {
432 Err(e) => Err(ShutdownError::Read(Arc::new(e))),
433 Ok(None) => Err(ShutdownError::ConnectionClosed),
434 Ok(Some(m)) => m.try_validate().map_err(ShutdownError::from),
435 };
436
437 state_lock = self.state.lock().expect("poisoned lock");
438 let state = &mut state_lock;
439
440 match result {
441 Ok(m) if m.id() == id => {
442 // This only is for us, so there's no need to alert anybody
443 // or queue it.
444 return (Ok(m), state_lock, AlertWhom::Anybody);
445 }
446 Err(e) => {
447 // This is a fatal error on the whole connection.
448 //
449 // If it's the first one encountered, queue the error, and
450 // return it.
451 if state.fatal.is_none() {
452 state.fatal = Some(e.clone());
453 }
454 return (Err(e), state_lock, AlertWhom::Everybody);
455 }
456 Ok(m) => {
457 // This is a message for exactly one ID, that isn't us.
458 // Queue it and notify them.
459 if let Some(ent) = state.pending.get_mut(m.id()) {
460 ent.queue.push_back(m);
461 if let Some(cv) = &ent.waiter {
462 cv.notify_one();
463 }
464 } else {
465 // Nothing wanted this response any longer.
466 // _Probably_ this means that we decided to cancel the
467 // request but Arti sent this response before it handled
468 // our cancellation.
469 }
470 }
471 };
472 }
473 }
474}