1use std::path::Path;
15use std::sync::Arc;
16
17use crate::unix;
18use std::{io::Error as IoError, net};
19
20#[cfg(target_os = "android")]
21use std::os::android::net::SocketAddrExt as _;
22#[cfg(target_os = "linux")]
23use std::os::linux::net::SocketAddrExt as _;
24
25#[derive(Clone, Debug, derive_more::From, derive_more::TryInto)]
101#[non_exhaustive]
102pub enum SocketAddr {
103    Inet(net::SocketAddr),
105    Unix(unix::SocketAddr),
109}
110
111impl SocketAddr {
112    pub fn display_lossy(&self) -> DisplayLossy<'_> {
120        DisplayLossy(self)
121    }
122
123    pub fn try_to_string(&self) -> Option<String> {
127        use SocketAddr::*;
128        match self {
129            Inet(sa) => Some(format!("inet:{}", sa)),
130            Unix(sa) => {
131                if sa.is_unnamed() {
132                    Some("unix:".to_string())
133                } else {
134                    sa.as_pathname()
135                        .and_then(Path::to_str)
136                        .map(|p| format!("unix:{}", p))
137                }
138            }
139        }
140    }
141
142    pub fn as_pathname(&self) -> Option<&Path> {
145        match self {
146            SocketAddr::Inet(_) => None,
147            SocketAddr::Unix(socket_addr) => socket_addr.as_pathname(),
148        }
149    }
150}
151
152pub struct DisplayLossy<'a>(&'a SocketAddr);
154
155impl<'a> std::fmt::Display for DisplayLossy<'a> {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        use SocketAddr::*;
158        match self.0 {
159            Inet(sa) => write!(f, "inet:{}", sa),
160            Unix(sa) => {
161                if let Some(path) = sa.as_pathname() {
162                    if let Some(path_str) = path.to_str() {
163                        write!(f, "unix:{}", path_str)
164                    } else {
165                        write!(f, "unix:{} [lossy]", path.to_string_lossy())
166                    }
167                } else if sa.is_unnamed() {
168                    write!(f, "unix:")
169                } else {
170                    write!(f, "unix:{:?} [lossy]", sa)
171                }
172            }
173        }
174    }
175}
176
177impl std::str::FromStr for SocketAddr {
178    type Err = AddrParseError;
179
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        if s.starts_with(|c: char| c.is_ascii_digit() || c == '[') {
182            Ok(s.parse::<net::SocketAddr>()?.into())
184        } else if let Some((schema, remainder)) = s.split_once(':') {
185            match schema {
186                "unix" => Ok(unix::SocketAddr::from_pathname(remainder)?.into()),
187                "inet" => Ok(remainder.parse::<net::SocketAddr>()?.into()),
188                _ => Err(AddrParseError::UnrecognizedSchema(schema.to_string())),
189            }
190        } else {
191            Err(AddrParseError::NoSchema)
192        }
193    }
194}
195
196#[derive(Clone, Debug, thiserror::Error)]
198#[non_exhaustive]
199pub enum AddrParseError {
200    #[error("Address schema {0:?} unrecognized")]
202    UnrecognizedSchema(String),
203    #[error("Address did not look like internet, but had no address schema.")]
205    NoSchema,
206    #[error("Invalid AF_UNIX address")]
208    InvalidAfUnixAddress(#[source] Arc<IoError>),
209    #[error("Invalid internet address")]
211    InvalidInetAddress(#[from] std::net::AddrParseError),
212}
213
214impl From<IoError> for AddrParseError {
215    fn from(e: IoError) -> Self {
216        Self::InvalidAfUnixAddress(Arc::new(e))
217    }
218}
219
220impl PartialEq for SocketAddr {
221    fn eq(&self, other: &Self) -> bool {
232        match (self, other) {
233            (Self::Inet(l0), Self::Inet(r0)) => l0 == r0,
234            #[cfg(unix)]
235            (Self::Unix(l0), Self::Unix(r0)) => {
236                if l0.is_unnamed() && r0.is_unnamed() {
240                    return true;
241                }
242                if let (Some(a), Some(b)) = (l0.as_pathname(), r0.as_pathname()) {
243                    return a == b;
244                }
245                #[cfg(any(target_os = "android", target_os = "linux"))]
246                if let (Some(a), Some(b)) = (l0.as_abstract_name(), r0.as_abstract_name()) {
247                    return a == b;
248                }
249                false
250            }
251            _ => false,
252        }
253    }
254}
255
256#[cfg(feature = "arbitrary")]
257impl<'a> arbitrary::Arbitrary<'a> for SocketAddr {
258    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
259        #[allow(clippy::missing_docs_in_private_items)]
261        #[derive(arbitrary::Arbitrary)]
262        enum Kind {
263            V4,
264            V6,
265            #[cfg(unix)]
266            Unix,
267            #[cfg(any(target_os = "android", target_os = "linux"))]
268            UnixAbstract,
269        }
270        match u.arbitrary()? {
271            Kind::V4 => Ok(SocketAddr::Inet(
272                net::SocketAddrV4::new(u.arbitrary()?, u.arbitrary()?).into(),
273            )),
274            Kind::V6 => Ok(SocketAddr::Inet(
275                net::SocketAddrV6::new(
276                    u.arbitrary()?,
277                    u.arbitrary()?,
278                    u.arbitrary()?,
279                    u.arbitrary()?,
280                )
281                .into(),
282            )),
283            #[cfg(unix)]
284            Kind::Unix => {
285                let pathname: std::ffi::OsString = u.arbitrary()?;
286                Ok(SocketAddr::Unix(
287                    unix::SocketAddr::from_pathname(pathname)
288                        .map_err(|_| arbitrary::Error::IncorrectFormat)?,
289                ))
290            }
291            #[cfg(any(target_os = "android", target_os = "linux"))]
292            Kind::UnixAbstract => {
293                #[cfg(target_os = "android")]
294                use std::os::android::net::SocketAddrExt as _;
295                #[cfg(target_os = "linux")]
296                use std::os::linux::net::SocketAddrExt as _;
297                let name: &[u8] = u.arbitrary()?;
298                Ok(SocketAddr::Unix(
299                    unix::SocketAddr::from_abstract_name(name)
300                        .map_err(|_| arbitrary::Error::IncorrectFormat)?,
301                ))
302            }
303        }
304    }
305}
306
307#[cfg(test)]
308mod test {
309    #![allow(clippy::bool_assert_comparison)]
311    #![allow(clippy::clone_on_copy)]
312    #![allow(clippy::dbg_macro)]
313    #![allow(clippy::mixed_attributes_style)]
314    #![allow(clippy::print_stderr)]
315    #![allow(clippy::print_stdout)]
316    #![allow(clippy::single_char_pattern)]
317    #![allow(clippy::unwrap_used)]
318    #![allow(clippy::unchecked_duration_subtraction)]
319    #![allow(clippy::useless_vec)]
320    #![allow(clippy::needless_pass_by_value)]
321    use super::AddrParseError;
324    use crate::general;
325    use assert_matches::assert_matches;
326    #[cfg(unix)]
327    use std::os::unix::net as unix;
328    use std::{net, str::FromStr as _};
329
330    fn from_inet(s: &str) -> general::SocketAddr {
334        let a: net::SocketAddr = s.parse().unwrap();
335        a.into()
336    }
337
338    #[test]
339    fn ok_inet() {
340        assert_eq!(
341            from_inet("127.0.0.1:9999"),
342            general::SocketAddr::from_str("127.0.0.1:9999").unwrap()
343        );
344        assert_eq!(
345            from_inet("127.0.0.1:9999"),
346            general::SocketAddr::from_str("inet:127.0.0.1:9999").unwrap()
347        );
348
349        assert_eq!(
350            from_inet("[::1]:9999"),
351            general::SocketAddr::from_str("[::1]:9999").unwrap()
352        );
353        assert_eq!(
354            from_inet("[::1]:9999"),
355            general::SocketAddr::from_str("inet:[::1]:9999").unwrap()
356        );
357
358        assert_ne!(
359            general::SocketAddr::from_str("127.0.0.1:9999").unwrap(),
360            general::SocketAddr::from_str("[::1]:9999").unwrap()
361        );
362
363        let ga1 = from_inet("127.0.0.1:9999");
364        assert_eq!(ga1.display_lossy().to_string(), "inet:127.0.0.1:9999");
365        assert_eq!(ga1.try_to_string().unwrap(), "inet:127.0.0.1:9999");
366
367        let ga2 = from_inet("[::1]:9999");
368        assert_eq!(ga2.display_lossy().to_string(), "inet:[::1]:9999");
369        assert_eq!(ga2.try_to_string().unwrap(), "inet:[::1]:9999");
370    }
371
372    #[cfg(unix)]
376    fn from_pathname(s: impl AsRef<std::path::Path>) -> general::SocketAddr {
377        let a = unix::SocketAddr::from_pathname(s).unwrap();
378        a.into()
379    }
380    #[test]
381    #[cfg(unix)]
382    fn ok_unix() {
383        assert_eq!(
384            from_pathname("/some/path"),
385            general::SocketAddr::from_str("unix:/some/path").unwrap()
386        );
387        assert_eq!(
388            from_pathname("/another/path"),
389            general::SocketAddr::from_str("unix:/another/path").unwrap()
390        );
391        assert_eq!(
392            from_pathname("/path/with spaces"),
393            general::SocketAddr::from_str("unix:/path/with spaces").unwrap()
394        );
395        assert_ne!(
396            general::SocketAddr::from_str("unix:/some/path").unwrap(),
397            general::SocketAddr::from_str("unix:/another/path").unwrap()
398        );
399        assert_eq!(
400            from_pathname(""),
401            general::SocketAddr::from_str("unix:").unwrap()
402        );
403
404        let ga1 = general::SocketAddr::from_str("unix:/some/path").unwrap();
405        assert_eq!(ga1.display_lossy().to_string(), "unix:/some/path");
406        assert_eq!(ga1.try_to_string().unwrap(), "unix:/some/path");
407
408        let ga2 = general::SocketAddr::from_str("unix:/another/path").unwrap();
409        assert_eq!(ga2.display_lossy().to_string(), "unix:/another/path");
410        assert_eq!(ga2.try_to_string().unwrap(), "unix:/another/path");
411    }
412
413    #[test]
414    fn parse_err_inet() {
415        assert_matches!(
416            "1234567890:999".parse::<general::SocketAddr>(),
417            Err(AddrParseError::InvalidInetAddress(_))
418        );
419        assert_matches!(
420            "1z".parse::<general::SocketAddr>(),
421            Err(AddrParseError::InvalidInetAddress(_))
422        );
423        assert_matches!(
424            "[[77".parse::<general::SocketAddr>(),
425            Err(AddrParseError::InvalidInetAddress(_))
426        );
427
428        assert_matches!(
429            "inet:fred:9999".parse::<general::SocketAddr>(),
430            Err(AddrParseError::InvalidInetAddress(_))
431        );
432
433        assert_matches!(
434            "inet:127.0.0.1".parse::<general::SocketAddr>(),
435            Err(AddrParseError::InvalidInetAddress(_))
436        );
437
438        assert_matches!(
439            "inet:[::1]".parse::<general::SocketAddr>(),
440            Err(AddrParseError::InvalidInetAddress(_))
441        );
442    }
443
444    #[test]
445    fn parse_err_schemata() {
446        assert_matches!(
447            "fred".parse::<general::SocketAddr>(),
448            Err(AddrParseError::NoSchema)
449        );
450        assert_matches!(
451            "fred:".parse::<general::SocketAddr>(),
452            Err(AddrParseError::UnrecognizedSchema(f)) if f == "fred"
453        );
454        assert_matches!(
455            "fred:hello".parse::<general::SocketAddr>(),
456            Err(AddrParseError::UnrecognizedSchema(f)) if f == "fred"
457        );
458    }
459
460    #[test]
461    #[cfg(unix)]
462    fn display_unix_weird() {
463        use std::ffi::OsStr;
464        use std::os::unix::ffi::OsStrExt as _;
465
466        let a1 = from_pathname(OsStr::from_bytes(&[255, 255, 255, 255]));
467        assert!(a1.try_to_string().is_none());
468        assert_eq!(a1.display_lossy().to_string(), "unix:���� [lossy]");
469
470        let a2 = from_pathname("");
471        assert_eq!(a2.try_to_string().unwrap(), "unix:");
472        assert_eq!(a2.display_lossy().to_string(), "unix:");
473    }
474
475    #[test]
476    #[cfg(not(unix))]
477    fn parse_err_no_unix() {
478        assert_matches!(
479            "unix:".parse::<general::SocketAddr>(),
480            Err(AddrParseError::InvalidAfUnixAddress(_))
481        );
482        assert_matches!(
483            "unix:/any/path".parse::<general::SocketAddr>(),
484            Err(AddrParseError::InvalidAfUnixAddress(_))
485        );
486    }
487}