1use crate::{CfgPath, CfgPathError};
4use serde::{Deserialize, Serialize};
5use std::{io, net, path::PathBuf, str::FromStr, sync::Arc};
6use tor_general_addr::{general, unix};
7
8#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
20#[serde(into = "CfgAddrSerde", try_from = "CfgAddrSerde")]
21pub struct CfgAddr(AddrInner);
22
23#[derive(Clone, Debug, Eq, PartialEq)]
27enum AddrInner {
28 Inet(net::SocketAddr),
30 Unix(CfgPath),
32}
33
34impl CfgAddr {
35 pub fn new_unix(path: CfgPath) -> Self {
41 CfgAddr(AddrInner::Unix(path))
42 }
43
44 pub fn address(
46 &self,
47 path_resolver: &crate::CfgPathResolver,
48 ) -> Result<general::SocketAddr, CfgAddrError> {
49 match &self.0 {
50 AddrInner::Inet(socket_addr) => {
51 Ok((*socket_addr).into())
53 }
54 AddrInner::Unix(cfg_path) => {
55 #[cfg(not(unix))]
56 {
57 return Err(unix::NoAfUnixSocketSupport::default().into());
59 }
60 #[cfg(unix)]
61 {
62 let addr = unix::SocketAddr::from_pathname(cfg_path.path(path_resolver)?)
63 .map_err(|e| CfgAddrError::ConstructAfUnixAddress(Arc::new(e)))?;
64 Ok(addr.into())
65 }
66 }
67 }
68 }
69
70 pub fn substitutions_will_apply(&self) -> bool {
74 match &self.0 {
75 AddrInner::Inet(_) => false,
76 AddrInner::Unix(_) => true,
77 }
78 }
79
80 fn try_to_string(&self) -> Result<String, &PathBuf> {
88 use crate::PathInner as PI;
89 use AddrInner as AI;
90 match &self.0 {
91 AI::Inet(socket_addr) => Ok(format!("inet:{}", socket_addr)),
92 AI::Unix(cfg_path) => match &cfg_path.0 {
93 PI::Shell(s) => Ok(format!("unix:{}", s)),
94 PI::Literal(path) => match path.literal.to_str() {
95 Some(literal_as_str) => Ok(format!("unix-literal:{}", literal_as_str)),
96 None => Err(&path.literal),
97 },
98 },
99 }
100 }
101}
102
103#[derive(Clone, Debug, thiserror::Error)]
105#[non_exhaustive]
106pub enum CfgAddrError {
107 #[error("No support for AF_UNIX addresses on this platform")]
109 NoAfUnixSocketSupport(#[from] unix::NoAfUnixSocketSupport),
110 #[error("Could not expand path")]
112 Path(#[from] CfgPathError),
113 #[error("Could not construct AF_UNIX address")]
117 ConstructAfUnixAddress(#[source] Arc<io::Error>),
118}
119
120impl FromStr for CfgAddr {
121 type Err = general::AddrParseError;
122
123 fn from_str(s: &str) -> Result<Self, Self::Err> {
124 if s.starts_with(|c: char| (c.is_ascii_digit() || c == '[')) {
127 Ok(s.parse::<net::SocketAddr>()?.into())
129 } else if let Some((schema, remainder)) = s.split_once(':') {
130 match schema {
131 "unix" => {
132 let path = CfgPath::new(remainder.to_string());
133 Ok(CfgAddr::new_unix(path))
134 }
135 "unix-literal" => {
136 let path = CfgPath::new_literal(remainder.to_string());
137 Ok(CfgAddr::new_unix(path))
138 }
139 "inet" => Ok(remainder.parse::<net::SocketAddr>()?.into()),
140 _ => Err(general::AddrParseError::UnrecognizedSchema(
141 schema.to_string(),
142 )),
143 }
144 } else {
145 Err(general::AddrParseError::NoSchema)
146 }
147 }
148}
149
150impl From<net::SocketAddr> for CfgAddr {
151 fn from(value: net::SocketAddr) -> Self {
152 CfgAddr(AddrInner::Inet(value))
153 }
154}
155impl TryFrom<unix::SocketAddr> for CfgAddr {
156 type Error = UnixAddrNotAPath;
157
158 fn try_from(value: unix::SocketAddr) -> Result<Self, Self::Error> {
159 Ok(Self::new_unix(CfgPath::new_literal(
162 value.as_pathname().ok_or(UnixAddrNotAPath)?,
163 )))
164 }
165}
166#[derive(Clone, Debug, Default, thiserror::Error)]
172#[non_exhaustive]
173#[error("Unix domain socket address was not a path.")]
174pub struct UnixAddrNotAPath;
175
176#[derive(Serialize, Deserialize)]
178#[serde(untagged)]
179enum CfgAddrSerde {
180 Str(String),
182 UnixLiteral {
185 unix_literal: PathBuf,
187 },
188}
189
190impl TryFrom<CfgAddrSerde> for CfgAddr {
191 type Error = general::AddrParseError;
192
193 fn try_from(value: CfgAddrSerde) -> Result<Self, Self::Error> {
194 use CfgAddrSerde as S;
195 match value {
196 S::Str(s) => s.parse(),
197 S::UnixLiteral { unix_literal } => {
198 Ok(CfgAddr::new_unix(CfgPath::new_literal(unix_literal)))
199 }
200 }
201 }
202}
203impl From<CfgAddr> for CfgAddrSerde {
204 fn from(value: CfgAddr) -> Self {
205 match value.try_to_string() {
206 Ok(s) => CfgAddrSerde::Str(s),
207 Err(unix_literal) => CfgAddrSerde::UnixLiteral {
208 unix_literal: unix_literal.clone(),
209 },
210 }
211 }
212}
213
214#[cfg(test)]
215mod test {
216 #![allow(clippy::bool_assert_comparison)]
218 #![allow(clippy::clone_on_copy)]
219 #![allow(clippy::dbg_macro)]
220 #![allow(clippy::mixed_attributes_style)]
221 #![allow(clippy::print_stderr)]
222 #![allow(clippy::print_stdout)]
223 #![allow(clippy::single_char_pattern)]
224 #![allow(clippy::unwrap_used)]
225 #![allow(clippy::unchecked_duration_subtraction)]
226 #![allow(clippy::useless_vec)]
227 #![allow(clippy::needless_pass_by_value)]
228 use super::*;
231 use assert_matches::assert_matches;
232 use std::path::PathBuf;
233
234 use crate::{home, CfgPathResolver};
235
236 #[test]
237 fn parse_inet_ok() {
238 fn check(s: &str) {
239 let resolv = CfgPathResolver::from_pairs([("FOO", "foo")]);
240 let a: general::SocketAddr = CfgAddr::from_str(s).unwrap().address(&resolv).unwrap();
241 assert_eq!(a, general::SocketAddr::from_str(s).unwrap());
242 }
243
244 check("127.0.0.1:9999");
245 check("inet:127.0.0.1:9999");
246 check("[2001:db8::413]:443");
247 check("inet:[2001:db8::413]:443");
248 }
249
250 #[test]
251 fn parse_inet_bad() {
252 assert_matches!(
253 CfgAddr::from_str("612"),
254 Err(general::AddrParseError::InvalidInetAddress(_))
255 );
256 assert_matches!(
257 CfgAddr::from_str("612unix:/home"),
258 Err(general::AddrParseError::InvalidInetAddress(_))
259 );
260 assert_matches!(
261 CfgAddr::from_str("127.0.0.1.1:99"),
262 Err(general::AddrParseError::InvalidInetAddress(_))
263 );
264 assert_matches!(
265 CfgAddr::from_str("inet:6"),
266 Err(general::AddrParseError::InvalidInetAddress(_))
267 );
268 assert_matches!(
269 CfgAddr::from_str("[[[[[]]]]]"),
270 Err(general::AddrParseError::InvalidInetAddress(_))
271 );
272 }
273
274 #[test]
275 fn parse_bad_schemas() {
276 assert_matches!(
277 CfgAddr::from_str("uranian:umbra"),
278 Err(general::AddrParseError::UnrecognizedSchema(_))
279 );
280 }
281
282 #[test]
283 fn unix_literal() {
284 let resolv = CfgPathResolver::from_pairs([("USER_HOME", home().unwrap())]);
285 let pb = PathBuf::from("${USER_HOME}/.local/socket");
286 let a1 = CfgAddr::new_unix(CfgPath::new_literal(&pb));
287 let a2 = CfgAddr::from_str("unix-literal:${USER_HOME}/.local/socket").unwrap();
288 #[cfg(unix)]
289 {
290 assert_eq!(a1.address(&resolv).unwrap(), a2.address(&resolv).unwrap(),);
291 match a1.address(&resolv).unwrap() {
292 general::SocketAddr::Unix(socket_addr) => {
293 assert!(socket_addr.as_pathname() == Some(pb.as_ref()));
295 }
296 _ => panic!("Expected a unix domain socket address"),
297 }
298 }
299 #[cfg(not(unix))]
300 assert_matches!(
301 a1.address(&resolv),
302 Err(CfgAddrError::NoAfUnixSocketSupport(_))
303 );
304 }
305
306 fn try_unix(addr: &str, want: &str, path_resolver: &CfgPathResolver) {
307 let p = CfgPath::new(want.to_string());
308 let expansion = p.path(path_resolver).unwrap();
309 let cfg_addr = CfgAddr::from_str(addr).unwrap();
310 assert_matches!(&cfg_addr.0, AddrInner::Unix(_));
311 #[cfg(unix)]
312 {
313 let gen_addr = cfg_addr.address(path_resolver).unwrap();
314 let expected_addr = unix::SocketAddr::from_pathname(expansion).unwrap();
315 assert_eq!(gen_addr, expected_addr.into());
316 }
317 #[cfg(not(unix))]
318 {
319 assert_matches!(
320 cfg_addr.address(path_resolver),
321 Err(CfgAddrError::NoAfUnixSocketSupport(_))
322 );
323 }
324 }
325
326 #[test]
327 fn unix_no_substitution() {
328 let resolver = CfgPathResolver::from_pairs([("FOO", "foo")]);
329 try_unix("unix:/home/mayor/.socket", "/home/mayor/.socket", &resolver);
330 }
331
332 #[test]
333 #[cfg(feature = "expand-paths")]
334 fn unix_substitution() {
335 let resolver = CfgPathResolver::from_pairs([("FOO", "foo")]);
336 try_unix("unix:${FOO}/socket", "${FOO}/socket", &resolver);
337 }
338
339 #[test]
340 fn serde() {
341 fn testcase_with_provided_addr(json: &str, addr: &CfgAddr) {
342 let a1: CfgAddr = serde_json::from_str(json).unwrap();
343 assert_eq!(&a1, addr);
344 let encoded = serde_json::to_string(&a1).unwrap();
345 let a2: CfgAddr = serde_json::from_str(&encoded).unwrap();
346 assert_eq!(&a2, addr);
347 }
348 fn testcase(json: &str, addr: &str) {
349 let addr = CfgAddr::from_str(addr).unwrap();
350 testcase_with_provided_addr(json, &addr);
351 }
352
353 testcase(r#" "inet:127.0.0.1:443" "#, "inet:127.0.0.1:443");
354 testcase(r#" "unix:${HOME}/socket" "#, "unix:${HOME}/socket");
355 testcase(
356 r#" "unix-literal:${HOME}/socket" "#,
357 "unix-literal:${HOME}/socket",
358 );
359 }
360}