use crate::errors::*; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::text::{next_field, parse_with, DnsTextContext, DnsTextData, DnsTextFormatter}; use bytes::Bytes; use smallvec::SmallVec; use std::fmt; use std::io::Cursor; use std::str::FromStr; use super::{DisplayLabels, DnsLabel, DnsLabelRef, DnsNameIterator, LabelOffsets}; /// A DNS name /// /// Uses the "original" raw representation for storage (i.e. can share /// memory with a parsed packet) #[derive(Clone)] pub struct DnsName { // in uncompressed form always includes terminating null octect; // but even in uncompressed form can include unused bytes at the // beginning // // may be empty for the root name (".", no labels) pub(super) data: Bytes, // either uncompressed or compressed offsets pub(super) label_offsets: LabelOffsets, // length of encoded form pub(super) total_len: u8, } impl DnsName { /// Create new name representing the DNS root (".") pub fn new_root() -> Self { DnsName { data: Bytes::new(), label_offsets: LabelOffsets::Uncompressed(SmallVec::new()), total_len: 1, } } /// Create new name representing the DNS root (".") and pre-allocate /// storage pub fn with_capacity(labels: u8, total_len: u8) -> Self { DnsName { data: Bytes::with_capacity(total_len as usize), label_offsets: LabelOffsets::Uncompressed(SmallVec::with_capacity(labels as usize)), total_len: 1, } } /// Parse text representation of a domain name pub fn parse(context: &DnsTextContext, value: &str) -> Result { super::name_text_parser::parse_name(context, value) } /// Returns whether name represents the DNS root (".") pub fn is_root(&self) -> bool { 0 == self.label_count() } /// How many labels the name has (without the trailing empty label, /// at most 127) pub fn label_count(&self) -> u8 { self.label_offsets.len() } /// Iterator over the labels (in the order they are stored in memory, /// i.e. top-level name last). pub fn labels<'a>(&'a self) -> DnsNameIterator<'a> { DnsNameIterator { name: &self, front_label: 0, back_label: self.label_offsets.len(), } } /// Return label at index `ndx` /// /// # Panics /// /// panics if `ndx >= self.label_count()`. pub fn label_ref<'a>(&'a self, ndx: u8) -> DnsLabelRef<'a> { let pos = self.label_offsets.label_pos(ndx); let label_len = self.data[pos]; debug_assert!(label_len < 64); let end = pos + 1 + label_len as usize; DnsLabelRef { label: &self.data[pos + 1..end], } } /// Return label at index `ndx` /// /// # Panics /// /// panics if `ndx >= self.label_count()`. pub fn label(&self, ndx: u8) -> DnsLabel { let pos = self.label_offsets.label_pos(ndx); let label_len = self.data[pos]; debug_assert!(label_len < 64); let end = pos + 1 + label_len as usize; DnsLabel { label: self.data.slice(pos + 1, end), } } } impl<'a> IntoIterator for &'a DnsName { type Item = DnsLabelRef<'a>; type IntoIter = DnsNameIterator<'a>; fn into_iter(self) -> Self::IntoIter { self.labels() } } impl PartialEq for DnsName { fn eq(&self, rhs: &DnsName) -> bool { let a_labels = self.labels(); let b_labels = rhs.labels(); if a_labels.len() != b_labels.len() { return false; } a_labels.zip(b_labels).all(|(a, b)| a == b) } } impl PartialEq for DnsName where T: AsRef, { fn eq(&self, rhs: &T) -> bool { self == rhs.as_ref() } } impl Eq for DnsName {} impl fmt::Debug for DnsName { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { DisplayLabels { labels: self, options: Default::default(), } .fmt(w) } } impl fmt::Display for DnsName { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { DisplayLabels { labels: self, options: Default::default(), } .fmt(w) } } impl FromStr for DnsName { type Err = ::failure::Error; fn from_str(s: &str) -> Result { parse_with(s, |data| DnsName::dns_parse(&DnsTextContext::new(), data)) } } impl DnsTextData for DnsName { fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result { let field = next_field(data)?; DnsName::parse(context, field) } fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { write!(f, "{}", self) } } impl DnsPacketData for DnsName { fn deserialize(data: &mut Cursor) -> Result { super::name_packet_parser::deserialize_name(data, false) } fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { context.write_uncompressed_name(packet, self) } }