#![deny(missing_docs)] //! Various structs to represents DNS names and labels use bytes::{Bytes,Buf,BytesMut}; use errors::*; use std::fmt; use std::io::Cursor; use packet_data::{DnsPacketData,deserialize}; #[inline] fn check_label(label: &[u8]) -> Result<()> { if label.len() == 0 { bail!("label must not be empty") } if label.len() > 63 { bail!("label must not be longer than 63 bytes") } Ok(()) } /// A DNS label (any binary string with `0 < length < 64`) #[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash)] pub struct DnsLabel { label: Bytes, // 0 < len < 64 } impl DnsLabel { /// Create new label from existing storage /// /// Fails when the length doesn't match the requirement `0 < length < 64`. pub fn new(label: Bytes) -> Result { check_label(&label)?; Ok(DnsLabel{label}) } /// Convert to a representation without storage pub fn as_ref<'a>(&'a self) -> DnsLabelRef<'a> { DnsLabelRef{label: self.label.as_ref()} } /// Access as raw bytes pub fn as_bytes(&self) -> &Bytes { &self.label } /// Access as raw bytes pub fn as_raw(&self) -> &[u8] { &self.label } } impl<'a> From> for DnsLabel { fn from(label_ref: DnsLabelRef<'a>) -> Self { DnsLabel{ label: Bytes::from(label_ref.label), } } } impl fmt::Debug for DnsLabel { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { self.as_ref().fmt(w) } } impl fmt::Display for DnsLabel { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { self.as_ref().fmt(w) } } /// A DNS label (any binary string with `0 < length < 64`) /// /// Storage is provided through lifetime. #[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash)] pub struct DnsLabelRef<'a> { label: &'a [u8], // 0 < len < 64 } impl<'a> DnsLabelRef<'a> { /// Create new label from existing storage /// /// Fails when the length doesn't match the requirement `0 < length < 64`. pub fn new(label: &'a [u8]) -> Result { check_label(label)?; Ok(DnsLabelRef{label}) } /// Access as raw bytes pub fn as_raw(&self) -> &'a [u8] { self.label } } impl<'a> From<&'a DnsLabel> for DnsLabelRef<'a> { fn from(label: &'a DnsLabel) -> Self { label.as_ref() } } impl<'a> fmt::Debug for DnsLabelRef<'a> { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { // just escape the display version format!("{}", self).fmt(w) } } impl<'a> fmt::Display for DnsLabelRef<'a> { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { use std::str; let mut done = 0; for pos in 0..self.label.len() { let c = self.label[pos]; if c <= 0x21 || c >= 0x7e || b'.' == c || b'\\' == c { // flush if done < pos { w.write_str(unsafe {str::from_utf8_unchecked(&self.label[done..pos])})?; } match c { b'.' => w.write_str(r#"\."#)?, b'\\' => w.write_str(r#"\\"#)?, _ => write!(w, r"\{:03o}", c)?, } done = pos + 1; } } // final flush if done < self.label.len() { w.write_str(unsafe {str::from_utf8_unchecked(&self.label[done..])})?; } Ok(()) } } /// Customize formatting of DNS labels /// /// The default uses "." as separator and adds a trailing separator. /// /// The `Debug` formatters just format as `Display` to a String and then /// format that string as `Debug`. #[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] pub struct DisplayLabelsOptions { /// separator to insert between (escaped) labels pub separator: &'static str, /// whether a trailing separator is added. /// /// without a trailing separator the root zone is represented as /// empty string! pub trailing: bool, } impl Default for DisplayLabelsOptions { fn default() -> Self { DisplayLabelsOptions{ separator: ".", trailing: true, } } } /// Wrap anything representing a collection of labels (`DnsLabelRef`) to /// format using the given `options`. /// /// As name you can pass any cloneable `(Into)Iterator` with /// `DnsLabelRef` items, e.g: /// /// * `&DnsPlainName` /// * `&DnsName` /// * `DnsNameIterator` /// * `DnsPlainNameIterator` #[derive(Clone)] pub struct DisplayLabels<'a, I> where I: IntoIterator>+Clone { /// Label collection to iterate over pub labels: I, /// Options pub options: DisplayLabelsOptions, } impl<'a, I> fmt::Debug for DisplayLabels<'a, I> where I: IntoIterator>+Clone { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { // just escape the display version1 format!("{}", self).fmt(w) } } impl<'a, I> fmt::Display for DisplayLabels<'a, I> where I: IntoIterator>+Clone { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { let mut i = self.labels.clone().into_iter(); if let Some(first_label) = i.next() { // first label fmt::Display::fmt(&first_label, w)?; // remaining labels while let Some(label) = i.next() { w.write_str(self.options.separator)?; fmt::Display::fmt(&label, w)?; } } if self.options.trailing { w.write_str(self.options.separator)?; } Ok(()) } } /// A DNS name /// /// Uses the "original" uncompressed raw representation for storage /// (i.e. can share memory with a parsed packet) #[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash)] pub struct DnsPlainName { // uncompressed raw representation data: Bytes, // at most 255 bytes label_count: u8, // at most 127 labels; doesn't count the final (empty) label } impl DnsPlainName { /// Parse a name from raw bytes pub fn new(raw: Bytes) -> Result { deserialize(raw) } /// How many labels the name has (without the trailing empty label, /// at most 127) pub fn label_count(&self) -> u8 { self.label_count } /// Iterator over the labels (in the order they are stored in memory, /// i.e. top-level name last). pub fn labels<'a>(&'a self) -> DnsPlainNameIterator<'a> { DnsPlainNameIterator{ name_data: self.data.as_ref(), position: 0, labels_done: 0, label_count: self.label_count, } } } impl<'a> IntoIterator for &'a DnsPlainName { type Item = DnsLabelRef<'a>; type IntoIter = DnsPlainNameIterator<'a>; fn into_iter(self) -> Self::IntoIter { self.labels() } } impl fmt::Debug for DnsPlainName { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { DisplayLabels{ labels: self, options: Default::default(), }.fmt(w) } } impl fmt::Display for DnsPlainName { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { DisplayLabels{ labels: self, options: Default::default(), }.fmt(w) } } impl DnsPacketData for DnsPlainName { fn deserialize(data: &mut Cursor) -> Result { check_enough_data!(data, 1, "DnsPlainName"); let start_pos = data.position() as usize; let mut total_len : usize = 0; let mut label_count: u8 = 0; loop { check_enough_data!(data, 1, "DnsPlainName label len"); let label_len = data.get_u8() as usize; total_len += 1; if total_len > 255 { bail!{"DNS name too long"} } if 0 == label_len { break; } label_count += 1; // can't overflow: total_len <= 255, and each label so far was not empty, i.e. used at least two bytes. if label_len > 63 { bail!("Invalid label length {}", label_len) } total_len += label_len; if total_len > 255 { bail!{"DNS name too long"} } check_enough_data!(data, (label_len), "DnsPlainName label"); data.advance(label_len); } let end_pos = data.position() as usize; Ok(DnsPlainName{ data: data.get_ref().slice(start_pos, end_pos), label_count: label_count, }) } } /// Iterator type for [`DnsPlainName::labels`] /// /// [`DnsPlainName::labels`]: struct.DnsPlainName.html#method.labels #[derive(Clone)] pub struct DnsPlainNameIterator<'a> { name_data: &'a [u8], position: u8, labels_done: u8, label_count: u8, } impl<'a> Iterator for DnsPlainNameIterator<'a> { type Item = DnsLabelRef<'a>; fn next(&mut self) -> Option { if self.labels_done >= self.label_count { return None } self.labels_done += 1; let label_len = self.name_data[self.position as usize]; let end = self.position+1+label_len; let label = DnsLabelRef{label: &self.name_data[(self.position+1) as usize..end as usize]}; self.position = end; Some(label) } fn size_hint(&self) -> (usize, Option) { let count = self.len(); (count, Some(count)) } fn count(self) -> usize { self.len() } } impl<'a> ExactSizeIterator for DnsPlainNameIterator<'a> { fn len(&self) -> usize { (self.label_count - self.labels_done) as usize } } /// A DNS name /// /// Uses a modified representation for storage (i.e. can NOT share /// memory with a parsed packet), but supports a `DoubleEndedIterator` /// to iterate over the labels. #[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash)] pub struct DnsName { // similar to the DNS encoding; but each "length" octet is the "XOR" // of the length of the surrounding labels. data: Bytes, // at most 255 bytes label_count: u8, // at most 127 labels; doesn't count the final (empty) label } impl DnsName { /// Parse a name from raw bytes pub fn new(raw: Bytes) -> Result { deserialize(raw) } /// How many labels the name has (without the trailing empty label, /// at most 127) pub fn label_count(&self) -> u8 { self.label_count } /// Iterator over the labels (in the order they are stored in memory, /// i.e. top-level name last). pub fn labels(&self) -> DnsNameIterator { DnsNameIterator{ name_data: self.data.as_ref(), front_position: 0, front_prev_label_len: 0, back_position: self.data.len() as u8 - 1, back_next_label_len: 0, labels_done: 0, label_count: self.label_count, } } } impl From for DnsName { fn from(plain: DnsPlainName) -> Self { let mut data = plain.data.try_mut().unwrap_or_else(BytesMut::from); let mut pos : u8 = 0; let mut prev_len : u8 = 0; loop { let label_len = data[pos as usize]; data[pos as usize] ^= prev_len; if 0 == label_len { break; } pos += label_len + 1; prev_len = label_len; } DnsName{ data: data.freeze(), label_count: plain.label_count, } } } impl From for DnsPlainName { fn from(name: DnsName) -> DnsPlainName { let mut data = name.data.try_mut().unwrap_or_else(BytesMut::from); let mut pos : u8 = 0; let mut prev_len : u8 = 0; loop { let label_len = data[pos as usize] ^ prev_len; data[pos as usize] = label_len; if 0 == label_len { break; } pos += label_len + 1; prev_len = label_len; } DnsPlainName{ data: data.freeze(), label_count: name.label_count, } } } impl<'a> IntoIterator for &'a DnsName { type Item = DnsLabelRef<'a>; type IntoIter = DnsNameIterator<'a>; fn into_iter(self) -> Self::IntoIter { self.labels() } } impl fmt::Debug for DnsName { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { DisplayLabels{ labels: self, options: Default::default(), }.fmt(w) } } impl fmt::Display for DnsName { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { DisplayLabels{ labels: self, options: Default::default(), }.fmt(w) } } impl DnsPacketData for DnsName { fn deserialize(data: &mut Cursor) -> Result { Ok(DnsName::from(DnsPlainName::deserialize(data)?)) } } /// Iterator type for [`DnsName::labels`] /// /// [`DnsName::labels`]: struct.DnsName.html#method.labels #[derive(Clone)] pub struct DnsNameIterator<'a> { name_data: &'a [u8], front_position: u8, front_prev_label_len: u8, back_position: u8, back_next_label_len: u8, labels_done: u8, label_count: u8, } impl<'a> Iterator for DnsNameIterator<'a> { type Item = DnsLabelRef<'a>; fn next(&mut self) -> Option { if self.labels_done >= self.label_count { return None } self.labels_done += 1; let label_len = self.name_data[self.front_position as usize] ^ self.front_prev_label_len; self.front_prev_label_len = label_len; let end = self.front_position+1+label_len; let label = DnsLabelRef{label: &self.name_data[(self.front_position+1) as usize..end as usize]}; self.front_position = end; Some(label) } fn size_hint(&self) -> (usize, Option) { let count = self.len(); (count, Some(count)) } fn count(self) -> usize { self.len() } } impl<'a> ExactSizeIterator for DnsNameIterator<'a> { fn len(&self) -> usize { (self.label_count - self.labels_done) as usize } } impl<'a> DoubleEndedIterator for DnsNameIterator<'a> { fn next_back(&mut self) -> Option { if self.labels_done >= self.label_count { return None } self.labels_done += 1; let label_len = self.name_data[self.back_position as usize] ^ self.back_next_label_len; self.back_next_label_len = label_len; let end = self.back_position; self.back_position -= 1 + label_len; let label = DnsLabelRef{label: &self.name_data[(self.back_position+1) as usize..end as usize]}; Some(label) } } #[cfg(test)] mod tests { use super::*; fn do_parse_and_display_name() where T: fmt::Display+fmt::Debug+DnsPacketData, for<'a> &'a T: IntoIterator>, { { let name = deserialize::(Bytes::from_static(b"\x07example\x03com\x00")).unwrap(); assert_eq!( format!("{}", name ), "example.com." ); assert_eq!( name.into_iter().count(), 2 ); } assert_eq!( format!( "{}", deserialize::(Bytes::from_static(b"\x07e!am.l\\\x03com\x00")).unwrap() ), "e\\041am\\.l\\\\.com." ); assert_eq!( format!( "{:?}", deserialize::(Bytes::from_static(b"\x07e!am.l\\\x03com\x00")).unwrap() ), r#""e\\041am\\.l\\\\.com.""# ); } #[test] fn parse_and_display_plain_name() { do_parse_and_display_name::(); } #[test] fn parse_and_display_name() { do_parse_and_display_name::(); } #[test] fn parse_and_reverse_name() { let name = deserialize::(Bytes::from_static(b"\x03www\x07example\x03com\x00")).unwrap(); assert_eq!( format!( "{}", DisplayLabels{ labels: name.labels().rev(), options: DisplayLabelsOptions{ separator: " ", trailing: false, }, } ), "com example www" ); } #[test] fn parse_and_convert_names() { let name = deserialize::(Bytes::from_static(b"\x03www\x07example\x03com\x00")).unwrap(); assert_eq!( DnsPlainName::from(DnsName::from(name.clone())), name ); } }