arti_rpc_client_core/conn/
connimpl.rs

1//! Implementation logic for RpcConn.
2//!
3//! Except for [`RpcConn`] itself, nothing in this module is a public API.
4//! This module exists so that we can more easily audit the code that
5//! touches the members of `RpcConn`.
6//!
7//! NOTE that many of the types and fields here have documented invariants.
8//! Except if noted otherwise, these invariants only hold when nobody
9//! is holding the lock on [`RequestState`].
10use std::{
11    collections::{HashMap, VecDeque},
12    sync::{Arc, Condvar, Mutex, MutexGuard},
13};
14
15use crate::{
16    llconn,
17    msgs::{
18        request::{IdGenerator, ValidatedRequest},
19        response::ValidatedResponse,
20        AnyRequestId, ObjectId,
21    },
22};
23
24use super::{ProtoError, ShutdownError};
25
26/// State held by the [`RpcConn`] for a single request ID.
27#[derive(Default)]
28struct RequestState {
29    /// A queue of replies received with this request's identity.
30    queue: VecDeque<ValidatedResponse>,
31    /// A condition variable used to wake a thread waiting for this request
32    /// to have messages.
33    ///
34    /// We `notify` this condvar thread under one of three circumstances:
35    ///
36    /// * When we queue a response for this request.
37    /// * When we store a fatal error affecting all requests in the RpcConn.
38    /// * When the thread currently reading from the [`llconn::Reader`] for this
39    ///   RpcConn stops doing so, and the request waiting
40    ///   on this thread has been chosen to take responsibility for reading.
41    ///
42    /// Invariants:
43    /// * The condvar is Some if (and only if) some thread is waiting
44    ///   on it.
45    waiter: Option<Arc<Condvar>>,
46}
47
48impl RequestState {
49    /// Helper: Pop and return the next message for this request.
50    ///
51    /// If there are no queued messages, but a fatal error has occurred on the connection,
52    /// return that.
53    ///
54    /// If there are no queued messages and no fatal error, return None.
55    fn pop_next_msg(
56        &mut self,
57        fatal: &Option<ShutdownError>,
58    ) -> Option<Result<ValidatedResponse, ShutdownError>> {
59        if let Some(m) = self.queue.pop_front() {
60            Some(Ok(m))
61        } else {
62            fatal.as_ref().map(|f| Err(f.clone()))
63        }
64    }
65}
66
67/// Mutable state to implement receiving replies on an RpcConn.
68struct ReceiverState {
69    /// Helper to assign connection- unique IDs to any requests without them.
70    id_gen: IdGenerator,
71    /// A fatal error, if any has occurred.
72    fatal: Option<ShutdownError>,
73    /// A map from request ID to the corresponding state.
74    ///
75    /// There is an entry in this map for every request that we have sent,
76    /// unless we have received a final response for that request,
77    /// or we have cancelled that request.
78    ///
79    /// (TODO: We might handle cancelling differently.)
80    pending: HashMap<AnyRequestId, RequestState>,
81    /// A reader that we use to receive replies from Arti.
82    ///
83    /// Invariants:
84    ///
85    /// * If this is None, a thread is reading and will take responsibility
86    ///   for liveness.
87    /// * If this is Some, no-one is reading and anyone who cares about liveness
88    ///   must take on the reader role.
89    ///
90    /// (Therefore, when it becomes Some, we must signal a cv, if any is set.)
91    reader: Option<crate::llconn::Reader>,
92}
93
94impl ReceiverState {
95    /// Notify an arbitrarily chosen request's condvar.
96    fn alert_anybody(&self) {
97        // TODO: This is O(n) in the worst case.
98        //
99        // But with luck, nobody will make a million requests and
100        // then wait on them one at a time?
101        for ent in self.pending.values() {
102            if let Some(cv) = &ent.waiter {
103                cv.notify_one();
104                return;
105            }
106        }
107    }
108
109    /// Notify the condvar for every request.
110    fn alert_everybody(&self) {
111        for ent in self.pending.values() {
112            if let Some(cv) = &ent.waiter {
113                // By our rules, each condvar is waited on by precisely one thread.
114                // So we call `notify_one` even though we are trying to wake up everyone.
115                cv.notify_one();
116            }
117        }
118    }
119}
120
121/// Object to receive messages on an RpcConn.
122///
123/// This is a crate-internal abstraction.
124/// It's separate from RpcConn for a few reasons:
125///
126/// - So we can keep the reading side of the channel open while the RpcConn has
127///   been dropped.
128/// - So we can hold the lock on this part without being blocked on threads writing.
129/// - Because this is the only part that for which
130///   `RequestHandle` needs to keep a reference.
131pub(super) struct Receiver {
132    /// Mutable state.
133    ///
134    /// This lock should only be held briefly, and never while reading from the
135    /// `llconn::Reader`.
136    state: Mutex<ReceiverState>,
137}
138
139/// An open RPC connection to Arti.
140#[derive(educe::Educe)]
141#[educe(Debug)]
142pub struct RpcConn {
143    /// The receiver object for this conn.
144    ///
145    /// It's in an `Arc<>` so that we can share it with the RequestHandles.
146    #[educe(Debug(ignore))]
147    receiver: Arc<Receiver>,
148
149    /// A writer that we use to send requests to Arti.
150    ///
151    /// This has its own lock so that we do not have to lock the Receiver
152    /// just in order to write.
153    ///
154    /// This lock does not nest with the`receiver` lock.  You must never hold
155    /// both at the same time.
156    ///
157    /// (For now, this lock is _ONLY_ held in the send_request method.)
158    #[educe(Debug(ignore))]
159    writer: Mutex<llconn::Writer>,
160
161    /// If set, we are authenticated and we have negotiated a session that has
162    /// this ObjectID.
163    pub(super) session: Option<ObjectId>,
164}
165
166/// Instruction to alert some additional condvar(s) before releasing our lock and returning
167///
168/// Any code which receives one of these must pass the instruction on to someone else,
169/// until, eventually, the instruction is acted on in [`Receiver::wait_on_message_for`].
170#[must_use]
171#[derive(Debug)]
172enum AlertWhom {
173    /// We don't need to alert anybody;
174    /// we have not taken the reader, or registered our own condvar:
175    /// therefore nobody expects us to take the reader.
176    Nobody,
177    /// We have taken the reader or been alerted via our condvar:
178    /// therefore, we are responsible for making sure
179    /// that _somebody_ takes the reader.
180    ///
181    /// We should therefore alert somebody if nobody currently has the reader.
182    Anybody,
183    /// We have been the first to encounter a fatal error.
184    /// Therefore, we should inform _everybody_.
185    Everybody,
186}
187
188impl RpcConn {
189    /// Construct a new RpcConn with a given reader and writer.
190    pub(super) fn new(reader: llconn::Reader, writer: llconn::Writer) -> Self {
191        Self {
192            receiver: Arc::new(Receiver {
193                state: Mutex::new(ReceiverState {
194                    id_gen: IdGenerator::default(),
195                    fatal: None,
196                    pending: HashMap::new(),
197                    reader: Some(reader),
198                }),
199            }),
200            writer: Mutex::new(writer),
201            session: None,
202        }
203    }
204
205    /// Send the request in `msg` on this connection, and return a RequestHandle
206    /// to wait for a reply.
207    ///
208    /// We validate `msg` before sending it out, and reject it if it doesn't
209    /// make sense. If `msg` has no `id` field, we allocate a new one
210    /// according to the rules in [`IdGenerator`].
211    ///
212    /// Limitation: We don't preserved unrecognized fields in the framing and meta
213    /// parts of `msg`.  See notes in `request.rs`.
214    pub(super) fn send_request(&self, msg: &str) -> Result<super::RequestHandle, ProtoError> {
215        use std::collections::hash_map::Entry::*;
216
217        let mut state = self.receiver.state.lock().expect("poisoned");
218        if let Some(f) = &state.fatal {
219            // If there's been a fatal error we don't even try to send the request.
220            return Err(f.clone().into());
221        }
222
223        // Convert this request into validated form (with an ID) and re-encode it.
224        let valid: ValidatedRequest =
225            ValidatedRequest::from_string_loose(msg, || state.id_gen.next_id())?;
226
227        // Do the necessary housekeeping before we send the request, so that
228        // we'll be able to understand the replies.
229        let id = valid.id().clone();
230        match state.pending.entry(id.clone()) {
231            Occupied(_) => return Err(ProtoError::RequestIdInUse),
232            Vacant(v) => {
233                v.insert(RequestState::default());
234            }
235        }
236        // Release the lock on the ReceiverState here; the two locks must not overlap.
237        drop(state);
238
239        // NOTE: This is the only block of code that holds the writer lock!
240        let write_outcome = { self.writer.lock().expect("poisoned").send_valid(&valid) };
241
242        match write_outcome {
243            Err(e) => {
244                // A failed write is a fatal error for everybody.
245                let e = ShutdownError::Write(Arc::new(e));
246                let mut state = self.receiver.state.lock().expect("poisoned");
247                if state.fatal.is_none() {
248                    state.fatal = Some(e.clone());
249                    state.alert_everybody();
250                }
251                Err(e.into())
252            }
253
254            Ok(()) => Ok(super::RequestHandle {
255                id,
256                conn: Mutex::new(Arc::clone(&self.receiver)),
257            }),
258        }
259    }
260}
261
262impl Receiver {
263    /// Wait until there is either a fatal error on this connection,
264    /// _or_ there is a new message for the request with the provided `id`.
265    /// Return that message, or a copy of the fatal error.
266    pub(super) fn wait_on_message_for(
267        &self,
268        id: &AnyRequestId,
269    ) -> Result<ValidatedResponse, ProtoError> {
270        // Here in wait_on_message_for_impl, we do the the actual work
271        // of waiting for the message.
272        let state = self.state.lock().expect("poisoned");
273        let (result, mut state, should_alert) = self.wait_on_message_for_impl(state, id);
274
275        // Great; we have a message or a fatal error.  All we need to do now
276        // is to restore our invariants before we drop state_lock.
277        //
278        // (It would be a bug to return early without restoring the invariants,
279        // so we'll use an IEFE pattern to prevent "?" and "return Err".)
280        #[allow(clippy::redundant_closure_call)]
281        (|| {
282            // "final" in this case means that we are not expecting any more
283            // replies for this request.
284            let is_final = match &result {
285                Err(_) => true,
286                Ok(r) => r.is_final(),
287            };
288
289            if is_final {
290                // Note 1: It might be cleaner to use Entry::remove(), but Entry is not
291                // exactly the right shape for us; see note in
292                // wait_on_message_for_impl.
293
294                // Note 2: This remove isn't necessary if `result` is
295                // RequestCancelled, but it won't hurt.
296
297                // Note 3: On DuplicateWait, it is not totally clear whether we should
298                // remove or not.  But that's an internal error that should never occur,
299                // so it is probably okay if we let the _other_ waiter keep on trying.
300                state.pending.remove(id);
301            }
302
303            match should_alert {
304                AlertWhom::Nobody => {}
305                AlertWhom::Anybody if state.reader.is_none() => {}
306                AlertWhom::Anybody => state.alert_anybody(),
307                AlertWhom::Everybody => state.alert_everybody(),
308            }
309        })();
310
311        result
312    }
313
314    /// Helper to implement [`wait_on_message_for`](Self::wait_on_message_for).
315    ///
316    /// Takes a `MutexGuard` as one of its arguments, and returns an equivalent
317    /// `MutexGuard` on completion.
318    ///
319    /// The caller is responsible for:
320    ///
321    /// - Removing the appropriate entry from `pending`, if the result
322    ///   indicates that no more messages will be received for this request.
323    /// - Possibly, notifying one or more condvars,
324    ///   depending on the resulting `AlertWhom`.
325    ///
326    /// The caller must not drop the `MutexGuard` until it has done the above.
327    fn wait_on_message_for_impl<'a>(
328        &'a self,
329        mut state_lock: MutexGuard<'a, ReceiverState>,
330        id: &AnyRequestId,
331    ) -> (
332        Result<ValidatedResponse, ProtoError>,
333        MutexGuard<'a, ReceiverState>,
334        AlertWhom,
335    ) {
336        // At this point, we have not registered on a condvar, and we have not
337        // taken the reader.
338        // Therefore, we do not yet need to ensure that anybody else takes the reader.
339        //
340        // TODO: It is possibly too easy to forget to set this,
341        // or to set it to a less "alerty" value.  Refactoring might help;
342        // see discussion at
343        // https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2258#note_3047267
344        let mut should_alert = AlertWhom::Nobody;
345
346        let mut state: &mut ReceiverState = &mut state_lock;
347
348        // Initialize `this_ent` to our own entry in the pending table.
349        let Some(mut this_ent) = state.pending.get_mut(id) else {
350            return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
351        };
352
353        let mut reader = loop {
354            // Note: It might be nice to use a hash_map::Entry here, but it
355            // doesn't really work the way we want.  The `entry()` API is always
356            // ready to insert, and requires that we clone `id`.  But what we
357            // want in this case is something that would give us a .remove()able
358            // Entry only if one is present.
359            if this_ent.waiter.is_some() {
360                // This is an internal error; nobody should be able to cause this.
361                return (Err(ProtoError::DuplicateWait), state_lock, should_alert);
362            }
363
364            if let Some(ready) = this_ent.pop_next_msg(&state.fatal) {
365                // There is a reply for us, or a fatal error.
366                return (ready.map_err(ProtoError::from), state_lock, should_alert);
367            }
368
369            // If we reach this point, we are about to either take the reader or
370            // register a cv.  This means that when we return, we need to make
371            // sure that at least one other cv gets notified.
372            should_alert = AlertWhom::Anybody;
373
374            if let Some(r) = state.reader.take() {
375                // Nobody else is reading; we have to do it.
376                break r;
377            }
378
379            // Somebody else is reading; register a condvar.
380            let cv = Arc::new(Condvar::new());
381            this_ent.waiter = Some(Arc::clone(&cv));
382
383            state_lock = cv.wait(state_lock).expect("poisoned lock");
384            state = &mut state_lock;
385            // Restore `this_ent`...
386            let Some(e) = state.pending.get_mut(id) else {
387                return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
388            };
389            this_ent = e;
390            // ... And un-register our condvar.
391            this_ent.waiter = None;
392
393            // We have been notified: either there is a reply or us,
394            // or we are supposed to take the reader.  We'll find out on our
395            // next time through the loop.
396        };
397
398        let (result, mut state_lock, should_alert) =
399            self.read_until_message_for(state_lock, &mut reader, id);
400        // Put the reader back.
401        state_lock.reader = Some(reader);
402
403        (result.map_err(ProtoError::from), state_lock, should_alert)
404    }
405
406    /// Read messages, delivering them as appropriate, until we find one for `id`,
407    /// or a fatal error occurs.
408    ///
409    /// Return that message or error, along with a `MutexGuard`.
410    ///
411    /// The caller is responsible for restoring the following state before
412    /// dropping the `MutexGuard`:
413    ///
414    /// - Putting `reader` back into the `reader` field.
415    /// - Other invariants as discussed in wait_on_message_for_impl.
416    fn read_until_message_for<'a>(
417        &'a self,
418        mut state_lock: MutexGuard<'a, ReceiverState>,
419        reader: &mut llconn::Reader,
420        id: &AnyRequestId,
421    ) -> (
422        Result<ValidatedResponse, ShutdownError>,
423        MutexGuard<'a, ReceiverState>,
424        AlertWhom,
425    ) {
426        loop {
427            // Importantly, we drop the state lock while we are reading.
428            // This is okay, since all our invariants should hold at this point.
429            drop(state_lock);
430
431            let result: Result<ValidatedResponse, _> = match reader.read_msg() {
432                Err(e) => Err(ShutdownError::Read(Arc::new(e))),
433                Ok(None) => Err(ShutdownError::ConnectionClosed),
434                Ok(Some(m)) => m.try_validate().map_err(ShutdownError::from),
435            };
436
437            state_lock = self.state.lock().expect("poisoned lock");
438            let state = &mut state_lock;
439
440            match result {
441                Ok(m) if m.id() == id => {
442                    // This only is for us, so there's no need to alert anybody
443                    // or queue it.
444                    return (Ok(m), state_lock, AlertWhom::Anybody);
445                }
446                Err(e) => {
447                    // This is a fatal error on the whole connection.
448                    //
449                    // If it's the first one encountered, queue the error, and
450                    // return it.
451                    if state.fatal.is_none() {
452                        state.fatal = Some(e.clone());
453                    }
454                    return (Err(e), state_lock, AlertWhom::Everybody);
455                }
456                Ok(m) => {
457                    // This is a message for exactly one ID, that isn't us.
458                    // Queue it and notify them.
459                    if let Some(ent) = state.pending.get_mut(m.id()) {
460                        ent.queue.push_back(m);
461                        if let Some(cv) = &ent.waiter {
462                            cv.notify_one();
463                        }
464                    } else {
465                        // Nothing wanted this response any longer.
466                        // _Probably_ this means that we decided to cancel the
467                        // request but Arti sent this response before it handled
468                        // our cancellation.
469                    }
470                }
471            };
472        }
473    }
474}