arti_testing/rt/
count.rs
1use futures::Stream;
4use tor_rtcompat::{NetStreamListener, NetStreamProvider, StreamOps};
5
6use async_trait::async_trait;
7use futures::io::{AsyncRead, AsyncWrite};
8use pin_project::pin_project;
9use std::io::Result as IoResult;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14
15#[derive(Debug, Clone, Default)]
17pub(crate) struct TcpCount {
18 pub(crate) n_connect_attempt: usize,
20 pub(crate) n_connect_ok: usize,
22 pub(crate) n_accept: usize,
24 pub(crate) n_bytes_send: usize,
26 pub(crate) n_bytes_recv: usize,
28}
29
30#[pin_project]
35pub(crate) struct Counting<R> {
36 #[pin]
38 inner: R,
39 count: Arc<Mutex<TcpCount>>,
41}
42
43impl<R> Clone for Counting<R>
44where
45 R: Clone,
46{
47 fn clone(&self) -> Self {
48 Self {
50 inner: self.inner.clone(),
51 count: self.count.clone(),
52 }
53 }
54}
55
56impl<R> Counting<R> {
57 pub(crate) fn new_zeroed(inner: R) -> Self
59 where
60 R: NetStreamProvider,
61 {
62 Self {
63 inner,
64 count: Default::default(),
65 }
66 }
67
68 pub(crate) fn counts(&self) -> TcpCount {
70 self.count.lock().expect("lock poisoned").clone()
71 }
72}
73
74#[async_trait]
75impl<R: NetStreamProvider + Send + Sync> NetStreamProvider for Counting<R> {
76 type Stream = Counting<R::Stream>;
77
78 type Listener = Counting<R::Listener>;
79
80 async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
81 {
82 self.count.lock().expect("lock poisoned").n_connect_attempt += 1;
83 }
84
85 let inner = self.inner.connect(addr).await?;
86
87 {
88 self.count.lock().expect("lock poisoned").n_connect_ok += 1;
89 }
90
91 Ok(Counting {
92 inner,
93 count: self.count.clone(),
94 })
95 }
96
97 async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
98 let inner = self.inner.listen(addr).await?;
99 Ok(Counting {
100 inner,
101 count: self.count.clone(),
102 })
103 }
104}
105
106impl<S: AsyncRead> AsyncRead for Counting<S> {
107 fn poll_read(
108 self: Pin<&mut Self>,
109 cx: &mut Context<'_>,
110 buf: &mut [u8],
111 ) -> Poll<IoResult<usize>> {
112 let this = self.project();
113 let outcome = this.inner.poll_read(cx, buf);
114
115 if let Poll::Ready(Ok(n)) = outcome {
116 this.count.lock().expect("poisoned lock").n_bytes_recv += n;
117 }
118 outcome
119 }
120}
121
122impl<S: AsyncWrite> AsyncWrite for Counting<S> {
123 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
124 let this = self.project();
125 let outcome = this.inner.poll_write(cx, buf);
126
127 if let Poll::Ready(Ok(n)) = outcome {
128 this.count.lock().expect("poisoned lock").n_bytes_send += n;
129 }
130 outcome
131 }
132 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
133 self.project().inner.poll_flush(cx)
134 }
135 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
136 self.project().inner.poll_close(cx)
137 }
138}
139
140impl<S: StreamOps> StreamOps for Counting<S> {
141 fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
142 self.inner.set_tcp_notsent_lowat(notsent_lowat)
143 }
144
145 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
146 self.inner.new_handle()
147 }
148}
149
150impl<S: NetStreamListener + Send + Sync> NetStreamListener for Counting<S> {
151 type Stream = Counting<S::Stream>;
152 type Incoming = Counting<S::Incoming>;
153
154 fn incoming(self) -> Self::Incoming {
155 Counting {
156 inner: self.inner.incoming(),
157 count: self.count,
158 }
159 }
160
161 fn local_addr(&self) -> IoResult<SocketAddr> {
162 self.inner.local_addr()
163 }
164}
165
166impl<S, T> Stream for Counting<S>
167where
168 S: Stream<Item = IoResult<(T, SocketAddr)>>,
169{
170 type Item = IoResult<(Counting<T>, SocketAddr)>;
171
172 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
173 let this = self.project();
174 let outcome = this.inner.poll_next(cx);
175
176 match outcome {
177 Poll::Ready(Some(Ok((inner, addr)))) => {
178 {
179 this.count.lock().expect("lock poisoned").n_accept += 1;
180 }
181 Poll::Ready(Some(Ok((
182 Counting {
183 inner,
184 count: this.count.clone(),
185 },
186 addr,
187 ))))
188 }
189 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
190 Poll::Ready(None) => Poll::Ready(None),
191 Poll::Pending => Poll::Pending,
192 }
193 }
194}