1
//! Helper types for framing Json objects into async read/writes
2

            
3
use std::marker::PhantomData;
4

            
5
use asynchronous_codec::JsonCodec;
6
use bytes::BytesMut;
7
use serde::Serialize;
8

            
9
use crate::msgs::BoxedResponse;
10
use crate::msgs::FlexibleRequest;
11

            
12
/// A stream of [`Request`](crate::msgs::Request)
13
/// taken from `T` (an `AsyncRead`) and deserialized from Json.
14
#[allow(dead_code)] // TODO RPC
15
pub(crate) type RequestStream<T> =
16
    asynchronous_codec::FramedRead<T, JsonCodec<(), FlexibleRequest>>;
17

            
18
/// As JsonCodec, but only supports encoding, and places a newline after every
19
/// object.
20
#[derive(Clone)]
21
pub(crate) struct JsonLinesEncoder<T> {
22
    /// We consume objects of type T.
23
    _phantom: PhantomData<fn(T) -> ()>,
24
}
25

            
26
impl<T> Default for JsonLinesEncoder<T> {
27
2
    fn default() -> Self {
28
2
        Self {
29
2
            _phantom: PhantomData,
30
2
        }
31
2
    }
32
}
33

            
34
impl<T> asynchronous_codec::Encoder for JsonLinesEncoder<T>
35
where
36
    T: Serialize + 'static,
37
{
38
    type Item<'a> = T;
39

            
40
    type Error = asynchronous_codec::JsonCodecError;
41

            
42
6
    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
43
        use std::fmt::Write as _;
44
6
        let j = serde_json::to_string(&item)?;
45
        // The jsonlines format won't work if serde_json starts adding newlines in the middle.
46
6
        debug_assert!(!j.contains('\n'));
47
6
        writeln!(dst, "{}", j).expect("write! of string on BytesMut failed");
48
6
        Ok(())
49
6
    }
50
}
51

            
52
/// A stream of [`BoxedResponse`] serialized as newline-terminated json objects
53
/// onto an `AsyncWrite.`
54
#[allow(dead_code)] // TODO RPC
55
pub(crate) type ResponseSink<T> =
56
    asynchronous_codec::FramedWrite<T, JsonLinesEncoder<BoxedResponse>>;
57

            
58
#[cfg(test)]
59
mod test {
60
    // @@ begin test lint list maintained by maint/add_warning @@
61
    #![allow(clippy::bool_assert_comparison)]
62
    #![allow(clippy::clone_on_copy)]
63
    #![allow(clippy::dbg_macro)]
64
    #![allow(clippy::mixed_attributes_style)]
65
    #![allow(clippy::print_stderr)]
66
    #![allow(clippy::print_stdout)]
67
    #![allow(clippy::single_char_pattern)]
68
    #![allow(clippy::unwrap_used)]
69
    #![allow(clippy::unchecked_duration_subtraction)]
70
    #![allow(clippy::useless_vec)]
71
    #![allow(clippy::needless_pass_by_value)]
72
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
73

            
74
    use super::*;
75
    use crate::msgs::*;
76
    use futures::sink::SinkExt as _;
77
    use futures_await_test::async_test;
78
    use tor_rpcbase as rpc;
79

            
80
    #[derive(serde::Serialize)]
81
    struct Empty {}
82

            
83
    #[async_test]
84
    async fn check_sink_basics() {
85
        // Sanity-checking for our sink type.
86
        let mut buf = Vec::new();
87
        let r1 = BoxedResponse {
88
            id: Some(RequestId::Int(7)),
89
            body: ResponseBody::Update(Box::new(Empty {})),
90
        };
91
        let r2 = BoxedResponse {
92
            id: Some(RequestId::Int(8)),
93
            body: ResponseBody::Error(Box::new(rpc::RpcError::from(
94
                crate::connection::RequestCancelled,
95
            ))),
96
        };
97
        let r3 = BoxedResponse {
98
            id: Some(RequestId::Int(9)),
99
            body: ResponseBody::Success(Box::new(Empty {})),
100
        };
101

            
102
        // These should get serialized as follows.
103
        let mut expect = String::new();
104
        expect.extend(serde_json::to_string(&r1));
105
        expect.push('\n');
106
        expect.extend(serde_json::to_string(&r2));
107
        expect.push('\n');
108
        expect.extend(serde_json::to_string(&r3));
109
        expect.push('\n');
110

            
111
        {
112
            let mut sink = ResponseSink::new(&mut buf, JsonLinesEncoder::default());
113
            sink.send(r1).await.unwrap();
114
            sink.send(r2).await.unwrap();
115
            sink.send(r3).await.unwrap();
116
        }
117
        // Exactly 3 messages means exactly 3 newlines.
118
        assert_eq!(buf.iter().filter(|c| **c == b'\n').count(), 3);
119
        // Make sure that the output is what we expected.
120
        assert_eq!(std::str::from_utf8(&buf).unwrap(), &expect);
121
    }
122
}