From 0261d27764588c8cd00d8fe671eb471872f343e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sat, 10 Feb 2018 11:32:25 +0100 Subject: [PATCH] step --- lib/dnsbox-base/src/common_types/classes.rs | 3 - .../src/common_types/name/canonical_name.rs | 127 +++++ .../src/common_types/name/compressed_name.rs | 124 ++++ .../src/common_types/name/label_offsets.rs | 44 ++ lib/dnsbox-base/src/common_types/name/mod.rs | 538 +----------------- lib/dnsbox-base/src/common_types/name/name.rs | 178 ++++++ .../src/common_types/name/name_iterator.rs | 46 ++ .../common_types/name/name_packet_parser.rs | 134 +++-- .../src/common_types/name/name_text_parser.rs | 162 ++---- .../src/common_types/name/tests.rs | 219 +++++++ lib/dnsbox-base/src/common_types/types.rs | 1 - lib/dnsbox-base/src/records/registry.rs | 1 - lib/dnsbox-base/src/ser/packet/write.rs | 80 ++- 13 files changed, 945 insertions(+), 712 deletions(-) create mode 100644 lib/dnsbox-base/src/common_types/name/canonical_name.rs create mode 100644 lib/dnsbox-base/src/common_types/name/compressed_name.rs create mode 100644 lib/dnsbox-base/src/common_types/name/label_offsets.rs create mode 100644 lib/dnsbox-base/src/common_types/name/name.rs create mode 100644 lib/dnsbox-base/src/common_types/name/name_iterator.rs create mode 100644 lib/dnsbox-base/src/common_types/name/tests.rs diff --git a/lib/dnsbox-base/src/common_types/classes.rs b/lib/dnsbox-base/src/common_types/classes.rs index bcfae07..6305e7a 100644 --- a/lib/dnsbox-base/src/common_types/classes.rs +++ b/lib/dnsbox-base/src/common_types/classes.rs @@ -133,7 +133,6 @@ impl Class { /// /// Avoids conflict with parsing RRTYPE mnemonics. pub fn from_known_name_without_any(name: &str) -> Option { - use std::ascii::AsciiExt; if name.eq_ignore_ascii_case("IN") { return Some(IN); } if name.eq_ignore_ascii_case("CH") { return Some(CH); } if name.eq_ignore_ascii_case("HS") { return Some(HS); } @@ -143,7 +142,6 @@ impl Class { /// parses known names (mnemonics) pub fn from_known_name(name: &str) -> Option { - use std::ascii::AsciiExt; Self::from_known_name_without_any(name).or_else(|| { if name.eq_ignore_ascii_case("ANY") { return Some(ANY); } None @@ -152,7 +150,6 @@ impl Class { /// parses generic names of the form "CLASS..." pub fn from_generic_name(name: &str) -> Option { - use std::ascii::AsciiExt; if name.len() > 5 && name.as_bytes()[0..5].eq_ignore_ascii_case(b"CLASS") { name[5..].parse::().ok().map(Class) } else { diff --git a/lib/dnsbox-base/src/common_types/name/canonical_name.rs b/lib/dnsbox-base/src/common_types/name/canonical_name.rs new file mode 100644 index 0000000..27791fc --- /dev/null +++ b/lib/dnsbox-base/src/common_types/name/canonical_name.rs @@ -0,0 +1,127 @@ +use bytes::Bytes; +use errors::*; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; +use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with}; +use std::fmt; +use std::io::Cursor; +use std::ops::{Deref, DerefMut}; +use std::str::FromStr; + +use super::{DnsName, DnsNameIterator, DnsLabelRef}; + +/// names that should be written in canonical form for DNSSEC according +/// to https://tools.ietf.org/html/rfc4034#section-6.2 +/// +/// DnsCompressedName always needs to be written in canonical form for +/// DNSSEC. +#[derive(Clone)] +pub struct DnsCanonicalName(pub DnsName); + +impl DnsCanonicalName { + /// Create new name representing the DNS root (".") + pub fn new_root() -> Self { + DnsCanonicalName(DnsName::new_root()) + } + + /// Parse text representation of a domain name + pub fn parse(context: &DnsTextContext, value: &str) -> Result + { + Ok(DnsCanonicalName(DnsName::parse(context, value)?)) + } +} + +impl Deref for DnsCanonicalName { + type Target = DnsName; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DnsCanonicalName { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl AsRef for DnsCanonicalName { + fn as_ref(&self) -> &DnsName { + &self.0 + } +} + +impl AsMut for DnsCanonicalName { + fn as_mut(&mut self) -> &mut DnsName { + &mut self.0 + } +} + +impl<'a> IntoIterator for &'a DnsCanonicalName { + type Item = DnsLabelRef<'a>; + type IntoIter = DnsNameIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.labels() + } +} + +impl PartialEq for DnsCanonicalName +{ + fn eq(&self, rhs: &DnsName) -> bool { + let this: &DnsName = self; + this == rhs + } +} + +impl PartialEq for DnsCanonicalName +where + T: AsRef +{ + fn eq(&self, rhs: &T) -> bool { + let this: &DnsName = self.as_ref(); + this == rhs + } +} + +impl Eq for DnsCanonicalName{} + +impl fmt::Debug for DnsCanonicalName { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(w) + } +} + +impl fmt::Display for DnsCanonicalName { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(w) + } +} + +impl FromStr for DnsCanonicalName { + type Err = ::failure::Error; + + fn from_str(s: &str) -> Result { + parse_with(s, |data| DnsCanonicalName::dns_parse(&DnsTextContext::new(), data)) + } +} + +impl DnsTextData for DnsCanonicalName { + fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result { + let field = next_field(data)?; + DnsCanonicalName::parse(context, field) + } + + fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { + self.0.dns_format(f) + } +} + +impl DnsPacketData for DnsCanonicalName { + fn deserialize(data: &mut Cursor) -> Result { + Ok(DnsCanonicalName(super::name_packet_parser::deserialize_name(data, false)?)) + } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + context.write_canonical_name(packet, self) + } +} diff --git a/lib/dnsbox-base/src/common_types/name/compressed_name.rs b/lib/dnsbox-base/src/common_types/name/compressed_name.rs new file mode 100644 index 0000000..a5ede5b --- /dev/null +++ b/lib/dnsbox-base/src/common_types/name/compressed_name.rs @@ -0,0 +1,124 @@ +use bytes::Bytes; +use errors::*; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; +use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with}; +use std::fmt; +use std::io::Cursor; +use std::ops::{Deref, DerefMut}; +use std::str::FromStr; + +use super::{DnsName, DnsNameIterator, DnsLabelRef}; + +/// Similar to `DnsName`, but allows using compressed labels in the +/// serialized form +#[derive(Clone)] +pub struct DnsCompressedName(pub DnsName); + +impl DnsCompressedName { + /// Create new name representing the DNS root (".") + pub fn new_root() -> Self { + DnsCompressedName(DnsName::new_root()) + } + + /// Parse text representation of a domain name + pub fn parse(context: &DnsTextContext, value: &str) -> Result + { + Ok(DnsCompressedName(DnsName::parse(context, value)?)) + } +} + +impl Deref for DnsCompressedName { + type Target = DnsName; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DnsCompressedName { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl AsRef for DnsCompressedName { + fn as_ref(&self) -> &DnsName { + &self.0 + } +} + +impl AsMut for DnsCompressedName { + fn as_mut(&mut self) -> &mut DnsName { + &mut self.0 + } +} + +impl<'a> IntoIterator for &'a DnsCompressedName { + type Item = DnsLabelRef<'a>; + type IntoIter = DnsNameIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.labels() + } +} + +impl PartialEq for DnsCompressedName +{ + fn eq(&self, rhs: &DnsName) -> bool { + let this: &DnsName = self; + this == rhs + } +} + +impl PartialEq for DnsCompressedName +where + T: AsRef +{ + fn eq(&self, rhs: &T) -> bool { + let this: &DnsName = self.as_ref(); + this == rhs + } +} + +impl Eq for DnsCompressedName{} + +impl fmt::Debug for DnsCompressedName { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(w) + } +} + +impl fmt::Display for DnsCompressedName { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(w) + } +} + +impl FromStr for DnsCompressedName { + type Err = ::failure::Error; + + fn from_str(s: &str) -> Result { + parse_with(s, |data| DnsCompressedName::dns_parse(&DnsTextContext::new(), data)) + } +} + +impl DnsTextData for DnsCompressedName { + fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result { + let field = next_field(data)?; + DnsCompressedName::parse(context, field) + } + + fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { + self.0.dns_format(f) + } +} + +impl DnsPacketData for DnsCompressedName { + fn deserialize(data: &mut Cursor) -> Result { + Ok(DnsCompressedName(super::name_packet_parser::deserialize_name(data, true)?)) + } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + context.write_compressed_name(packet, self) + } +} diff --git a/lib/dnsbox-base/src/common_types/name/label_offsets.rs b/lib/dnsbox-base/src/common_types/name/label_offsets.rs new file mode 100644 index 0000000..230a416 --- /dev/null +++ b/lib/dnsbox-base/src/common_types/name/label_offsets.rs @@ -0,0 +1,44 @@ +use smallvec::SmallVec; + +#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] +pub enum LabelOffset { + LabelStart(u8), + PacketStart(u16), +} + +// the heap meta data is usually at least 2*usize big; assuming 64-bit +// platforms it should be ok to use 16 bytes in the smallvec. +#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] +pub enum LabelOffsets { + Uncompressed(SmallVec<[u8;16]>), + Compressed(usize, SmallVec<[LabelOffset;4]>), +} + +impl LabelOffsets { + pub fn len(&self) -> u8 { + let l = match *self { + LabelOffsets::Uncompressed(ref offs) => offs.len(), + LabelOffsets::Compressed(_, ref offs) => offs.len(), + }; + debug_assert!(l < 128); + l as u8 + } + + pub fn label_pos(&self, ndx: u8) -> usize { + debug_assert!(ndx < 127); + match *self { + LabelOffsets::Uncompressed(ref offs) => offs[ndx as usize] as usize, + LabelOffsets::Compressed(start, ref offs) => match offs[ndx as usize] { + LabelOffset::LabelStart(o) => start + (o as usize), + LabelOffset::PacketStart(o) => o as usize, + } + } + } + + pub fn is_compressed(&self) -> bool { + match *self { + LabelOffsets::Uncompressed(_) => false, + LabelOffsets::Compressed(_, _) => true, + } + } +} diff --git a/lib/dnsbox-base/src/common_types/name/mod.rs b/lib/dnsbox-base/src/common_types/name/mod.rs index b4dcdad..d763ef2 100644 --- a/lib/dnsbox-base/src/common_types/name/mod.rs +++ b/lib/dnsbox-base/src/common_types/name/mod.rs @@ -3,543 +3,27 @@ use bytes::Bytes; use errors::*; -use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use smallvec::SmallVec; -use std::fmt; use std::io::Cursor; -use std::ops::{Deref, DerefMut}; +pub use self::canonical_name::*; +pub use self::compressed_name::*; pub use self::display::*; pub use self::label::*; +pub use self::name_iterator::*; +pub use self::name::*; +use self::label_offsets::*; +mod canonical_name; +mod compressed_name; mod display; mod label; +mod label_offsets; +mod name; +mod name_iterator; mod name_mutations; mod name_packet_parser; mod name_text_parser; -#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] -enum LabelOffset { - LabelStart(u8), - PacketStart(u16), -} - -// the heap meta data is usually at least 2*usize big; assuming 64-bit -// platforms it should be ok to use 16 bytes in the smallvec. -#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] -enum LabelOffsets { - Uncompressed(SmallVec<[u8;16]>), - Compressed(usize, SmallVec<[LabelOffset;4]>), -} - -impl LabelOffsets { - fn len(&self) -> u8 { - let l = match *self { - LabelOffsets::Uncompressed(ref offs) => offs.len(), - LabelOffsets::Compressed(_, ref offs) => offs.len(), - }; - debug_assert!(l < 128); - l as u8 - } - - fn label_pos(&self, ndx: u8) -> usize { - debug_assert!(ndx < 127); - match *self { - LabelOffsets::Uncompressed(ref offs) => offs[ndx as usize] as usize, - LabelOffsets::Compressed(start, ref offs) => match offs[ndx as usize] { - LabelOffset::LabelStart(o) => start + (o as usize), - LabelOffset::PacketStart(o) => o as usize, - } - } - } - - fn is_compressed(&self) -> bool { - match *self { - LabelOffsets::Uncompressed(_) => false, - LabelOffsets::Compressed(_, _) => true, - } - } -} - -/// A DNS name -/// -/// Uses the "original" raw representation for storage (i.e. can share -/// memory with a parsed packet) -#[derive(Clone)] -pub struct DnsName { - // in uncompressed form always includes terminating null octect; - // but even in uncompressed form can include unused bytes at the - // beginning - // - // may be empty for the root name (".", no labels) - data: Bytes, - // either uncompressed or compressed offsets - label_offsets: LabelOffsets, - // length of encoded form - total_len: u8, -} - -impl DnsName { - /// Create new name representing the DNS root (".") - pub fn new_root() -> Self { - DnsName{ - data: Bytes::new(), - label_offsets: LabelOffsets::Uncompressed(SmallVec::new()), - total_len: 1, - } - } - - /// Create new name representing the DNS root (".") and pre-allocate - /// storage - pub fn with_capacity(labels: u8, total_len: u8) -> Self { - DnsName{ - data: Bytes::with_capacity(total_len as usize), - label_offsets: LabelOffsets::Uncompressed(SmallVec::with_capacity(labels as usize)), - total_len: 1, - } - } - - /// Returns whether name represents the DNS root (".") - pub fn is_root(&self) -> bool { - 0 == self.label_count() - } - - /// How many labels the name has (without the trailing empty label, - /// at most 127) - pub fn label_count(&self) -> u8 { - self.label_offsets.len() - } - - /// Iterator over the labels (in the order they are stored in memory, - /// i.e. top-level name last). - pub fn labels<'a>(&'a self) -> DnsNameIterator<'a> { - DnsNameIterator{ - name: &self, - front_label: 0, - back_label: self.label_offsets.len(), - } - } - - /// Return label at index `ndx` - /// - /// # Panics - /// - /// panics if `ndx >= self.label_count()`. - pub fn label_ref<'a>(&'a self, ndx: u8) -> DnsLabelRef<'a> { - let pos = self.label_offsets.label_pos(ndx); - let label_len = self.data[pos]; - debug_assert!(label_len < 64); - let end = pos + 1 + label_len as usize; - DnsLabelRef{label: &self.data[pos + 1..end]} - } - - /// Return label at index `ndx` - /// - /// # Panics - /// - /// panics if `ndx >= self.label_count()`. - pub fn label(&self, ndx: u8) -> DnsLabel { - let pos = self.label_offsets.label_pos(ndx); - let label_len = self.data[pos]; - debug_assert!(label_len < 64); - let end = pos + 1 + label_len as usize; - DnsLabel{label: self.data.slice(pos + 1, end) } - } -} - -impl<'a> IntoIterator for &'a DnsName { - type Item = DnsLabelRef<'a>; - type IntoIter = DnsNameIterator<'a>; - - fn into_iter(self) -> Self::IntoIter { - self.labels() - } -} - -impl PartialEq for DnsName { - fn eq(&self, rhs: &DnsName) -> bool { - let a_labels = self.labels(); - let b_labels = rhs.labels(); - if a_labels.len() != b_labels.len() { return false; } - a_labels.zip(b_labels).all(|(a,b)| a == b) - } -} - -impl Eq for DnsName{} - -impl fmt::Debug for DnsName { - fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { - DisplayLabels{ - labels: self, - options: Default::default(), - }.fmt(w) - } -} - -impl fmt::Display for DnsName { - fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { - DisplayLabels{ - labels: self, - options: Default::default(), - }.fmt(w) - } -} - -impl DnsPacketData for DnsName { - fn deserialize(data: &mut Cursor) -> Result { - DnsName::parse_name(data, false) - } - - fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { - context.write_uncompressed_name(packet, self) - } -} - -/// Similar to `DnsName`, but allows using compressed labels in the -/// serialized form -#[derive(Clone)] -pub struct DnsCompressedName(pub DnsName); - -impl DnsCompressedName { - /// Create new name representing the DNS root (".") - pub fn new_root() -> Self { - DnsCompressedName(DnsName::new_root()) - } -} - -impl Deref for DnsCompressedName { - type Target = DnsName; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for DnsCompressedName { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl<'a> IntoIterator for &'a DnsCompressedName { - type Item = DnsLabelRef<'a>; - type IntoIter = DnsNameIterator<'a>; - - fn into_iter(self) -> Self::IntoIter { - self.labels() - } -} - -impl PartialEq for DnsCompressedName { - fn eq(&self, rhs: &DnsCompressedName) -> bool { - self.0 == rhs.0 - } -} - -impl PartialEq for DnsCompressedName { - fn eq(&self, rhs: &DnsName) -> bool { - &self.0 == rhs - } -} - -impl PartialEq for DnsName { - fn eq(&self, rhs: &DnsCompressedName) -> bool { - self == &rhs.0 - } -} - -impl Eq for DnsCompressedName{} - -impl fmt::Debug for DnsCompressedName { - fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(w) - } -} - -impl fmt::Display for DnsCompressedName { - fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(w) - } -} - -impl DnsPacketData for DnsCompressedName { - fn deserialize(data: &mut Cursor) -> Result { - Ok(DnsCompressedName(DnsName::parse_name(data, true)?)) - } - - fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { - context.write_compressed_name(packet, self) - } -} - -/// names that should be written in canonical form for DNSSEC according -/// to https://tools.ietf.org/html/rfc4034#section-6.2 -/// -/// TODO: make it a newtype. -/// -/// DnsCompressedName always needs to be written in canonical form for -/// DNSSEC. -pub type DnsCanonicalName = DnsName; - -/// Iterator type for [`DnsName::labels`] -/// -/// [`DnsName::labels`]: struct.DnsName.html#method.labels -#[derive(Clone)] -pub struct DnsNameIterator<'a> { - name: &'a DnsName, - front_label: u8, - back_label: u8, -} - -impl<'a> Iterator for DnsNameIterator<'a> { - type Item = DnsLabelRef<'a>; - - fn next(&mut self) -> Option { - if self.front_label >= self.back_label { return None } - let label = self.name.label_ref(self.front_label); - self.front_label += 1; - Some(label) - } - - fn size_hint(&self) -> (usize, Option) { - let count = self.len(); - (count, Some(count)) - } - - fn count(self) -> usize { - self.len() - } -} - -impl<'a> ExactSizeIterator for DnsNameIterator<'a> { - fn len(&self) -> usize { - (self.back_label - self.front_label) as usize - } -} - -impl<'a> DoubleEndedIterator for DnsNameIterator<'a> { - fn next_back(&mut self) -> Option { - if self.front_label >= self.back_label { return None } - self.back_label -= 1; - let label = self.name.label_ref(self.back_label); - Some(label) - } -} - #[cfg(test)] -mod tests { - use ser::packet; - use super::*; - -/* - fn deserialize(bytes: &'static [u8]) -> Result { - let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?; - { - let check_result = packet::deserialize_with(result.clone().encode(), DnsName::deserialize).unwrap(); - assert_eq!(check_result, result); - } - Ok(result) - } -*/ - - fn de_uncompressed(bytes: &'static [u8]) -> Result { - let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?; - assert_eq!(bytes, result.clone().encode()); - Ok(result) - } - - fn check_uncompressed_display(bytes: &'static [u8], txt: &str, label_count: u8) { - let name = de_uncompressed(bytes).unwrap(); - assert_eq!( - name.labels().count(), - label_count as usize - ); - assert_eq!( - format!("{}", name), - txt - ); - } - - fn check_uncompressed_debug(bytes: &'static [u8], txt: &str) { - let name = de_uncompressed(bytes).unwrap(); - assert_eq!( - format!("{:?}", name), - txt - ); - } - - #[test] - fn parse_and_display_name() { - check_uncompressed_display( - b"\x07example\x03com\x00", - "example.com.", - 2, - ); - check_uncompressed_display( - b"\x07e!am.l\\\x03com\x00", - "e\\033am\\.l\\\\.com.", - 2, - ); - check_uncompressed_debug( - b"\x07e!am.l\\\x03com\x00", - r#""e\\033am\\.l\\\\.com.""#, - ); - } - - #[test] - fn parse_and_reverse_name() { - let name = de_uncompressed(b"\x03www\x07example\x03com\x00").unwrap(); - assert_eq!( - format!( - "{}", - DisplayLabels{ - labels: name.labels().rev(), - options: DisplayLabelsOptions{ - separator: " ", - trailing: false, - }, - } - ), - "com example www" - ); - } - - #[test] - fn modifications() { - let mut name = de_uncompressed(b"\x07example\x03com\x00").unwrap(); - - name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "www.example.com." - ); - - name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "www.example.com.org." - ); - - name.pop_front(); - assert_eq!( - format!("{}", name), - "example.com.org." - ); - - name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "mx.example.com.org." - ); - // the "mx" label should fit into the place "www" used before, - // make sure the buffer was reused and the name not moved within - assert_eq!(1, name.label_offsets.label_pos(0)); - - name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "mx.example.com.org.com." - ); - } - - - - fn de_compressed(bytes: &'static [u8], offset: usize) -> Result { - use bytes::Buf; - - let mut c = Cursor::new(Bytes::from_static(bytes)); - c.set_position(offset as u64); - let result = DnsPacketData::deserialize(&mut c)?; - if c.remaining() != 0 { - bail!("data remaining: {}", c.remaining()) - } - Ok(result) - } - - fn check_compressed_display(bytes: &'static [u8], offset: usize, txt: &str, label_count: u8) { - let name = de_compressed(bytes, offset).unwrap(); - assert_eq!( - name.labels().count(), - label_count as usize - ); - assert_eq!( - format!("{}", name), - txt - ); - } - - fn check_compressed_debug(bytes: &'static [u8], offset: usize, txt: &str) { - let name = de_compressed(bytes, offset).unwrap(); - assert_eq!( - format!("{:?}", name), - txt - ); - } - - #[test] - fn parse_invalid_compressed_name() { - de_compressed(b"\x11com\x00\x07example\xc0\x00", 5).unwrap_err(); - de_compressed(b"\x10com\x00\x07example\xc0\x00", 5).unwrap_err(); - } - - #[test] - fn parse_and_display_compressed_name() { - check_compressed_display( - b"\x03com\x00\x07example\xc0\x00", 5, - "example.com.", - 2, - ); - check_compressed_display( - b"\x03com\x00\x07e!am.l\\\xc0\x00", 5, - "e\\033am\\.l\\\\.com.", - 2, - ); - check_compressed_debug( - b"\x03com\x00\x07e!am.l\\\xc0\x00", 5, - r#""e\\033am\\.l\\\\.com.""#, - ); - check_compressed_display( - b"\x03com\x00\x07example\xc0\x00\x03www\xc0\x05", 15, - "www.example.com.", - 3, - ); - } - - #[test] - fn modifications_compressed() { - let mut name = de_compressed(b"\x03com\x00\x07example\xc0\x00\xc0\x05", 15).unwrap(); - - name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "www.example.com." - ); - - name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "www.example.com.org." - ); - - name.pop_front(); - assert_eq!( - format!("{}", name), - "example.com.org." - ); - - name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "mx.example.com.org." - ); - // the "mx" label should fit into the place "www" used before, - // make sure the buffer was reused and the name not moved within - assert_eq!(1, name.label_offsets.label_pos(0)); - - name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap(); - assert_eq!( - format!("{}", name), - "mx.example.com.org.com." - ); - } -} +mod tests; diff --git a/lib/dnsbox-base/src/common_types/name/name.rs b/lib/dnsbox-base/src/common_types/name/name.rs new file mode 100644 index 0000000..4dd06aa --- /dev/null +++ b/lib/dnsbox-base/src/common_types/name/name.rs @@ -0,0 +1,178 @@ +use bytes::Bytes; +use errors::*; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; +use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with}; +use smallvec::SmallVec; +use std::fmt; +use std::io::Cursor; +use std::str::FromStr; + +use super::{LabelOffsets, DnsNameIterator, DnsLabelRef, DnsLabel, DisplayLabels}; + +/// A DNS name +/// +/// Uses the "original" raw representation for storage (i.e. can share +/// memory with a parsed packet) +#[derive(Clone)] +pub struct DnsName { + // in uncompressed form always includes terminating null octect; + // but even in uncompressed form can include unused bytes at the + // beginning + // + // may be empty for the root name (".", no labels) + pub(super) data: Bytes, + // either uncompressed or compressed offsets + pub(super) label_offsets: LabelOffsets, + // length of encoded form + pub(super) total_len: u8, +} + +impl DnsName { + /// Create new name representing the DNS root (".") + pub fn new_root() -> Self { + DnsName{ + data: Bytes::new(), + label_offsets: LabelOffsets::Uncompressed(SmallVec::new()), + total_len: 1, + } + } + + /// Create new name representing the DNS root (".") and pre-allocate + /// storage + pub fn with_capacity(labels: u8, total_len: u8) -> Self { + DnsName{ + data: Bytes::with_capacity(total_len as usize), + label_offsets: LabelOffsets::Uncompressed(SmallVec::with_capacity(labels as usize)), + total_len: 1, + } + } + + /// Parse text representation of a domain name + pub fn parse(context: &DnsTextContext, value: &str) -> Result { + super::name_text_parser::parse_name(context, value) + } + + /// Returns whether name represents the DNS root (".") + pub fn is_root(&self) -> bool { + 0 == self.label_count() + } + + /// How many labels the name has (without the trailing empty label, + /// at most 127) + pub fn label_count(&self) -> u8 { + self.label_offsets.len() + } + + /// Iterator over the labels (in the order they are stored in memory, + /// i.e. top-level name last). + pub fn labels<'a>(&'a self) -> DnsNameIterator<'a> { + DnsNameIterator{ + name: &self, + front_label: 0, + back_label: self.label_offsets.len(), + } + } + + /// Return label at index `ndx` + /// + /// # Panics + /// + /// panics if `ndx >= self.label_count()`. + pub fn label_ref<'a>(&'a self, ndx: u8) -> DnsLabelRef<'a> { + let pos = self.label_offsets.label_pos(ndx); + let label_len = self.data[pos]; + debug_assert!(label_len < 64); + let end = pos + 1 + label_len as usize; + DnsLabelRef{label: &self.data[pos + 1..end]} + } + + /// Return label at index `ndx` + /// + /// # Panics + /// + /// panics if `ndx >= self.label_count()`. + pub fn label(&self, ndx: u8) -> DnsLabel { + let pos = self.label_offsets.label_pos(ndx); + let label_len = self.data[pos]; + debug_assert!(label_len < 64); + let end = pos + 1 + label_len as usize; + DnsLabel{label: self.data.slice(pos + 1, end) } + } +} + +impl<'a> IntoIterator for &'a DnsName { + type Item = DnsLabelRef<'a>; + type IntoIter = DnsNameIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.labels() + } +} + +impl PartialEq for DnsName +{ + fn eq(&self, rhs: &DnsName) -> bool { + let a_labels = self.labels(); + let b_labels = rhs.labels(); + if a_labels.len() != b_labels.len() { return false; } + a_labels.zip(b_labels).all(|(a,b)| a == b) + } +} + +impl PartialEq for DnsName +where + T: AsRef +{ + fn eq(&self, rhs: &T) -> bool { + self == rhs.as_ref() + } +} + +impl Eq for DnsName{} + +impl fmt::Debug for DnsName { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + DisplayLabels{ + labels: self, + options: Default::default(), + }.fmt(w) + } +} + +impl fmt::Display for DnsName { + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + DisplayLabels{ + labels: self, + options: Default::default(), + }.fmt(w) + } +} + +impl FromStr for DnsName { + type Err = ::failure::Error; + + fn from_str(s: &str) -> Result { + parse_with(s, |data| DnsName::dns_parse(&DnsTextContext::new(), data)) + } +} + +impl DnsTextData for DnsName { + fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result { + let field = next_field(data)?; + DnsName::parse(context, field) + } + + fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { + write!(f, "{}", self) + } +} + +impl DnsPacketData for DnsName { + fn deserialize(data: &mut Cursor) -> Result { + super::name_packet_parser::deserialize_name(data, false) + } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + context.write_uncompressed_name(packet, self) + } +} diff --git a/lib/dnsbox-base/src/common_types/name/name_iterator.rs b/lib/dnsbox-base/src/common_types/name/name_iterator.rs new file mode 100644 index 0000000..a24bf17 --- /dev/null +++ b/lib/dnsbox-base/src/common_types/name/name_iterator.rs @@ -0,0 +1,46 @@ +use super::{DnsName, DnsLabelRef}; + +/// Iterator type for [`DnsName::labels`] +/// +/// [`DnsName::labels`]: struct.DnsName.html#method.labels +#[derive(Clone)] +pub struct DnsNameIterator<'a> { + pub(super) name: &'a DnsName, + pub(super) front_label: u8, + pub(super) back_label: u8, +} + +impl<'a> Iterator for DnsNameIterator<'a> { + type Item = DnsLabelRef<'a>; + + fn next(&mut self) -> Option { + if self.front_label >= self.back_label { return None } + let label = self.name.label_ref(self.front_label); + self.front_label += 1; + Some(label) + } + + fn size_hint(&self) -> (usize, Option) { + let count = self.len(); + (count, Some(count)) + } + + fn count(self) -> usize { + self.len() + } +} + +impl<'a> ExactSizeIterator for DnsNameIterator<'a> { + fn len(&self) -> usize { + (self.back_label - self.front_label) as usize + } +} + +impl<'a> DoubleEndedIterator for DnsNameIterator<'a> { + fn next_back(&mut self) -> Option { + if self.front_label >= self.back_label { return None } + self.back_label -= 1; + let label = self.name.label_ref(self.back_label); + Some(label) + } +} diff --git a/lib/dnsbox-base/src/common_types/name/name_packet_parser.rs b/lib/dnsbox-base/src/common_types/name/name_packet_parser.rs index d4669f7..63d7f08 100644 --- a/lib/dnsbox-base/src/common_types/name/name_packet_parser.rs +++ b/lib/dnsbox-base/src/common_types/name/name_packet_parser.rs @@ -1,86 +1,84 @@ use bytes::Buf; use super::*; -impl DnsName { - /// `data`: bytes of packet from beginning until at least the end of the name - /// `start_pos`: position of first byte of the name - /// `uncmpr_offsets`: offsets of uncompressed labels so far - /// `label_len`: first compressed label length (`0xc0 | offset-high, offset-low`) - /// `total_len`: length of (uncompressed) label encoding so far - fn parse_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offsets: SmallVec<[u8;16]>, mut total_len: usize, mut label_len: u8) -> Result { - let mut label_offsets = uncmpr_offsets.into_iter() - .map(LabelOffset::LabelStart) - .collect::>(); +/// `data`: bytes of packet from beginning until at least the end of the name +/// `start_pos`: position of first byte of the name +/// `uncmpr_offsets`: offsets of uncompressed labels so far +/// `label_len`: first compressed label length (`0xc0 | offset-high, offset-low`) +/// `total_len`: length of (uncompressed) label encoding so far +fn deserialize_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offsets: SmallVec<[u8;16]>, mut total_len: usize, mut label_len: u8) -> Result { + let mut label_offsets = uncmpr_offsets.into_iter() + .map(LabelOffset::LabelStart) + .collect::>(); - let mut pos = start_pos + total_len; - 'next_compressed: loop { - { - ensure!(pos + 1 < data.len(), "not enough data for compressed label"); - let new_pos = ((label_len as usize & 0x3f) << 8) | (data[pos + 1] as usize); - ensure!(new_pos < pos, "Compressed label offset too big: {} >= {}", new_pos, pos); - pos = new_pos; - } - - loop { - ensure!(pos < data.len(), "not enough data for label"); - label_len = data[pos]; - - if 0 == label_len { - return Ok(DnsName{ - data: data, - label_offsets: LabelOffsets::Compressed(start_pos, label_offsets), - total_len: total_len as u8 + 1, - }) - } - - if label_len & 0xc0 == 0xc0 { continue 'next_compressed; } - ensure!(label_len < 64, "Invalid label length {}", label_len); - - total_len += 1 + label_len as usize; - // max len 255, but there also needs to be an empty label at the end - if total_len > 254 { bail!("DNS name too long") } - - label_offsets.push(LabelOffset::PacketStart(pos as u16)); - pos += 1 + label_len as usize; - } + let mut pos = start_pos + total_len; + 'next_compressed: loop { + { + ensure!(pos + 1 < data.len(), "not enough data for compressed label"); + let new_pos = ((label_len as usize & 0x3f) << 8) | (data[pos + 1] as usize); + ensure!(new_pos < pos, "Compressed label offset too big: {} >= {}", new_pos, pos); + pos = new_pos; } - } - pub(super) fn parse_name(data: &mut Cursor, accept_compressed: bool) -> Result { - check_enough_data!(data, 1, "DnsName"); - let start_pos = data.position() as usize; - let mut total_len : usize = 0; - let mut label_offsets = SmallVec::new(); loop { - check_enough_data!(data, 1, "DnsName label len"); - let label_len = data.get_u8() as usize; + ensure!(pos < data.len(), "not enough data for label"); + label_len = data[pos]; + if 0 == label_len { - let end_pos = data.position() as usize; return Ok(DnsName{ - data: data.get_ref().slice(start_pos, end_pos), - label_offsets: LabelOffsets::Uncompressed(label_offsets), + data: data, + label_offsets: LabelOffsets::Compressed(start_pos, label_offsets), total_len: total_len as u8 + 1, }) } - if label_len & 0xc0 == 0xc0 { - // compressed label - if !accept_compressed { bail!("Invalid label compression {}", label_len) } - check_enough_data!(data, 1, "DnsName compressed label target"); - // eat second part of compressed label - data.get_u8(); - let end_pos = data.position() as usize; - let data = data.get_ref().slice(0, end_pos); + if label_len & 0xc0 == 0xc0 { continue 'next_compressed; } + ensure!(label_len < 64, "Invalid label length {}", label_len); - return Self::parse_name_compressed_cont(data, start_pos, label_offsets, total_len, label_len as u8); - } - label_offsets.push(total_len as u8); - if label_len > 63 { bail!("Invalid label length {}", label_len) } - total_len += 1 + label_len; + total_len += 1 + label_len as usize; // max len 255, but there also needs to be an empty label at the end - if total_len > 254 { bail!{"DNS name too long"} } - check_enough_data!(data, (label_len), "DnsName label"); - data.advance(label_len); + if total_len > 254 { bail!("DNS name too long") } + + label_offsets.push(LabelOffset::PacketStart(pos as u16)); + pos += 1 + label_len as usize; } } } + +pub fn deserialize_name(data: &mut Cursor, accept_compressed: bool) -> Result { + check_enough_data!(data, 1, "DnsName"); + let start_pos = data.position() as usize; + let mut total_len : usize = 0; + let mut label_offsets = SmallVec::new(); + loop { + check_enough_data!(data, 1, "DnsName label len"); + let label_len = data.get_u8() as usize; + if 0 == label_len { + let end_pos = data.position() as usize; + return Ok(DnsName{ + data: data.get_ref().slice(start_pos, end_pos), + label_offsets: LabelOffsets::Uncompressed(label_offsets), + total_len: total_len as u8 + 1, + }) + } + if label_len & 0xc0 == 0xc0 { + // compressed label + if !accept_compressed { bail!("Invalid label compression {}", label_len) } + check_enough_data!(data, 1, "DnsName compressed label target"); + // eat second part of compressed label + data.get_u8(); + + let end_pos = data.position() as usize; + let data = data.get_ref().slice(0, end_pos); + + return deserialize_name_compressed_cont(data, start_pos, label_offsets, total_len, label_len as u8); + } + label_offsets.push(total_len as u8); + if label_len > 63 { bail!("Invalid label length {}", label_len) } + total_len += 1 + label_len; + // max len 255, but there also needs to be an empty label at the end + if total_len > 254 { bail!{"DNS name too long"} } + check_enough_data!(data, (label_len), "DnsName label"); + data.advance(label_len); + } +} diff --git a/lib/dnsbox-base/src/common_types/name/name_text_parser.rs b/lib/dnsbox-base/src/common_types/name/name_text_parser.rs index 2061394..815ba85 100644 --- a/lib/dnsbox-base/src/common_types/name/name_text_parser.rs +++ b/lib/dnsbox-base/src/common_types/name/name_text_parser.rs @@ -1,111 +1,65 @@ -use super::*; -use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, quoted, parse_with}; +use errors::*; +use ser::text::{DnsTextContext, quoted}; -impl DnsName { - /// Parse text representation of a domain name - pub fn parse(context: &DnsTextContext, value: &str) -> Result - { - let raw = value.as_bytes(); - let mut name = DnsName::new_root(); - if raw == b"." { - return Ok(name); - } else if raw == b"@" { - match context.origin() { - Some(o) => return Ok(o.clone()), - None => bail!("@ invalid without $ORIGIN"), - } +use super::{DnsName, DnsLabelRef}; + +/// Parse text representation of a domain name +pub fn parse_name(context: &DnsTextContext, value: &str) -> Result +{ + let raw = value.as_bytes(); + let mut name = DnsName::new_root(); + if raw == b"." { + return Ok(name); + } else if raw == b"@" { + match context.origin() { + Some(o) => return Ok(o.clone()), + None => bail!("@ invalid without $ORIGIN"), } - ensure!(!raw.is_empty(), "invalid empty name"); - let mut label = Vec::new(); - let mut pos = 0; - while pos < raw.len() { - if raw[pos] == b'.' { - ensure!(!label.is_empty(), "empty label in name: {:?}", value); - name.push_back(DnsLabelRef::new(&label)?)?; - label.clear(); - } else if raw[pos] == b'\\' { - ensure!(pos + 1 < raw.len(), "unexpected end of name after backslash: {:?}", value); - if raw[pos+1] >= b'0' && raw[pos+1] <= b'9' { - // \ddd escape - ensure!(pos + 3 < raw.len(), "unexpected end of name after backslash with digit: {:?}", value); - ensure!(raw[pos+2] >= b'0' && raw[pos+2] <= b'9' && raw[pos+3] >= b'0' && raw[pos+3] <= b'9', "expected three digits after backslash in name: {:?}", name); - let d1 = (raw[pos+1] - b'0') as u32; - let d2 = (raw[pos+2] - b'0') as u32; - let d3 = (raw[pos+3] - b'0') as u32; - let v = d1 * 100 + d2 * 10 + d3; - ensure!(v < 256, "invalid escape in name, {} > 255: {:?}", v, name); - label.push(v as u8); - } else { - ensure!(!quoted::is_ascii_whitespace(raw[pos+1]), "whitespace cannot be escaped with backslash prefix; encode it as \\{:03} in: {:?}", raw[pos+1], name); - label.push(raw[pos+1]); - } - } else { - ensure!(!quoted::is_ascii_whitespace(raw[pos]), "whitespace must be encoded as \\{:03} in: {:?}", raw[pos], name); - label.push(raw[pos]); - } - pos += 1; - } - - if !label.is_empty() { - // no trailing dot, relative name - - // push last label + } + ensure!(!raw.is_empty(), "invalid empty name"); + let mut label = Vec::new(); + let mut pos = 0; + while pos < raw.len() { + if raw[pos] == b'.' { + ensure!(!label.is_empty(), "empty label in name: {:?}", value); name.push_back(DnsLabelRef::new(&label)?)?; - - match context.origin() { - Some(o) => { - for l in o { name.push_back(l)?; } - }, - None => bail!("missing trailing dot without $ORIGIN"), + label.clear(); + } else if raw[pos] == b'\\' { + ensure!(pos + 1 < raw.len(), "unexpected end of name after backslash: {:?}", value); + if raw[pos+1] >= b'0' && raw[pos+1] <= b'9' { + // \ddd escape + ensure!(pos + 3 < raw.len(), "unexpected end of name after backslash with digit: {:?}", value); + ensure!(raw[pos+2] >= b'0' && raw[pos+2] <= b'9' && raw[pos+3] >= b'0' && raw[pos+3] <= b'9', "expected three digits after backslash in name: {:?}", name); + let d1 = (raw[pos+1] - b'0') as u32; + let d2 = (raw[pos+2] - b'0') as u32; + let d3 = (raw[pos+3] - b'0') as u32; + let v = d1 * 100 + d2 * 10 + d3; + ensure!(v < 256, "invalid escape in name, {} > 255: {:?}", v, name); + label.push(v as u8); + } else { + ensure!(!quoted::is_ascii_whitespace(raw[pos+1]), "whitespace cannot be escaped with backslash prefix; encode it as \\{:03} in: {:?}", raw[pos+1], name); + label.push(raw[pos+1]); } + } else { + ensure!(!quoted::is_ascii_whitespace(raw[pos]), "whitespace must be encoded as \\{:03} in: {:?}", raw[pos], name); + label.push(raw[pos]); } - - Ok(name) - } -} - -impl DnsTextData for DnsName { - fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result { - let field = next_field(data)?; - DnsName::parse(context, field) - } - - fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { - write!(f, "{}", self) - } -} - -impl ::std::str::FromStr for DnsName { - type Err = ::failure::Error; - - fn from_str(s: &str) -> Result { - parse_with(s, |data| DnsName::dns_parse(&DnsTextContext::new(), data)) - } -} - -impl DnsCompressedName { - /// Parse text representation of a domain name - pub fn parse(context: &DnsTextContext, value: &str) -> Result - { - Ok(DnsCompressedName(DnsName::parse(context, value)?)) - } -} - -impl DnsTextData for DnsCompressedName { - fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result { - let field = next_field(data)?; - DnsCompressedName::parse(context, field) - } - - fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { - self.0.dns_format(f) - } -} - -impl ::std::str::FromStr for DnsCompressedName { - type Err = ::failure::Error; - - fn from_str(s: &str) -> Result { - parse_with(s, |data| DnsCompressedName::dns_parse(&DnsTextContext::new(), data)) + pos += 1; } + + if !label.is_empty() { + // no trailing dot, relative name + + // push last label + name.push_back(DnsLabelRef::new(&label)?)?; + + match context.origin() { + Some(o) => { + for l in o { name.push_back(l)?; } + }, + None => bail!("missing trailing dot without $ORIGIN"), + } + } + + Ok(name) } diff --git a/lib/dnsbox-base/src/common_types/name/tests.rs b/lib/dnsbox-base/src/common_types/name/tests.rs new file mode 100644 index 0000000..484ee8d --- /dev/null +++ b/lib/dnsbox-base/src/common_types/name/tests.rs @@ -0,0 +1,219 @@ +use bytes::Bytes; +use ser::packet; +use ser::packet::DnsPacketData; +use std::io::Cursor; +use errors::*; + +use super::{DnsName, DnsCompressedName, DnsLabelRef, DisplayLabels, DisplayLabelsOptions}; + +/* +fn deserialize(bytes: &'static [u8]) -> Result { + let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?; + { + let check_result = packet::deserialize_with(result.clone().encode(), DnsName::deserialize).unwrap(); + assert_eq!(check_result, result); + } + Ok(result) +} +*/ + +fn de_uncompressed(bytes: &'static [u8]) -> Result { + let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?; + assert_eq!(bytes, result.clone().encode()); + Ok(result) +} + +fn check_uncompressed_display(bytes: &'static [u8], txt: &str, label_count: u8) { + let name = de_uncompressed(bytes).unwrap(); + assert_eq!( + name.labels().count(), + label_count as usize + ); + assert_eq!( + format!("{}", name), + txt + ); +} + +fn check_uncompressed_debug(bytes: &'static [u8], txt: &str) { + let name = de_uncompressed(bytes).unwrap(); + assert_eq!( + format!("{:?}", name), + txt + ); +} + +#[test] +fn parse_and_display_name() { + check_uncompressed_display( + b"\x07example\x03com\x00", + "example.com.", + 2, + ); + check_uncompressed_display( + b"\x07e!am.l\\\x03com\x00", + "e\\033am\\.l\\\\.com.", + 2, + ); + check_uncompressed_debug( + b"\x07e!am.l\\\x03com\x00", + r#""e\\033am\\.l\\\\.com.""#, + ); +} + +#[test] +fn parse_and_reverse_name() { + let name = de_uncompressed(b"\x03www\x07example\x03com\x00").unwrap(); + assert_eq!( + format!( + "{}", + DisplayLabels{ + labels: name.labels().rev(), + options: DisplayLabelsOptions{ + separator: " ", + trailing: false, + }, + } + ), + "com example www" + ); +} + +#[test] +fn modifications() { + let mut name = de_uncompressed(b"\x07example\x03com\x00").unwrap(); + + name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "www.example.com." + ); + + name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "www.example.com.org." + ); + + name.pop_front(); + assert_eq!( + format!("{}", name), + "example.com.org." + ); + + name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "mx.example.com.org." + ); + // the "mx" label should fit into the place "www" used before, + // make sure the buffer was reused and the name not moved within + assert_eq!(1, name.label_offsets.label_pos(0)); + + name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "mx.example.com.org.com." + ); +} + + + +fn de_compressed(bytes: &'static [u8], offset: usize) -> Result { + use bytes::Buf; + + let mut c = Cursor::new(Bytes::from_static(bytes)); + c.set_position(offset as u64); + let result = DnsPacketData::deserialize(&mut c)?; + if c.remaining() != 0 { + bail!("data remaining: {}", c.remaining()) + } + Ok(result) +} + +fn check_compressed_display(bytes: &'static [u8], offset: usize, txt: &str, label_count: u8) { + let name = de_compressed(bytes, offset).unwrap(); + assert_eq!( + name.labels().count(), + label_count as usize + ); + assert_eq!( + format!("{}", name), + txt + ); +} + +fn check_compressed_debug(bytes: &'static [u8], offset: usize, txt: &str) { + let name = de_compressed(bytes, offset).unwrap(); + assert_eq!( + format!("{:?}", name), + txt + ); +} + +#[test] +fn parse_invalid_compressed_name() { + de_compressed(b"\x11com\x00\x07example\xc0\x00", 5).unwrap_err(); + de_compressed(b"\x10com\x00\x07example\xc0\x00", 5).unwrap_err(); +} + +#[test] +fn parse_and_display_compressed_name() { + check_compressed_display( + b"\x03com\x00\x07example\xc0\x00", 5, + "example.com.", + 2, + ); + check_compressed_display( + b"\x03com\x00\x07e!am.l\\\xc0\x00", 5, + "e\\033am\\.l\\\\.com.", + 2, + ); + check_compressed_debug( + b"\x03com\x00\x07e!am.l\\\xc0\x00", 5, + r#""e\\033am\\.l\\\\.com.""#, + ); + check_compressed_display( + b"\x03com\x00\x07example\xc0\x00\x03www\xc0\x05", 15, + "www.example.com.", + 3, + ); +} + +#[test] +fn modifications_compressed() { + let mut name = de_compressed(b"\x03com\x00\x07example\xc0\x00\xc0\x05", 15).unwrap(); + + name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "www.example.com." + ); + + name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "www.example.com.org." + ); + + name.pop_front(); + assert_eq!( + format!("{}", name), + "example.com.org." + ); + + name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "mx.example.com.org." + ); + // the "mx" label should fit into the place "www" used before, + // make sure the buffer was reused and the name not moved within + assert_eq!(1, name.label_offsets.label_pos(0)); + + name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap(); + assert_eq!( + format!("{}", name), + "mx.example.com.org.com." + ); +} diff --git a/lib/dnsbox-base/src/common_types/types.rs b/lib/dnsbox-base/src/common_types/types.rs index d22827f..54a36ae 100644 --- a/lib/dnsbox-base/src/common_types/types.rs +++ b/lib/dnsbox-base/src/common_types/types.rs @@ -474,7 +474,6 @@ impl Type { /// parses generic names of the form "TYPE..." pub fn from_generic_name(name: &str) -> Option { - use std::ascii::AsciiExt; if name.len() > 4 && name.as_bytes()[0..4].eq_ignore_ascii_case(b"TYPE") { name[4..].parse::().ok().map(Type) } else { diff --git a/lib/dnsbox-base/src/records/registry.rs b/lib/dnsbox-base/src/records/registry.rs index 4152a18..786639d 100644 --- a/lib/dnsbox-base/src/records/registry.rs +++ b/lib/dnsbox-base/src/records/registry.rs @@ -1,6 +1,5 @@ use bytes::Bytes; use std::any::TypeId; -use std::ascii::AsciiExt; use std::collections::HashMap; use std::io::Cursor; use std::marker::PhantomData; diff --git a/lib/dnsbox-base/src/ser/packet/write.rs b/lib/dnsbox-base/src/ser/packet/write.rs index 638fbaf..2be59a4 100644 --- a/lib/dnsbox-base/src/ser/packet/write.rs +++ b/lib/dnsbox-base/src/ser/packet/write.rs @@ -68,6 +68,24 @@ fn write_name(packet: &mut Vec, name: &DnsName) { packet.put_u8(0); } +fn write_canonical_label(packet: &mut Vec, label: DnsLabelRef) { + let l = label.len(); + debug_assert!(l < 64); + packet.reserve(l as usize + 1); + packet.put_u8(l); + for c in label.as_raw() { + packet.put_u8(c.to_ascii_lowercase()); + } +} + +fn write_canonical_name(packet: &mut Vec, name: &DnsName) { + for label in name { + write_canonical_label(packet, label); + } + packet.reserve(1); + packet.put_u8(0); +} + fn write_label_remember(packet: &mut Vec, labels: &mut Vec, label: DnsLabelRef, next_entry: usize) { labels.push(LabelEntry { pos: packet.len(), @@ -76,9 +94,23 @@ fn write_label_remember(packet: &mut Vec, labels: &mut Vec, labe write_label(packet, label); } +#[derive(Clone, Debug)] +enum LabelWriteMethod { + Uncompressed, + Compressed(Vec), + Canonical, // DNSSEC, uncompressed + ASCII lower-case +} + +impl Default for LabelWriteMethod { + fn default() -> Self { + LabelWriteMethod::Uncompressed + } +} + #[derive(Clone, Debug, Default)] pub struct DnsPacketWriteContext { - labels: Option>, + labels: LabelWriteMethod, + } impl DnsPacketWriteContext { @@ -86,11 +118,26 @@ impl DnsPacketWriteContext { Default::default() } + /// Enables writing compressed names + /// + /// Only `DnsCompressedName` uses compression, `DnsName` and + /// `DnsCanonicalName` are never compressed. pub fn enable_compression(&mut self) { - self.labels = Some(Vec::new()); + self.labels = LabelWriteMethod::Compressed(Vec::new()); } - pub fn write_uncompressed_name(&mut self, packet: &mut Vec, name: &DnsName) -> Result<()> { + /// Enables writing canonical names + /// + /// Disables compression, and converts `DnsCompressedName` and + /// `DnsCanonicalName` to ASCII lowercase. + /// + /// `DnsName` is never compressed, but also never converted to + /// lowercase. + pub fn enable_canonical(&mut self) { + self.labels = LabelWriteMethod::Canonical; + } + + pub(crate) fn write_uncompressed_name(&mut self, packet: &mut Vec, name: &DnsName) -> Result<()> { // for now we don't remember labels of these names. // // if we did: would we want to check whether a suffix is already @@ -100,19 +147,36 @@ impl DnsPacketWriteContext { Ok(()) } - pub fn write_compressed_name(&mut self, packet: &mut Vec, name: &DnsCompressedName) -> Result<()> { + pub(crate) fn write_canonical_name(&mut self, packet: &mut Vec, name: &DnsName) -> Result<()> { + match self.labels { + LabelWriteMethod::Uncompressed | LabelWriteMethod::Compressed(_) => { + // uncompressed + write_name(packet, name); + }, + LabelWriteMethod::Canonical => { + write_canonical_name(packet, name); + }, + } + return Ok(()) + } + + pub(crate) fn write_compressed_name(&mut self, packet: &mut Vec, name: &DnsCompressedName) -> Result<()> { + // for DNSSEC we need to write it canonical if name.is_root() { write_name(packet, name); return Ok(()); } let labels = match self.labels { - Some(ref mut labels) => labels, - None => { - // compression disabled + LabelWriteMethod::Uncompressed => { write_name(packet, name); return Ok(()); - } + }, + LabelWriteMethod::Compressed(ref mut labels) => labels, + LabelWriteMethod::Canonical => { + write_canonical_name(packet, name); + return Ok(()) + }, }; let mut best_match = None;