1use futures::lock::Mutex;
7use futures::stream::StreamExt;
8use hickory_proto::op::{
9 Message, Query, header::MessageType, op_code::OpCode, response_code::ResponseCode,
10};
11use hickory_proto::rr::{DNSClass, Name, RData, Record, RecordType, rdata};
12use hickory_proto::serialize::binary::{BinDecodable, BinEncodable};
13use std::collections::HashMap;
14use std::net::{IpAddr, SocketAddr};
15use std::sync::Arc;
16use tor_rtcompat::SpawnExt;
17use tracing::{debug, error, info, warn};
18
19use arti_client::{Error, HasKind, StreamPrefs, TorClient};
20use safelog::sensitive as sv;
21use tor_config::Listen;
22use tor_error::{error_report, warn_report};
23use tor_rtcompat::{Runtime, UdpSocket};
24
25use anyhow::{Result, anyhow};
26
27use crate::proxy::port_info;
28
29const MAX_DATAGRAM_SIZE: usize = 1536;
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37struct DnsIsolationKey(usize, IpAddr);
38
39impl arti_client::isolation::IsolationHelper for DnsIsolationKey {
40 fn compatible_same_type(&self, other: &Self) -> bool {
41 self == other
42 }
43
44 fn join_same_type(&self, other: &Self) -> Option<Self> {
45 if self == other {
46 Some(self.clone())
47 } else {
48 None
49 }
50 }
51
52 fn enables_long_lived_circuits(&self) -> bool {
53 false
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Hash)]
59struct DnsCacheKey(DnsIsolationKey, Vec<Query>);
60
61#[derive(Debug, Clone)]
63struct DnsResponseTarget<U> {
64 id: u16,
66 addr: SocketAddr,
68 socket: Arc<U>,
70}
71
72async fn do_query<R>(
74 tor_client: TorClient<R>,
75 queries: &[Query],
76 prefs: &StreamPrefs,
77) -> Result<Vec<Record>, ResponseCode>
78where
79 R: Runtime,
80{
81 let mut answers = Vec::new();
82
83 let err_conv = |error: Error| {
84 if tor_error::ErrorKind::RemoteHostNotFound == error.kind() {
85 ResponseCode::NoError
87 } else {
88 ResponseCode::ServFail
89 }
90 };
91 for query in queries {
92 let mut a = Vec::new();
93 let mut ptr = Vec::new();
94
95 match query.query_class() {
98 DNSClass::IN => {
99 match query.query_type() {
100 typ @ RecordType::A | typ @ RecordType::AAAA => {
101 let mut name = query.name().clone();
102 name.set_fqdn(false);
104 let res = tor_client
105 .resolve_with_prefs(&name.to_utf8(), prefs)
106 .await
107 .map_err(err_conv)?;
108 for ip in res {
109 a.push((query.name().clone(), ip, typ));
110 }
111 }
112 RecordType::PTR => {
113 let addr = query
114 .name()
115 .parse_arpa_name()
116 .map_err(|_| ResponseCode::FormErr)?
117 .addr();
118 let res = tor_client
119 .resolve_ptr_with_prefs(addr, prefs)
120 .await
121 .map_err(err_conv)?;
122 for domain in res {
123 let domain =
124 Name::from_utf8(domain).map_err(|_| ResponseCode::ServFail)?;
125 ptr.push((query.name().clone(), domain));
126 }
127 }
128 _ => {
129 return Err(ResponseCode::NotImp);
130 }
131 }
132 }
133 _ => {
134 return Err(ResponseCode::NotImp);
135 }
136 }
137 for (name, ip, typ) in a {
138 match (ip, typ) {
139 (IpAddr::V4(v4), RecordType::A) => {
140 answers.push(Record::from_rdata(name, 3600, RData::A(rdata::A(v4))));
141 }
142 (IpAddr::V6(v6), RecordType::AAAA) => {
143 answers.push(Record::from_rdata(name, 3600, RData::AAAA(rdata::AAAA(v6))));
144 }
145 _ => (),
146 }
147 }
148 for (ptr, name) in ptr {
149 answers.push(Record::from_rdata(ptr, 3600, RData::PTR(rdata::PTR(name))));
150 }
151 }
152
153 Ok(answers)
154}
155
156#[allow(clippy::cognitive_complexity)] async fn handle_dns_req<R, U>(
160 tor_client: TorClient<R>,
161 socket_id: usize,
162 packet: &[u8],
163 addr: SocketAddr,
164 socket: Arc<U>,
165 current_requests: &Mutex<HashMap<DnsCacheKey, Vec<DnsResponseTarget<U>>>>,
166) -> Result<()>
167where
168 R: Runtime,
169 U: UdpSocket,
170{
171 let mut query = Message::from_bytes(packet)?;
173 let id = query.id();
174 let queries = query.queries();
175 let isolation = DnsIsolationKey(socket_id, addr.ip());
176
177 let request_id = {
178 let request_id = DnsCacheKey(isolation.clone(), queries.to_vec());
179
180 let response_target = DnsResponseTarget { id, addr, socket };
181
182 let mut current_requests = current_requests.lock().await;
183
184 let req = current_requests.entry(request_id.clone()).or_default();
185 req.push(response_target);
186
187 if req.len() > 1 {
188 debug!("Received a query already being served");
189 return Ok(());
190 }
191 debug!("Received a new query");
192
193 request_id
194 };
195
196 let mut prefs = StreamPrefs::new();
197 prefs.set_isolation(isolation);
198
199 let mut response = match do_query(tor_client, queries, &prefs).await {
200 Ok(answers) => {
201 let mut response = Message::new();
202 response
203 .set_message_type(MessageType::Response)
204 .set_op_code(OpCode::Query)
205 .set_recursion_desired(query.recursion_desired())
206 .set_recursion_available(true)
207 .add_queries(query.take_queries())
208 .add_answers(answers);
209 response
211 }
212 Err(error_type) => Message::error_msg(id, OpCode::Query, error_type),
213 };
214
215 let targets = current_requests
217 .lock()
218 .await
219 .remove(&request_id)
220 .unwrap_or_default();
221
222 for target in targets {
223 response.set_id(target.id);
224 let response = match response.to_bytes() {
226 Ok(r) => r,
227 Err(e) => {
228 error_report!(e, "Failed to serialize DNS packet: {:?}", sv(&response));
233 continue;
234 }
235 };
236 let _ = target.socket.send(&response, &target.addr).await;
237 }
238 Ok(())
239}
240
241#[cfg_attr(feature = "experimental-api", visibility::make(pub))]
245#[allow(clippy::cognitive_complexity)] pub(crate) async fn launch_dns_resolver<R: Runtime>(
247 runtime: R,
248 tor_client: TorClient<R>,
249 listen: Listen,
250) -> Result<(impl Future<Output = Result<()>>, Vec<port_info::Port>)> {
251 if !listen.is_loopback_only() {
252 warn!(
253 "Configured to listen for DNS on non-local addresses. This is usually insecure! We recommend listening on localhost only."
254 );
255 }
256
257 let mut listeners = Vec::new();
258 let mut listening_on = Vec::new();
259
260 match listen.ip_addrs() {
262 Ok(addrgroups) => {
263 for addrgroup in addrgroups {
264 for addr in addrgroup {
265 match runtime.bind(&addr).await {
268 Ok(listener) => {
269 let bound_addr = listener.local_addr()?;
270 info!("Listening on {:?}.", bound_addr);
271 listeners.push(listener);
272 listening_on.push(bound_addr);
273 }
274 #[cfg(unix)]
275 Err(ref e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => {
276 warn_report!(e, "Address family not supported {}", addr);
277 }
278 Err(ref e) => {
279 return Err(anyhow!("Can't listen on {}: {e}", addr));
280 }
281 }
282 }
283 }
285 }
286 Err(e) => warn_report!(e, "Invalid listen spec"),
287 }
288 if listeners.is_empty() {
290 error!("Couldn't open any DNS listeners.");
291 return Err(anyhow!("Couldn't open any DNS listeners"));
292 }
293
294 let ports = listening_on
295 .iter()
296 .map(|sockaddr| port_info::Port {
297 protocol: port_info::SupportedProtocol::DnsUdp,
298 address: (*sockaddr).into(),
299 })
300 .collect();
301
302 Ok((
303 run_dns_resolver_with_listeners(runtime, tor_client, listeners),
304 ports,
305 ))
306}
307
308async fn run_dns_resolver_with_listeners<R: Runtime>(
310 runtime: R,
311 tor_client: TorClient<R>,
312 listeners: Vec<<R as tor_rtcompat::UdpProvider>::UdpSocket>,
313) -> Result<()> {
314 let mut incoming = futures::stream::select_all(
315 listeners
316 .into_iter()
317 .map(|socket| {
318 futures::stream::unfold(Arc::new(socket), |socket| async {
319 let mut packet = [0; MAX_DATAGRAM_SIZE];
320 let packet = socket
321 .recv(&mut packet)
322 .await
323 .map(|(size, remote)| (packet, size, remote, socket.clone()));
324 Some((packet, socket))
325 })
326 })
327 .enumerate()
328 .map(|(listener_id, incoming_packet)| {
329 Box::pin(incoming_packet.map(move |packet| (packet, listener_id)))
330 }),
331 );
332
333 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
334 while let Some((packet, id)) = incoming.next().await {
335 let (packet, size, addr, socket) = match packet {
336 Ok(packet) => packet,
337 Err(err) => {
338 warn_report!(err, "Incoming datagram failed");
340 continue;
341 }
342 };
343
344 let client_ref = tor_client.clone();
345 runtime.spawn({
346 let pending_requests = pending_requests.clone();
347 async move {
348 let res = handle_dns_req(
349 client_ref,
350 id,
351 &packet[..size],
352 addr,
353 socket,
354 &pending_requests,
355 )
356 .await;
357 if let Err(e) = res {
358 warn!("connection exited with error: {}", tor_error::Report(e));
360 }
361 }
362 })?;
363 }
364
365 Ok(())
366}