arti_testing/rt/
count.rs

1//! Support for counting various TCP stats for a Runtime.
2
3use futures::Stream;
4use tor_rtcompat::{NetStreamListener, NetStreamProvider, StreamOps};
5
6use async_trait::async_trait;
7use futures::io::{AsyncRead, AsyncWrite};
8use pin_project::pin_project;
9use std::io::Result as IoResult;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14
15/// Object that holds underlying counts for a Runtime.
16#[derive(Debug, Clone, Default)]
17pub(crate) struct TcpCount {
18    /// number of TCP connections we've launched
19    pub(crate) n_connect_attempt: usize,
20    /// number of TCP connections we've successfully completed
21    pub(crate) n_connect_ok: usize,
22    /// number of incoming TCP connections we've received
23    pub(crate) n_accept: usize,
24    /// total number of bytes we've sent
25    pub(crate) n_bytes_send: usize,
26    /// total number of bytes we've received
27    pub(crate) n_bytes_recv: usize,
28}
29
30/// A "Counting" wrapper around various objects, keeping running counts of TCP
31/// events.
32///
33/// This can wrap most Tcp-related Runtime types.
34#[pin_project]
35pub(crate) struct Counting<R> {
36    /// The inner object that we're instrumenting
37    #[pin]
38    inner: R,
39    /// A shared mutable set of counts.
40    count: Arc<Mutex<TcpCount>>,
41}
42
43impl<R> Clone for Counting<R>
44where
45    R: Clone,
46{
47    fn clone(&self) -> Self {
48        // TODO: Use educe instead.
49        Self {
50            inner: self.inner.clone(),
51            count: self.count.clone(),
52        }
53    }
54}
55
56impl<R> Counting<R> {
57    /// Return a new wrapper around a NetStreamProvider with a new set of statistics
58    pub(crate) fn new_zeroed(inner: R) -> Self
59    where
60        R: NetStreamProvider,
61    {
62        Self {
63            inner,
64            count: Default::default(),
65        }
66    }
67
68    /// Return a copy of our current statistics
69    pub(crate) fn counts(&self) -> TcpCount {
70        self.count.lock().expect("lock poisoned").clone()
71    }
72}
73
74#[async_trait]
75impl<R: NetStreamProvider + Send + Sync> NetStreamProvider for Counting<R> {
76    type Stream = Counting<R::Stream>;
77
78    type Listener = Counting<R::Listener>;
79
80    async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
81        {
82            self.count.lock().expect("lock poisoned").n_connect_attempt += 1;
83        }
84
85        let inner = self.inner.connect(addr).await?;
86
87        {
88            self.count.lock().expect("lock poisoned").n_connect_ok += 1;
89        }
90
91        Ok(Counting {
92            inner,
93            count: self.count.clone(),
94        })
95    }
96
97    async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
98        let inner = self.inner.listen(addr).await?;
99        Ok(Counting {
100            inner,
101            count: self.count.clone(),
102        })
103    }
104}
105
106impl<S: AsyncRead> AsyncRead for Counting<S> {
107    fn poll_read(
108        self: Pin<&mut Self>,
109        cx: &mut Context<'_>,
110        buf: &mut [u8],
111    ) -> Poll<IoResult<usize>> {
112        let this = self.project();
113        let outcome = this.inner.poll_read(cx, buf);
114
115        if let Poll::Ready(Ok(n)) = outcome {
116            this.count.lock().expect("poisoned lock").n_bytes_recv += n;
117        }
118        outcome
119    }
120}
121
122impl<S: AsyncWrite> AsyncWrite for Counting<S> {
123    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
124        let this = self.project();
125        let outcome = this.inner.poll_write(cx, buf);
126
127        if let Poll::Ready(Ok(n)) = outcome {
128            this.count.lock().expect("poisoned lock").n_bytes_send += n;
129        }
130        outcome
131    }
132    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
133        self.project().inner.poll_flush(cx)
134    }
135    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
136        self.project().inner.poll_close(cx)
137    }
138}
139
140impl<S: StreamOps> StreamOps for Counting<S> {
141    fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
142        self.inner.set_tcp_notsent_lowat(notsent_lowat)
143    }
144
145    fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
146        self.inner.new_handle()
147    }
148}
149
150impl<S: NetStreamListener + Send + Sync> NetStreamListener for Counting<S> {
151    type Stream = Counting<S::Stream>;
152    type Incoming = Counting<S::Incoming>;
153
154    fn incoming(self) -> Self::Incoming {
155        Counting {
156            inner: self.inner.incoming(),
157            count: self.count,
158        }
159    }
160
161    fn local_addr(&self) -> IoResult<SocketAddr> {
162        self.inner.local_addr()
163    }
164}
165
166impl<S, T> Stream for Counting<S>
167where
168    S: Stream<Item = IoResult<(T, SocketAddr)>>,
169{
170    type Item = IoResult<(Counting<T>, SocketAddr)>;
171
172    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
173        let this = self.project();
174        let outcome = this.inner.poll_next(cx);
175
176        match outcome {
177            Poll::Ready(Some(Ok((inner, addr)))) => {
178                {
179                    this.count.lock().expect("lock poisoned").n_accept += 1;
180                }
181                Poll::Ready(Some(Ok((
182                    Counting {
183                        inner,
184                        count: this.count.clone(),
185                    },
186                    addr,
187                ))))
188            }
189            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
190            Poll::Ready(None) => Poll::Ready(None),
191            Poll::Pending => Poll::Pending,
192        }
193    }
194}