1
//! Implementation logic for RpcConn.
2
//!
3
//! Except for [`RpcConn`] itself, nothing in this module is a public API.
4
//! This module exists so that we can more easily audit the code that
5
//! touches the members of `RpcConn`.
6
//!
7
//! NOTE that many of the types and fields here have documented invariants.
8
//! Except if noted otherwise, these invariants only hold when nobody
9
//! is holding the lock on [`RequestState`].
10
use std::{
11
    collections::{HashMap, VecDeque},
12
    sync::{Arc, Condvar, Mutex, MutexGuard},
13
};
14

            
15
use crate::{
16
    llconn,
17
    msgs::{
18
        request::{IdGenerator, ValidatedRequest},
19
        response::ValidatedResponse,
20
        AnyRequestId, ObjectId,
21
    },
22
};
23

            
24
use super::{ProtoError, ShutdownError};
25

            
26
/// State held by the [`RpcConn`] for a single request ID.
27
#[derive(Default)]
28
struct RequestState {
29
    /// A queue of replies received with this request's identity.
30
    queue: VecDeque<ValidatedResponse>,
31
    /// A condition variable used to wake a thread waiting for this request
32
    /// to have messages.
33
    ///
34
    /// We `notify` this condvar thread under one of three circumstances:
35
    ///
36
    /// * When we queue a response for this request.
37
    /// * When we store a fatal error affecting all requests in the RpcConn.
38
    /// * When the thread currently reading from the [`llconn::Reader`] for this
39
    ///   RpcConn stops doing so, and the request waiting
40
    ///   on this thread has been chosen to take responsibility for reading.
41
    ///
42
    /// Invariants:
43
    /// * The condvar is Some if (and only if) some thread is waiting
44
    ///   on it.
45
    waiter: Option<Arc<Condvar>>,
46
}
47

            
48
impl RequestState {
49
    /// Helper: Pop and return the next message for this request.
50
    ///
51
    /// If there are no queued messages, but a fatal error has occurred on the connection,
52
    /// return that.
53
    ///
54
    /// If there are no queued messages and no fatal error, return None.
55
16338
    fn pop_next_msg(
56
16338
        &mut self,
57
16338
        fatal: &Option<ShutdownError>,
58
16338
    ) -> Option<Result<ValidatedResponse, ShutdownError>> {
59
16338
        if let Some(m) = self.queue.pop_front() {
60
7090
            Some(Ok(m))
61
        } else {
62
9289
            fatal.as_ref().map(|f| Err(f.clone()))
63
        }
64
16338
    }
65
}
66

            
67
/// Mutable state to implement receiving replies on an RpcConn.
68
struct ReceiverState {
69
    /// Helper to assign connection- unique IDs to any requests without them.
70
    id_gen: IdGenerator,
71
    /// A fatal error, if any has occurred.
72
    fatal: Option<ShutdownError>,
73
    /// A map from request ID to the corresponding state.
74
    ///
75
    /// There is an entry in this map for every request that we have sent,
76
    /// unless we have received a final response for that request,
77
    /// or we have cancelled that request.
78
    ///
79
    /// (TODO: We might handle cancelling differently.)
80
    pending: HashMap<AnyRequestId, RequestState>,
81
    /// A reader that we use to receive replies from Arti.
82
    ///
83
    /// Invariants:
84
    ///
85
    /// * If this is None, a thread is reading and will take responsibility
86
    ///   for liveness.
87
    /// * If this is Some, no-one is reading and anyone who cares about liveness
88
    ///   must take on the reader role.
89
    ///
90
    /// (Therefore, when it becomes Some, we must signal a cv, if any is set.)
91
    reader: Option<crate::llconn::Reader>,
92
}
93

            
94
impl ReceiverState {
95
    /// Notify an arbitrarily chosen request's condvar.
96
1132
    fn alert_anybody(&self) {
97
        // TODO: This is O(n) in the worst case.
98
        //
99
        // But with luck, nobody will make a million requests and
100
        // then wait on them one at a time?
101
1442
        for ent in self.pending.values() {
102
1442
            if let Some(cv) = &ent.waiter {
103
1106
                cv.notify_one();
104
1106
                return;
105
336
            }
106
        }
107
1132
    }
108

            
109
    /// Notify the condvar for every request.
110
6
    fn alert_everybody(&self) {
111
82
        for ent in self.pending.values() {
112
82
            if let Some(cv) = &ent.waiter {
113
80
                // By our rules, each condvar is waited on by precisely one thread.
114
80
                // So we call `notify_one` even though we are trying to wake up everyone.
115
80
                cv.notify_one();
116
80
            }
117
        }
118
6
    }
119
}
120

            
121
/// Object to receive messages on an RpcConn.
122
///
123
/// This is a crate-internal abstraction.
124
/// It's separate from RpcConn for a few reasons:
125
///
126
/// - So we can keep the reading side of the channel open while the RpcConn has
127
///   been dropped.
128
/// - So we can hold the lock on this part without being blocked on threads writing.
129
/// - Because this is the only part that for which
130
///   `RequestHandle` needs to keep a reference.
131
pub(super) struct Receiver {
132
    /// Mutable state.
133
    ///
134
    /// This lock should only be held briefly, and never while reading from the
135
    /// `llconn::Reader`.
136
    state: Mutex<ReceiverState>,
137
}
138

            
139
/// An open RPC connection to Arti.
140
#[derive(educe::Educe)]
141
#[educe(Debug)]
142
pub struct RpcConn {
143
    /// The receiver object for this conn.
144
    ///
145
    /// It's in an `Arc<>` so that we can share it with the RequestHandles.
146
    #[educe(Debug(ignore))]
147
    receiver: Arc<Receiver>,
148

            
149
    /// A writer that we use to send requests to Arti.
150
    ///
151
    /// This has its own lock so that we do not have to lock the Receiver
152
    /// just in order to write.
153
    ///
154
    /// This lock does not nest with the`receiver` lock.  You must never hold
155
    /// both at the same time.
156
    ///
157
    /// (For now, this lock is _ONLY_ held in the send_request method.)
158
    #[educe(Debug(ignore))]
159
    writer: Mutex<llconn::Writer>,
160

            
161
    /// If set, we are authenticated and we have negotiated a session that has
162
    /// this ObjectID.
163
    pub(super) session: Option<ObjectId>,
164
}
165

            
166
/// Instruction to alert some additional condvar(s) before releasing our lock and returning
167
///
168
/// Any code which receives one of these must pass the instruction on to someone else,
169
/// until, eventually, the instruction is acted on in [`Receiver::wait_on_message_for`].
170
#[must_use]
171
#[derive(Debug)]
172
enum AlertWhom {
173
    /// We don't need to alert anybody;
174
    /// we have not taken the reader, or registered our own condvar:
175
    /// therefore nobody expects us to take the reader.
176
    Nobody,
177
    /// We have taken the reader or been alerted via our condvar:
178
    /// therefore, we are responsible for making sure
179
    /// that _somebody_ takes the reader.
180
    ///
181
    /// We should therefore alert somebody if nobody currently has the reader.
182
    Anybody,
183
    /// We have been the first to encounter a fatal error.
184
    /// Therefore, we should inform _everybody_.
185
    Everybody,
186
}
187

            
188
impl RpcConn {
189
    /// Construct a new RpcConn with a given reader and writer.
190
10
    pub(super) fn new(reader: llconn::Reader, writer: llconn::Writer) -> Self {
191
10
        Self {
192
10
            receiver: Arc::new(Receiver {
193
10
                state: Mutex::new(ReceiverState {
194
10
                    id_gen: IdGenerator::default(),
195
10
                    fatal: None,
196
10
                    pending: HashMap::new(),
197
10
                    reader: Some(reader),
198
10
                }),
199
10
            }),
200
10
            writer: Mutex::new(writer),
201
10
            session: None,
202
10
        }
203
10
    }
204

            
205
    /// Send the request in `msg` on this connection, and return a RequestHandle
206
    /// to wait for a reply.
207
    ///
208
    /// We validate `msg` before sending it out, and reject it if it doesn't
209
    /// make sense. If `msg` has no `id` field, we allocate a new one
210
    /// according to the rules in [`IdGenerator`].
211
    ///
212
    /// Limitation: We don't preserved unrecognized fields in the framing and meta
213
    /// parts of `msg`.  See notes in `request.rs`.
214
4194
    pub(super) fn send_request(&self, msg: &str) -> Result<super::RequestHandle, ProtoError> {
215
        use std::collections::hash_map::Entry::*;
216

            
217
4194
        let mut state = self.receiver.state.lock().expect("poisoned");
218
4194
        if let Some(f) = &state.fatal {
219
            // If there's been a fatal error we don't even try to send the request.
220
8
            return Err(f.clone().into());
221
4186
        }
222

            
223
        // Convert this request into validated form (with an ID) and re-encode it.
224
4186
        let valid: ValidatedRequest =
225
6279
            ValidatedRequest::from_string_loose(msg, || state.id_gen.next_id())?;
226

            
227
        // Do the necessary housekeeping before we send the request, so that
228
        // we'll be able to understand the replies.
229
4186
        let id = valid.id().clone();
230
4186
        match state.pending.entry(id.clone()) {
231
            Occupied(_) => return Err(ProtoError::RequestIdInUse),
232
4186
            Vacant(v) => {
233
4186
                v.insert(RequestState::default());
234
4186
            }
235
4186
        }
236
4186
        // Release the lock on the ReceiverState here; the two locks must not overlap.
237
4186
        drop(state);
238
4186

            
239
4186
        // NOTE: This is the only block of code that holds the writer lock!
240
4186
        let write_outcome = { self.writer.lock().expect("poisoned").send_valid(&valid) };
241
4186

            
242
4186
        match write_outcome {
243
            Err(e) => {
244
                // A failed write is a fatal error for everybody.
245
                let e = ShutdownError::Write(Arc::new(e));
246
                let mut state = self.receiver.state.lock().expect("poisoned");
247
                if state.fatal.is_none() {
248
                    state.fatal = Some(e.clone());
249
                    state.alert_everybody();
250
                }
251
                Err(e.into())
252
            }
253

            
254
4186
            Ok(()) => Ok(super::RequestHandle {
255
4186
                id,
256
4186
                conn: Mutex::new(Arc::clone(&self.receiver)),
257
4186
            }),
258
        }
259
4194
    }
260
}
261

            
262
impl Receiver {
263
    /// Wait until there is either a fatal error on this connection,
264
    /// _or_ there is a new message for the request with the provided `id`.
265
    /// Return that message, or a copy of the fatal error.
266
8230
    pub(super) fn wait_on_message_for(
267
8230
        &self,
268
8230
        id: &AnyRequestId,
269
8230
    ) -> Result<ValidatedResponse, ProtoError> {
270
8230
        // Here in wait_on_message_for_impl, we do the the actual work
271
8230
        // of waiting for the message.
272
8230
        let state = self.state.lock().expect("poisoned");
273
8230
        let (result, mut state, should_alert) = self.wait_on_message_for_impl(state, id);
274

            
275
        // Great; we have a message or a fatal error.  All we need to do now
276
        // is to restore our invariants before we drop state_lock.
277
        //
278
        // (It would be a bug to return early without restoring the invariants,
279
        // so we'll use an IEFE pattern to prevent "?" and "return Err".)
280
        #[allow(clippy::redundant_closure_call)]
281
8230
        (|| {
282
            // "final" in this case means that we are not expecting any more
283
            // replies for this request.
284
8230
            let is_final = match &result {
285
88
                Err(_) => true,
286
8142
                Ok(r) => r.is_final(),
287
            };
288

            
289
8230
            if is_final {
290
4186
                // Note 1: It might be cleaner to use Entry::remove(), but Entry is not
291
4186
                // exactly the right shape for us; see note in
292
4186
                // wait_on_message_for_impl.
293
4186

            
294
4186
                // Note 2: This remove isn't necessary if `result` is
295
4186
                // RequestCancelled, but it won't hurt.
296
4186

            
297
4186
                // Note 3: On DuplicateWait, it is not totally clear whether we should
298
4186
                // remove or not.  But that's an internal error that should never occur,
299
4186
                // so it is probably okay if we let the _other_ waiter keep on trying.
300
4186
                state.pending.remove(id);
301
4186
            }
302

            
303
8132
            match should_alert {
304
92
                AlertWhom::Nobody => {}
305
8132
                AlertWhom::Anybody if state.reader.is_none() => {}
306
1132
                AlertWhom::Anybody => state.alert_anybody(),
307
6
                AlertWhom::Everybody => state.alert_everybody(),
308
            }
309
8230
        })();
310
8230

            
311
8230
        result
312
8230
    }
313

            
314
    /// Helper to implement [`wait_on_message_for`](Self::wait_on_message_for).
315
    ///
316
    /// Takes a `MutexGuard` as one of its arguments, and returns an equivalent
317
    /// `MutexGuard` on completion.
318
    ///
319
    /// The caller is responsible for:
320
    ///
321
    /// - Removing the appropriate entry from `pending`, if the result
322
    ///   indicates that no more messages will be received for this request.
323
    /// - Possibly, notifying one or more condvars,
324
    ///   depending on the resulting `AlertWhom`.
325
    ///
326
    /// The caller must not drop the `MutexGuard` until it has done the above.
327
8230
    fn wait_on_message_for_impl<'a>(
328
8230
        &'a self,
329
8230
        mut state_lock: MutexGuard<'a, ReceiverState>,
330
8230
        id: &AnyRequestId,
331
8230
    ) -> (
332
8230
        Result<ValidatedResponse, ProtoError>,
333
8230
        MutexGuard<'a, ReceiverState>,
334
8230
        AlertWhom,
335
8230
    ) {
336
8230
        // At this point, we have not registered on a condvar, and we have not
337
8230
        // taken the reader.
338
8230
        // Therefore, we do not yet need to ensure that anybody else takes the reader.
339
8230
        //
340
8230
        // TODO: It is possibly too easy to forget to set this,
341
8230
        // or to set it to a less "alerty" value.  Refactoring might help;
342
8230
        // see discussion at
343
8230
        // https://gitlab.torproject.org/tpo/core/arti/-/merge_requests/2258#note_3047267
344
8230
        let mut should_alert = AlertWhom::Nobody;
345
8230

            
346
8230
        let mut state: &mut ReceiverState = &mut state_lock;
347

            
348
        // Initialize `this_ent` to our own entry in the pending table.
349
8230
        let Some(mut this_ent) = state.pending.get_mut(id) else {
350
            return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
351
        };
352

            
353
1058
        let mut reader = loop {
354
            // Note: It might be nice to use a hash_map::Entry here, but it
355
            // doesn't really work the way we want.  The `entry()` API is always
356
            // ready to insert, and requires that we clone `id`.  But what we
357
            // want in this case is something that would give us a .remove()able
358
            // Entry only if one is present.
359
16338
            if this_ent.waiter.is_some() {
360
                // This is an internal error; nobody should be able to cause this.
361
                return (Err(ProtoError::DuplicateWait), state_lock, should_alert);
362
16338
            }
363

            
364
16338
            if let Some(ready) = this_ent.pop_next_msg(&state.fatal) {
365
                // There is a reply for us, or a fatal error.
366
7172
                return (ready.map_err(ProtoError::from), state_lock, should_alert);
367
9166
            }
368
9166

            
369
9166
            // If we reach this point, we are about to either take the reader or
370
9166
            // register a cv.  This means that when we return, we need to make
371
9166
            // sure that at least one other cv gets notified.
372
9166
            should_alert = AlertWhom::Anybody;
373

            
374
9166
            if let Some(r) = state.reader.take() {
375
                // Nobody else is reading; we have to do it.
376
1058
                break r;
377
8108
            }
378
8108

            
379
8108
            // Somebody else is reading; register a condvar.
380
8108
            let cv = Arc::new(Condvar::new());
381
8108
            this_ent.waiter = Some(Arc::clone(&cv));
382
8108

            
383
8108
            state_lock = cv.wait(state_lock).expect("poisoned lock");
384
8108
            state = &mut state_lock;
385
            // Restore `this_ent`...
386
8108
            let Some(e) = state.pending.get_mut(id) else {
387
                return (Err(ProtoError::RequestCompleted), state_lock, should_alert);
388
            };
389
8108
            this_ent = e;
390
8108
            // ... And un-register our condvar.
391
8108
            this_ent.waiter = None;
392

            
393
            // We have been notified: either there is a reply or us,
394
            // or we are supposed to take the reader.  We'll find out on our
395
            // next time through the loop.
396
        };
397

            
398
1058
        let (result, mut state_lock, should_alert) =
399
1058
            self.read_until_message_for(state_lock, &mut reader, id);
400
1058
        // Put the reader back.
401
1058
        state_lock.reader = Some(reader);
402
1058

            
403
1058
        (result.map_err(ProtoError::from), state_lock, should_alert)
404
8230
    }
405

            
406
    /// Read messages, delivering them as appropriate, until we find one for `id`,
407
    /// or a fatal error occurs.
408
    ///
409
    /// Return that message or error, along with a `MutexGuard`.
410
    ///
411
    /// The caller is responsible for restoring the following state before
412
    /// dropping the `MutexGuard`:
413
    ///
414
    /// - Putting `reader` back into the `reader` field.
415
    /// - Other invariants as discussed in wait_on_message_for_impl.
416
1058
    fn read_until_message_for<'a>(
417
1058
        &'a self,
418
1058
        mut state_lock: MutexGuard<'a, ReceiverState>,
419
1058
        reader: &mut llconn::Reader,
420
1058
        id: &AnyRequestId,
421
1058
    ) -> (
422
1058
        Result<ValidatedResponse, ShutdownError>,
423
1058
        MutexGuard<'a, ReceiverState>,
424
1058
        AlertWhom,
425
1058
    ) {
426
        loop {
427
            // Importantly, we drop the state lock while we are reading.
428
            // This is okay, since all our invariants should hold at this point.
429
8148
            drop(state_lock);
430

            
431
8148
            let result: Result<ValidatedResponse, _> = match reader.read_msg() {
432
2
                Err(e) => Err(ShutdownError::Read(Arc::new(e))),
433
                Ok(None) => Err(ShutdownError::ConnectionClosed),
434
8146
                Ok(Some(m)) => m.try_validate().map_err(ShutdownError::from),
435
            };
436

            
437
8148
            state_lock = self.state.lock().expect("poisoned lock");
438
8148
            let state = &mut state_lock;
439

            
440
8142
            match result {
441
8142
                Ok(m) if m.id() == id => {
442
1052
                    // This only is for us, so there's no need to alert anybody
443
1052
                    // or queue it.
444
1052
                    return (Ok(m), state_lock, AlertWhom::Anybody);
445
                }
446
6
                Err(e) => {
447
6
                    // This is a fatal error on the whole connection.
448
6
                    //
449
6
                    // If it's the first one encountered, queue the error, and
450
6
                    // return it.
451
6
                    if state.fatal.is_none() {
452
6
                        state.fatal = Some(e.clone());
453
6
                    }
454
6
                    return (Err(e), state_lock, AlertWhom::Everybody);
455
                }
456
7090
                Ok(m) => {
457
                    // This is a message for exactly one ID, that isn't us.
458
                    // Queue it and notify them.
459
7090
                    if let Some(ent) = state.pending.get_mut(m.id()) {
460
7090
                        ent.queue.push_back(m);
461
7090
                        if let Some(cv) = &ent.waiter {
462
7038
                            cv.notify_one();
463
7038
                        }
464
                    } else {
465
                        // Nothing wanted this response any longer.
466
                        // _Probably_ this means that we decided to cancel the
467
                        // request but Arti sent this response before it handled
468
                        // our cancellation.
469
                    }
470
                }
471
            };
472
        }
473
1058
    }
474
}