1use futures::lock::Mutex;
7use futures::stream::StreamExt;
8use futures::task::SpawnExt;
9use hickory_proto::op::{
10 header::MessageType, op_code::OpCode, response_code::ResponseCode, Message, Query,
11};
12use hickory_proto::rr::{rdata, DNSClass, Name, RData, Record, RecordType};
13use hickory_proto::serialize::binary::{BinDecodable, BinEncodable};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::sync::Arc;
17use tracing::{debug, error, info, warn};
18
19use arti_client::{Error, HasKind, StreamPrefs, TorClient};
20use safelog::sensitive as sv;
21use tor_config::Listen;
22use tor_error::{error_report, warn_report};
23use tor_rtcompat::{Runtime, UdpSocket};
24
25use anyhow::{anyhow, Result};
26
27const MAX_DATAGRAM_SIZE: usize = 1536;
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35struct DnsIsolationKey(usize, IpAddr);
36
37impl arti_client::isolation::IsolationHelper for DnsIsolationKey {
38 fn compatible_same_type(&self, other: &Self) -> bool {
39 self == other
40 }
41
42 fn join_same_type(&self, other: &Self) -> Option<Self> {
43 if self == other {
44 Some(self.clone())
45 } else {
46 None
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53struct DnsCacheKey(DnsIsolationKey, Vec<Query>);
54
55#[derive(Debug, Clone)]
57struct DnsResponseTarget<U> {
58 id: u16,
60 addr: SocketAddr,
62 socket: Arc<U>,
64}
65
66async fn do_query<R>(
68 tor_client: TorClient<R>,
69 queries: &[Query],
70 prefs: &StreamPrefs,
71) -> Result<Vec<Record>, ResponseCode>
72where
73 R: Runtime,
74{
75 let mut answers = Vec::new();
76
77 let err_conv = |error: Error| {
78 if tor_error::ErrorKind::RemoteHostNotFound == error.kind() {
79 ResponseCode::NoError
81 } else {
82 ResponseCode::ServFail
83 }
84 };
85 for query in queries {
86 let mut a = Vec::new();
87 let mut ptr = Vec::new();
88
89 match query.query_class() {
92 DNSClass::IN => {
93 match query.query_type() {
94 typ @ RecordType::A | typ @ RecordType::AAAA => {
95 let mut name = query.name().clone();
96 name.set_fqdn(false);
98 let res = tor_client
99 .resolve_with_prefs(&name.to_utf8(), prefs)
100 .await
101 .map_err(err_conv)?;
102 for ip in res {
103 a.push((query.name().clone(), ip, typ));
104 }
105 }
106 RecordType::PTR => {
107 let addr = query
108 .name()
109 .parse_arpa_name()
110 .map_err(|_| ResponseCode::FormErr)?
111 .addr();
112 let res = tor_client
113 .resolve_ptr_with_prefs(addr, prefs)
114 .await
115 .map_err(err_conv)?;
116 for domain in res {
117 let domain =
118 Name::from_utf8(domain).map_err(|_| ResponseCode::ServFail)?;
119 ptr.push((query.name().clone(), domain));
120 }
121 }
122 _ => {
123 return Err(ResponseCode::NotImp);
124 }
125 }
126 }
127 _ => {
128 return Err(ResponseCode::NotImp);
129 }
130 }
131 for (name, ip, typ) in a {
132 match (ip, typ) {
133 (IpAddr::V4(v4), RecordType::A) => {
134 answers.push(Record::from_rdata(name, 3600, RData::A(rdata::A(v4))));
135 }
136 (IpAddr::V6(v6), RecordType::AAAA) => {
137 answers.push(Record::from_rdata(name, 3600, RData::AAAA(rdata::AAAA(v6))));
138 }
139 _ => (),
140 }
141 }
142 for (ptr, name) in ptr {
143 answers.push(Record::from_rdata(ptr, 3600, RData::PTR(rdata::PTR(name))));
144 }
145 }
146
147 Ok(answers)
148}
149
150async fn handle_dns_req<R, U>(
153 tor_client: TorClient<R>,
154 socket_id: usize,
155 packet: &[u8],
156 addr: SocketAddr,
157 socket: Arc<U>,
158 current_requests: &Mutex<HashMap<DnsCacheKey, Vec<DnsResponseTarget<U>>>>,
159) -> Result<()>
160where
161 R: Runtime,
162 U: UdpSocket,
163{
164 let mut query = Message::from_bytes(packet)?;
166 let id = query.id();
167 let queries = query.queries();
168 let isolation = DnsIsolationKey(socket_id, addr.ip());
169
170 let request_id = {
171 let request_id = DnsCacheKey(isolation.clone(), queries.to_vec());
172
173 let response_target = DnsResponseTarget { id, addr, socket };
174
175 let mut current_requests = current_requests.lock().await;
176
177 let req = current_requests.entry(request_id.clone()).or_default();
178 req.push(response_target);
179
180 if req.len() > 1 {
181 debug!("Received a query already being served");
182 return Ok(());
183 }
184 debug!("Received a new query");
185
186 request_id
187 };
188
189 let mut prefs = StreamPrefs::new();
190 prefs.set_isolation(isolation);
191
192 let mut response = match do_query(tor_client, queries, &prefs).await {
193 Ok(answers) => {
194 let mut response = Message::new();
195 response
196 .set_message_type(MessageType::Response)
197 .set_op_code(OpCode::Query)
198 .set_recursion_desired(query.recursion_desired())
199 .set_recursion_available(true)
200 .add_queries(query.take_queries())
201 .add_answers(answers);
202 response
204 }
205 Err(error_type) => Message::error_msg(id, OpCode::Query, error_type),
206 };
207
208 let targets = current_requests
210 .lock()
211 .await
212 .remove(&request_id)
213 .unwrap_or_default();
214
215 for target in targets {
216 response.set_id(target.id);
217 let response = match response.to_bytes() {
219 Ok(r) => r,
220 Err(e) => {
221 error_report!(e, "Failed to serialize DNS packet: {:?}", sv(&response));
226 continue;
227 }
228 };
229 let _ = target.socket.send(&response, &target.addr).await;
230 }
231 Ok(())
232}
233
234#[cfg_attr(feature = "experimental-api", visibility::make(pub))]
236pub(crate) async fn run_dns_resolver<R: Runtime>(
237 runtime: R,
238 tor_client: TorClient<R>,
239 listen: Listen,
240) -> Result<()> {
241 if !listen.is_localhost_only() {
242 warn!("Configured to listen for DNS on non-local addresses. This is usually insecure! We recommend listening on localhost only.");
243 }
244
245 let mut listeners = Vec::new();
246
247 match listen.ip_addrs() {
249 Ok(addrgroups) => {
250 for addrgroup in addrgroups {
251 for addr in addrgroup {
252 match runtime.bind(&addr).await {
255 Ok(listener) => {
256 info!("Listening on {:?}.", addr);
257 listeners.push(listener);
258 }
259 #[cfg(unix)]
260 Err(ref e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => {
261 warn_report!(e, "Address family not supported {}", addr);
262 }
263 Err(ref e) => {
264 return Err(anyhow!("Can't listen on {}: {e}", addr));
265 }
266 }
267 }
268 }
269 }
270 Err(e) => warn_report!(e, "Invalid listen spec"),
271 }
272 if listeners.is_empty() {
274 error!("Couldn't open any DNS listeners.");
275 return Err(anyhow!("Couldn't open any DNS listeners"));
276 }
277
278 let mut incoming = futures::stream::select_all(
279 listeners
280 .into_iter()
281 .map(|socket| {
282 futures::stream::unfold(Arc::new(socket), |socket| async {
283 let mut packet = [0; MAX_DATAGRAM_SIZE];
284 let packet = socket
285 .recv(&mut packet)
286 .await
287 .map(|(size, remote)| (packet, size, remote, socket.clone()));
288 Some((packet, socket))
289 })
290 })
291 .enumerate()
292 .map(|(listener_id, incoming_packet)| {
293 Box::pin(incoming_packet.map(move |packet| (packet, listener_id)))
294 }),
295 );
296
297 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
298 while let Some((packet, id)) = incoming.next().await {
299 let (packet, size, addr, socket) = match packet {
300 Ok(packet) => packet,
301 Err(err) => {
302 warn_report!(err, "Incoming datagram failed");
304 continue;
305 }
306 };
307
308 let client_ref = tor_client.clone();
309 runtime.spawn({
310 let pending_requests = pending_requests.clone();
311 async move {
312 let res = handle_dns_req(
313 client_ref,
314 id,
315 &packet[..size],
316 addr,
317 socket,
318 &pending_requests,
319 )
320 .await;
321 if let Err(e) = res {
322 warn!("connection exited with error: {}", tor_error::Report(e));
324 }
325 }
326 })?;
327 }
328
329 Ok(())
330}