This commit is contained in:
Stefan Bühler 2018-02-10 11:32:25 +01:00
parent 534272fbc9
commit 0261d27764
13 changed files with 945 additions and 712 deletions

View File

@ -133,7 +133,6 @@ impl Class {
/// ///
/// Avoids conflict with parsing RRTYPE mnemonics. /// Avoids conflict with parsing RRTYPE mnemonics.
pub fn from_known_name_without_any(name: &str) -> Option<Self> { pub fn from_known_name_without_any(name: &str) -> Option<Self> {
use std::ascii::AsciiExt;
if name.eq_ignore_ascii_case("IN") { return Some(IN); } if name.eq_ignore_ascii_case("IN") { return Some(IN); }
if name.eq_ignore_ascii_case("CH") { return Some(CH); } if name.eq_ignore_ascii_case("CH") { return Some(CH); }
if name.eq_ignore_ascii_case("HS") { return Some(HS); } if name.eq_ignore_ascii_case("HS") { return Some(HS); }
@ -143,7 +142,6 @@ impl Class {
/// parses known names (mnemonics) /// parses known names (mnemonics)
pub fn from_known_name(name: &str) -> Option<Self> { pub fn from_known_name(name: &str) -> Option<Self> {
use std::ascii::AsciiExt;
Self::from_known_name_without_any(name).or_else(|| { Self::from_known_name_without_any(name).or_else(|| {
if name.eq_ignore_ascii_case("ANY") { return Some(ANY); } if name.eq_ignore_ascii_case("ANY") { return Some(ANY); }
None None
@ -152,7 +150,6 @@ impl Class {
/// parses generic names of the form "CLASS..." /// parses generic names of the form "CLASS..."
pub fn from_generic_name(name: &str) -> Option<Self> { pub fn from_generic_name(name: &str) -> Option<Self> {
use std::ascii::AsciiExt;
if name.len() > 5 && name.as_bytes()[0..5].eq_ignore_ascii_case(b"CLASS") { if name.len() > 5 && name.as_bytes()[0..5].eq_ignore_ascii_case(b"CLASS") {
name[5..].parse::<u16>().ok().map(Class) name[5..].parse::<u16>().ok().map(Class)
} else { } else {

View File

@ -0,0 +1,127 @@
use bytes::Bytes;
use errors::*;
use ser::packet::{DnsPacketData, DnsPacketWriteContext};
use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with};
use std::fmt;
use std::io::Cursor;
use std::ops::{Deref, DerefMut};
use std::str::FromStr;
use super::{DnsName, DnsNameIterator, DnsLabelRef};
/// names that should be written in canonical form for DNSSEC according
/// to https://tools.ietf.org/html/rfc4034#section-6.2
///
/// DnsCompressedName always needs to be written in canonical form for
/// DNSSEC.
#[derive(Clone)]
pub struct DnsCanonicalName(pub DnsName);
impl DnsCanonicalName {
/// Create new name representing the DNS root (".")
pub fn new_root() -> Self {
DnsCanonicalName(DnsName::new_root())
}
/// Parse text representation of a domain name
pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self>
{
Ok(DnsCanonicalName(DnsName::parse(context, value)?))
}
}
impl Deref for DnsCanonicalName {
type Target = DnsName;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DnsCanonicalName {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl AsRef<DnsName> for DnsCanonicalName {
fn as_ref(&self) -> &DnsName {
&self.0
}
}
impl AsMut<DnsName> for DnsCanonicalName {
fn as_mut(&mut self) -> &mut DnsName {
&mut self.0
}
}
impl<'a> IntoIterator for &'a DnsCanonicalName {
type Item = DnsLabelRef<'a>;
type IntoIter = DnsNameIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.labels()
}
}
impl PartialEq<DnsName> for DnsCanonicalName
{
fn eq(&self, rhs: &DnsName) -> bool {
let this: &DnsName = self;
this == rhs
}
}
impl<T> PartialEq<T> for DnsCanonicalName
where
T: AsRef<DnsName>
{
fn eq(&self, rhs: &T) -> bool {
let this: &DnsName = self.as_ref();
this == rhs
}
}
impl Eq for DnsCanonicalName{}
impl fmt::Debug for DnsCanonicalName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(w)
}
}
impl fmt::Display for DnsCanonicalName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(w)
}
}
impl FromStr for DnsCanonicalName {
type Err = ::failure::Error;
fn from_str(s: &str) -> Result<Self> {
parse_with(s, |data| DnsCanonicalName::dns_parse(&DnsTextContext::new(), data))
}
}
impl DnsTextData for DnsCanonicalName {
fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result<Self> {
let field = next_field(data)?;
DnsCanonicalName::parse(context, field)
}
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
self.0.dns_format(f)
}
}
impl DnsPacketData for DnsCanonicalName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
Ok(DnsCanonicalName(super::name_packet_parser::deserialize_name(data, false)?))
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
context.write_canonical_name(packet, self)
}
}

View File

@ -0,0 +1,124 @@
use bytes::Bytes;
use errors::*;
use ser::packet::{DnsPacketData, DnsPacketWriteContext};
use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with};
use std::fmt;
use std::io::Cursor;
use std::ops::{Deref, DerefMut};
use std::str::FromStr;
use super::{DnsName, DnsNameIterator, DnsLabelRef};
/// Similar to `DnsName`, but allows using compressed labels in the
/// serialized form
#[derive(Clone)]
pub struct DnsCompressedName(pub DnsName);
impl DnsCompressedName {
/// Create new name representing the DNS root (".")
pub fn new_root() -> Self {
DnsCompressedName(DnsName::new_root())
}
/// Parse text representation of a domain name
pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self>
{
Ok(DnsCompressedName(DnsName::parse(context, value)?))
}
}
impl Deref for DnsCompressedName {
type Target = DnsName;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DnsCompressedName {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl AsRef<DnsName> for DnsCompressedName {
fn as_ref(&self) -> &DnsName {
&self.0
}
}
impl AsMut<DnsName> for DnsCompressedName {
fn as_mut(&mut self) -> &mut DnsName {
&mut self.0
}
}
impl<'a> IntoIterator for &'a DnsCompressedName {
type Item = DnsLabelRef<'a>;
type IntoIter = DnsNameIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.labels()
}
}
impl PartialEq<DnsName> for DnsCompressedName
{
fn eq(&self, rhs: &DnsName) -> bool {
let this: &DnsName = self;
this == rhs
}
}
impl<T> PartialEq<T> for DnsCompressedName
where
T: AsRef<DnsName>
{
fn eq(&self, rhs: &T) -> bool {
let this: &DnsName = self.as_ref();
this == rhs
}
}
impl Eq for DnsCompressedName{}
impl fmt::Debug for DnsCompressedName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(w)
}
}
impl fmt::Display for DnsCompressedName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(w)
}
}
impl FromStr for DnsCompressedName {
type Err = ::failure::Error;
fn from_str(s: &str) -> Result<Self> {
parse_with(s, |data| DnsCompressedName::dns_parse(&DnsTextContext::new(), data))
}
}
impl DnsTextData for DnsCompressedName {
fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result<Self> {
let field = next_field(data)?;
DnsCompressedName::parse(context, field)
}
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
self.0.dns_format(f)
}
}
impl DnsPacketData for DnsCompressedName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
Ok(DnsCompressedName(super::name_packet_parser::deserialize_name(data, true)?))
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
context.write_compressed_name(packet, self)
}
}

View File

@ -0,0 +1,44 @@
use smallvec::SmallVec;
#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)]
pub enum LabelOffset {
LabelStart(u8),
PacketStart(u16),
}
// the heap meta data is usually at least 2*usize big; assuming 64-bit
// platforms it should be ok to use 16 bytes in the smallvec.
#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)]
pub enum LabelOffsets {
Uncompressed(SmallVec<[u8;16]>),
Compressed(usize, SmallVec<[LabelOffset;4]>),
}
impl LabelOffsets {
pub fn len(&self) -> u8 {
let l = match *self {
LabelOffsets::Uncompressed(ref offs) => offs.len(),
LabelOffsets::Compressed(_, ref offs) => offs.len(),
};
debug_assert!(l < 128);
l as u8
}
pub fn label_pos(&self, ndx: u8) -> usize {
debug_assert!(ndx < 127);
match *self {
LabelOffsets::Uncompressed(ref offs) => offs[ndx as usize] as usize,
LabelOffsets::Compressed(start, ref offs) => match offs[ndx as usize] {
LabelOffset::LabelStart(o) => start + (o as usize),
LabelOffset::PacketStart(o) => o as usize,
}
}
}
pub fn is_compressed(&self) -> bool {
match *self {
LabelOffsets::Uncompressed(_) => false,
LabelOffsets::Compressed(_, _) => true,
}
}
}

View File

@ -3,543 +3,27 @@
use bytes::Bytes; use bytes::Bytes;
use errors::*; use errors::*;
use ser::packet::{DnsPacketData, DnsPacketWriteContext};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::fmt;
use std::io::Cursor; use std::io::Cursor;
use std::ops::{Deref, DerefMut};
pub use self::canonical_name::*;
pub use self::compressed_name::*;
pub use self::display::*; pub use self::display::*;
pub use self::label::*; pub use self::label::*;
pub use self::name_iterator::*;
pub use self::name::*;
use self::label_offsets::*;
mod canonical_name;
mod compressed_name;
mod display; mod display;
mod label; mod label;
mod label_offsets;
mod name;
mod name_iterator;
mod name_mutations; mod name_mutations;
mod name_packet_parser; mod name_packet_parser;
mod name_text_parser; mod name_text_parser;
#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)]
enum LabelOffset {
LabelStart(u8),
PacketStart(u16),
}
// the heap meta data is usually at least 2*usize big; assuming 64-bit
// platforms it should be ok to use 16 bytes in the smallvec.
#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)]
enum LabelOffsets {
Uncompressed(SmallVec<[u8;16]>),
Compressed(usize, SmallVec<[LabelOffset;4]>),
}
impl LabelOffsets {
fn len(&self) -> u8 {
let l = match *self {
LabelOffsets::Uncompressed(ref offs) => offs.len(),
LabelOffsets::Compressed(_, ref offs) => offs.len(),
};
debug_assert!(l < 128);
l as u8
}
fn label_pos(&self, ndx: u8) -> usize {
debug_assert!(ndx < 127);
match *self {
LabelOffsets::Uncompressed(ref offs) => offs[ndx as usize] as usize,
LabelOffsets::Compressed(start, ref offs) => match offs[ndx as usize] {
LabelOffset::LabelStart(o) => start + (o as usize),
LabelOffset::PacketStart(o) => o as usize,
}
}
}
fn is_compressed(&self) -> bool {
match *self {
LabelOffsets::Uncompressed(_) => false,
LabelOffsets::Compressed(_, _) => true,
}
}
}
/// A DNS name
///
/// Uses the "original" raw representation for storage (i.e. can share
/// memory with a parsed packet)
#[derive(Clone)]
pub struct DnsName {
// in uncompressed form always includes terminating null octect;
// but even in uncompressed form can include unused bytes at the
// beginning
//
// may be empty for the root name (".", no labels)
data: Bytes,
// either uncompressed or compressed offsets
label_offsets: LabelOffsets,
// length of encoded form
total_len: u8,
}
impl DnsName {
/// Create new name representing the DNS root (".")
pub fn new_root() -> Self {
DnsName{
data: Bytes::new(),
label_offsets: LabelOffsets::Uncompressed(SmallVec::new()),
total_len: 1,
}
}
/// Create new name representing the DNS root (".") and pre-allocate
/// storage
pub fn with_capacity(labels: u8, total_len: u8) -> Self {
DnsName{
data: Bytes::with_capacity(total_len as usize),
label_offsets: LabelOffsets::Uncompressed(SmallVec::with_capacity(labels as usize)),
total_len: 1,
}
}
/// Returns whether name represents the DNS root (".")
pub fn is_root(&self) -> bool {
0 == self.label_count()
}
/// How many labels the name has (without the trailing empty label,
/// at most 127)
pub fn label_count(&self) -> u8 {
self.label_offsets.len()
}
/// Iterator over the labels (in the order they are stored in memory,
/// i.e. top-level name last).
pub fn labels<'a>(&'a self) -> DnsNameIterator<'a> {
DnsNameIterator{
name: &self,
front_label: 0,
back_label: self.label_offsets.len(),
}
}
/// Return label at index `ndx`
///
/// # Panics
///
/// panics if `ndx >= self.label_count()`.
pub fn label_ref<'a>(&'a self, ndx: u8) -> DnsLabelRef<'a> {
let pos = self.label_offsets.label_pos(ndx);
let label_len = self.data[pos];
debug_assert!(label_len < 64);
let end = pos + 1 + label_len as usize;
DnsLabelRef{label: &self.data[pos + 1..end]}
}
/// Return label at index `ndx`
///
/// # Panics
///
/// panics if `ndx >= self.label_count()`.
pub fn label(&self, ndx: u8) -> DnsLabel {
let pos = self.label_offsets.label_pos(ndx);
let label_len = self.data[pos];
debug_assert!(label_len < 64);
let end = pos + 1 + label_len as usize;
DnsLabel{label: self.data.slice(pos + 1, end) }
}
}
impl<'a> IntoIterator for &'a DnsName {
type Item = DnsLabelRef<'a>;
type IntoIter = DnsNameIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.labels()
}
}
impl PartialEq<DnsName> for DnsName {
fn eq(&self, rhs: &DnsName) -> bool {
let a_labels = self.labels();
let b_labels = rhs.labels();
if a_labels.len() != b_labels.len() { return false; }
a_labels.zip(b_labels).all(|(a,b)| a == b)
}
}
impl Eq for DnsName{}
impl fmt::Debug for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl fmt::Display for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl DnsPacketData for DnsName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
DnsName::parse_name(data, false)
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
context.write_uncompressed_name(packet, self)
}
}
/// Similar to `DnsName`, but allows using compressed labels in the
/// serialized form
#[derive(Clone)]
pub struct DnsCompressedName(pub DnsName);
impl DnsCompressedName {
/// Create new name representing the DNS root (".")
pub fn new_root() -> Self {
DnsCompressedName(DnsName::new_root())
}
}
impl Deref for DnsCompressedName {
type Target = DnsName;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DnsCompressedName {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<'a> IntoIterator for &'a DnsCompressedName {
type Item = DnsLabelRef<'a>;
type IntoIter = DnsNameIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.labels()
}
}
impl PartialEq<DnsCompressedName> for DnsCompressedName {
fn eq(&self, rhs: &DnsCompressedName) -> bool {
self.0 == rhs.0
}
}
impl PartialEq<DnsName> for DnsCompressedName {
fn eq(&self, rhs: &DnsName) -> bool {
&self.0 == rhs
}
}
impl PartialEq<DnsCompressedName> for DnsName {
fn eq(&self, rhs: &DnsCompressedName) -> bool {
self == &rhs.0
}
}
impl Eq for DnsCompressedName{}
impl fmt::Debug for DnsCompressedName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(w)
}
}
impl fmt::Display for DnsCompressedName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(w)
}
}
impl DnsPacketData for DnsCompressedName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
Ok(DnsCompressedName(DnsName::parse_name(data, true)?))
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
context.write_compressed_name(packet, self)
}
}
/// names that should be written in canonical form for DNSSEC according
/// to https://tools.ietf.org/html/rfc4034#section-6.2
///
/// TODO: make it a newtype.
///
/// DnsCompressedName always needs to be written in canonical form for
/// DNSSEC.
pub type DnsCanonicalName = DnsName;
/// Iterator type for [`DnsName::labels`]
///
/// [`DnsName::labels`]: struct.DnsName.html#method.labels
#[derive(Clone)]
pub struct DnsNameIterator<'a> {
name: &'a DnsName,
front_label: u8,
back_label: u8,
}
impl<'a> Iterator for DnsNameIterator<'a> {
type Item = DnsLabelRef<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.front_label >= self.back_label { return None }
let label = self.name.label_ref(self.front_label);
self.front_label += 1;
Some(label)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let count = self.len();
(count, Some(count))
}
fn count(self) -> usize {
self.len()
}
}
impl<'a> ExactSizeIterator for DnsNameIterator<'a> {
fn len(&self) -> usize {
(self.back_label - self.front_label) as usize
}
}
impl<'a> DoubleEndedIterator for DnsNameIterator<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.front_label >= self.back_label { return None }
self.back_label -= 1;
let label = self.name.label_ref(self.back_label);
Some(label)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests;
use ser::packet;
use super::*;
/*
fn deserialize(bytes: &'static [u8]) -> Result<DnsName> {
let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?;
{
let check_result = packet::deserialize_with(result.clone().encode(), DnsName::deserialize).unwrap();
assert_eq!(check_result, result);
}
Ok(result)
}
*/
fn de_uncompressed(bytes: &'static [u8]) -> Result<DnsName> {
let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?;
assert_eq!(bytes, result.clone().encode());
Ok(result)
}
fn check_uncompressed_display(bytes: &'static [u8], txt: &str, label_count: u8) {
let name = de_uncompressed(bytes).unwrap();
assert_eq!(
name.labels().count(),
label_count as usize
);
assert_eq!(
format!("{}", name),
txt
);
}
fn check_uncompressed_debug(bytes: &'static [u8], txt: &str) {
let name = de_uncompressed(bytes).unwrap();
assert_eq!(
format!("{:?}", name),
txt
);
}
#[test]
fn parse_and_display_name() {
check_uncompressed_display(
b"\x07example\x03com\x00",
"example.com.",
2,
);
check_uncompressed_display(
b"\x07e!am.l\\\x03com\x00",
"e\\033am\\.l\\\\.com.",
2,
);
check_uncompressed_debug(
b"\x07e!am.l\\\x03com\x00",
r#""e\\033am\\.l\\\\.com.""#,
);
}
#[test]
fn parse_and_reverse_name() {
let name = de_uncompressed(b"\x03www\x07example\x03com\x00").unwrap();
assert_eq!(
format!(
"{}",
DisplayLabels{
labels: name.labels().rev(),
options: DisplayLabelsOptions{
separator: " ",
trailing: false,
},
}
),
"com example www"
);
}
#[test]
fn modifications() {
let mut name = de_uncompressed(b"\x07example\x03com\x00").unwrap();
name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com."
);
name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com.org."
);
name.pop_front();
assert_eq!(
format!("{}", name),
"example.com.org."
);
name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org."
);
// the "mx" label should fit into the place "www" used before,
// make sure the buffer was reused and the name not moved within
assert_eq!(1, name.label_offsets.label_pos(0));
name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org.com."
);
}
fn de_compressed(bytes: &'static [u8], offset: usize) -> Result<DnsCompressedName> {
use bytes::Buf;
let mut c = Cursor::new(Bytes::from_static(bytes));
c.set_position(offset as u64);
let result = DnsPacketData::deserialize(&mut c)?;
if c.remaining() != 0 {
bail!("data remaining: {}", c.remaining())
}
Ok(result)
}
fn check_compressed_display(bytes: &'static [u8], offset: usize, txt: &str, label_count: u8) {
let name = de_compressed(bytes, offset).unwrap();
assert_eq!(
name.labels().count(),
label_count as usize
);
assert_eq!(
format!("{}", name),
txt
);
}
fn check_compressed_debug(bytes: &'static [u8], offset: usize, txt: &str) {
let name = de_compressed(bytes, offset).unwrap();
assert_eq!(
format!("{:?}", name),
txt
);
}
#[test]
fn parse_invalid_compressed_name() {
de_compressed(b"\x11com\x00\x07example\xc0\x00", 5).unwrap_err();
de_compressed(b"\x10com\x00\x07example\xc0\x00", 5).unwrap_err();
}
#[test]
fn parse_and_display_compressed_name() {
check_compressed_display(
b"\x03com\x00\x07example\xc0\x00", 5,
"example.com.",
2,
);
check_compressed_display(
b"\x03com\x00\x07e!am.l\\\xc0\x00", 5,
"e\\033am\\.l\\\\.com.",
2,
);
check_compressed_debug(
b"\x03com\x00\x07e!am.l\\\xc0\x00", 5,
r#""e\\033am\\.l\\\\.com.""#,
);
check_compressed_display(
b"\x03com\x00\x07example\xc0\x00\x03www\xc0\x05", 15,
"www.example.com.",
3,
);
}
#[test]
fn modifications_compressed() {
let mut name = de_compressed(b"\x03com\x00\x07example\xc0\x00\xc0\x05", 15).unwrap();
name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com."
);
name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com.org."
);
name.pop_front();
assert_eq!(
format!("{}", name),
"example.com.org."
);
name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org."
);
// the "mx" label should fit into the place "www" used before,
// make sure the buffer was reused and the name not moved within
assert_eq!(1, name.label_offsets.label_pos(0));
name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org.com."
);
}
}

View File

@ -0,0 +1,178 @@
use bytes::Bytes;
use errors::*;
use ser::packet::{DnsPacketData, DnsPacketWriteContext};
use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with};
use smallvec::SmallVec;
use std::fmt;
use std::io::Cursor;
use std::str::FromStr;
use super::{LabelOffsets, DnsNameIterator, DnsLabelRef, DnsLabel, DisplayLabels};
/// A DNS name
///
/// Uses the "original" raw representation for storage (i.e. can share
/// memory with a parsed packet)
#[derive(Clone)]
pub struct DnsName {
// in uncompressed form always includes terminating null octect;
// but even in uncompressed form can include unused bytes at the
// beginning
//
// may be empty for the root name (".", no labels)
pub(super) data: Bytes,
// either uncompressed or compressed offsets
pub(super) label_offsets: LabelOffsets,
// length of encoded form
pub(super) total_len: u8,
}
impl DnsName {
/// Create new name representing the DNS root (".")
pub fn new_root() -> Self {
DnsName{
data: Bytes::new(),
label_offsets: LabelOffsets::Uncompressed(SmallVec::new()),
total_len: 1,
}
}
/// Create new name representing the DNS root (".") and pre-allocate
/// storage
pub fn with_capacity(labels: u8, total_len: u8) -> Self {
DnsName{
data: Bytes::with_capacity(total_len as usize),
label_offsets: LabelOffsets::Uncompressed(SmallVec::with_capacity(labels as usize)),
total_len: 1,
}
}
/// Parse text representation of a domain name
pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self> {
super::name_text_parser::parse_name(context, value)
}
/// Returns whether name represents the DNS root (".")
pub fn is_root(&self) -> bool {
0 == self.label_count()
}
/// How many labels the name has (without the trailing empty label,
/// at most 127)
pub fn label_count(&self) -> u8 {
self.label_offsets.len()
}
/// Iterator over the labels (in the order they are stored in memory,
/// i.e. top-level name last).
pub fn labels<'a>(&'a self) -> DnsNameIterator<'a> {
DnsNameIterator{
name: &self,
front_label: 0,
back_label: self.label_offsets.len(),
}
}
/// Return label at index `ndx`
///
/// # Panics
///
/// panics if `ndx >= self.label_count()`.
pub fn label_ref<'a>(&'a self, ndx: u8) -> DnsLabelRef<'a> {
let pos = self.label_offsets.label_pos(ndx);
let label_len = self.data[pos];
debug_assert!(label_len < 64);
let end = pos + 1 + label_len as usize;
DnsLabelRef{label: &self.data[pos + 1..end]}
}
/// Return label at index `ndx`
///
/// # Panics
///
/// panics if `ndx >= self.label_count()`.
pub fn label(&self, ndx: u8) -> DnsLabel {
let pos = self.label_offsets.label_pos(ndx);
let label_len = self.data[pos];
debug_assert!(label_len < 64);
let end = pos + 1 + label_len as usize;
DnsLabel{label: self.data.slice(pos + 1, end) }
}
}
impl<'a> IntoIterator for &'a DnsName {
type Item = DnsLabelRef<'a>;
type IntoIter = DnsNameIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.labels()
}
}
impl PartialEq<DnsName> for DnsName
{
fn eq(&self, rhs: &DnsName) -> bool {
let a_labels = self.labels();
let b_labels = rhs.labels();
if a_labels.len() != b_labels.len() { return false; }
a_labels.zip(b_labels).all(|(a,b)| a == b)
}
}
impl<T> PartialEq<T> for DnsName
where
T: AsRef<DnsName>
{
fn eq(&self, rhs: &T) -> bool {
self == rhs.as_ref()
}
}
impl Eq for DnsName{}
impl fmt::Debug for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl fmt::Display for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl FromStr for DnsName {
type Err = ::failure::Error;
fn from_str(s: &str) -> Result<Self> {
parse_with(s, |data| DnsName::dns_parse(&DnsTextContext::new(), data))
}
}
impl DnsTextData for DnsName {
fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result<Self> {
let field = next_field(data)?;
DnsName::parse(context, field)
}
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl DnsPacketData for DnsName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
super::name_packet_parser::deserialize_name(data, false)
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
context.write_uncompressed_name(packet, self)
}
}

View File

@ -0,0 +1,46 @@
use super::{DnsName, DnsLabelRef};
/// Iterator type for [`DnsName::labels`]
///
/// [`DnsName::labels`]: struct.DnsName.html#method.labels
#[derive(Clone)]
pub struct DnsNameIterator<'a> {
pub(super) name: &'a DnsName,
pub(super) front_label: u8,
pub(super) back_label: u8,
}
impl<'a> Iterator for DnsNameIterator<'a> {
type Item = DnsLabelRef<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.front_label >= self.back_label { return None }
let label = self.name.label_ref(self.front_label);
self.front_label += 1;
Some(label)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let count = self.len();
(count, Some(count))
}
fn count(self) -> usize {
self.len()
}
}
impl<'a> ExactSizeIterator for DnsNameIterator<'a> {
fn len(&self) -> usize {
(self.back_label - self.front_label) as usize
}
}
impl<'a> DoubleEndedIterator for DnsNameIterator<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.front_label >= self.back_label { return None }
self.back_label -= 1;
let label = self.name.label_ref(self.back_label);
Some(label)
}
}

View File

@ -1,86 +1,84 @@
use bytes::Buf; use bytes::Buf;
use super::*; use super::*;
impl DnsName { /// `data`: bytes of packet from beginning until at least the end of the name
/// `data`: bytes of packet from beginning until at least the end of the name /// `start_pos`: position of first byte of the name
/// `start_pos`: position of first byte of the name /// `uncmpr_offsets`: offsets of uncompressed labels so far
/// `uncmpr_offsets`: offsets of uncompressed labels so far /// `label_len`: first compressed label length (`0xc0 | offset-high, offset-low`)
/// `label_len`: first compressed label length (`0xc0 | offset-high, offset-low`) /// `total_len`: length of (uncompressed) label encoding so far
/// `total_len`: length of (uncompressed) label encoding so far fn deserialize_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offsets: SmallVec<[u8;16]>, mut total_len: usize, mut label_len: u8) -> Result<DnsName> {
fn parse_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offsets: SmallVec<[u8;16]>, mut total_len: usize, mut label_len: u8) -> Result<Self> { let mut label_offsets = uncmpr_offsets.into_iter()
let mut label_offsets = uncmpr_offsets.into_iter() .map(LabelOffset::LabelStart)
.map(LabelOffset::LabelStart) .collect::<SmallVec<_>>();
.collect::<SmallVec<_>>();
let mut pos = start_pos + total_len; let mut pos = start_pos + total_len;
'next_compressed: loop { 'next_compressed: loop {
{ {
ensure!(pos + 1 < data.len(), "not enough data for compressed label"); ensure!(pos + 1 < data.len(), "not enough data for compressed label");
let new_pos = ((label_len as usize & 0x3f) << 8) | (data[pos + 1] as usize); let new_pos = ((label_len as usize & 0x3f) << 8) | (data[pos + 1] as usize);
ensure!(new_pos < pos, "Compressed label offset too big: {} >= {}", new_pos, pos); ensure!(new_pos < pos, "Compressed label offset too big: {} >= {}", new_pos, pos);
pos = new_pos; pos = new_pos;
}
loop {
ensure!(pos < data.len(), "not enough data for label");
label_len = data[pos];
if 0 == label_len {
return Ok(DnsName{
data: data,
label_offsets: LabelOffsets::Compressed(start_pos, label_offsets),
total_len: total_len as u8 + 1,
})
}
if label_len & 0xc0 == 0xc0 { continue 'next_compressed; }
ensure!(label_len < 64, "Invalid label length {}", label_len);
total_len += 1 + label_len as usize;
// max len 255, but there also needs to be an empty label at the end
if total_len > 254 { bail!("DNS name too long") }
label_offsets.push(LabelOffset::PacketStart(pos as u16));
pos += 1 + label_len as usize;
}
} }
}
pub(super) fn parse_name(data: &mut Cursor<Bytes>, accept_compressed: bool) -> Result<Self> {
check_enough_data!(data, 1, "DnsName");
let start_pos = data.position() as usize;
let mut total_len : usize = 0;
let mut label_offsets = SmallVec::new();
loop { loop {
check_enough_data!(data, 1, "DnsName label len"); ensure!(pos < data.len(), "not enough data for label");
let label_len = data.get_u8() as usize; label_len = data[pos];
if 0 == label_len { if 0 == label_len {
let end_pos = data.position() as usize;
return Ok(DnsName{ return Ok(DnsName{
data: data.get_ref().slice(start_pos, end_pos), data: data,
label_offsets: LabelOffsets::Uncompressed(label_offsets), label_offsets: LabelOffsets::Compressed(start_pos, label_offsets),
total_len: total_len as u8 + 1, total_len: total_len as u8 + 1,
}) })
} }
if label_len & 0xc0 == 0xc0 {
// compressed label
if !accept_compressed { bail!("Invalid label compression {}", label_len) }
check_enough_data!(data, 1, "DnsName compressed label target");
// eat second part of compressed label
data.get_u8();
let end_pos = data.position() as usize; if label_len & 0xc0 == 0xc0 { continue 'next_compressed; }
let data = data.get_ref().slice(0, end_pos); ensure!(label_len < 64, "Invalid label length {}", label_len);
return Self::parse_name_compressed_cont(data, start_pos, label_offsets, total_len, label_len as u8); total_len += 1 + label_len as usize;
}
label_offsets.push(total_len as u8);
if label_len > 63 { bail!("Invalid label length {}", label_len) }
total_len += 1 + label_len;
// max len 255, but there also needs to be an empty label at the end // max len 255, but there also needs to be an empty label at the end
if total_len > 254 { bail!{"DNS name too long"} } if total_len > 254 { bail!("DNS name too long") }
check_enough_data!(data, (label_len), "DnsName label");
data.advance(label_len); label_offsets.push(LabelOffset::PacketStart(pos as u16));
pos += 1 + label_len as usize;
} }
} }
} }
pub fn deserialize_name(data: &mut Cursor<Bytes>, accept_compressed: bool) -> Result<DnsName> {
check_enough_data!(data, 1, "DnsName");
let start_pos = data.position() as usize;
let mut total_len : usize = 0;
let mut label_offsets = SmallVec::new();
loop {
check_enough_data!(data, 1, "DnsName label len");
let label_len = data.get_u8() as usize;
if 0 == label_len {
let end_pos = data.position() as usize;
return Ok(DnsName{
data: data.get_ref().slice(start_pos, end_pos),
label_offsets: LabelOffsets::Uncompressed(label_offsets),
total_len: total_len as u8 + 1,
})
}
if label_len & 0xc0 == 0xc0 {
// compressed label
if !accept_compressed { bail!("Invalid label compression {}", label_len) }
check_enough_data!(data, 1, "DnsName compressed label target");
// eat second part of compressed label
data.get_u8();
let end_pos = data.position() as usize;
let data = data.get_ref().slice(0, end_pos);
return deserialize_name_compressed_cont(data, start_pos, label_offsets, total_len, label_len as u8);
}
label_offsets.push(total_len as u8);
if label_len > 63 { bail!("Invalid label length {}", label_len) }
total_len += 1 + label_len;
// max len 255, but there also needs to be an empty label at the end
if total_len > 254 { bail!{"DNS name too long"} }
check_enough_data!(data, (label_len), "DnsName label");
data.advance(label_len);
}
}

View File

@ -1,111 +1,65 @@
use super::*; use errors::*;
use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, quoted, parse_with}; use ser::text::{DnsTextContext, quoted};
impl DnsName { use super::{DnsName, DnsLabelRef};
/// Parse text representation of a domain name
pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self> /// Parse text representation of a domain name
{ pub fn parse_name(context: &DnsTextContext, value: &str) -> Result<DnsName>
let raw = value.as_bytes(); {
let mut name = DnsName::new_root(); let raw = value.as_bytes();
if raw == b"." { let mut name = DnsName::new_root();
return Ok(name); if raw == b"." {
} else if raw == b"@" { return Ok(name);
match context.origin() { } else if raw == b"@" {
Some(o) => return Ok(o.clone()), match context.origin() {
None => bail!("@ invalid without $ORIGIN"), Some(o) => return Ok(o.clone()),
} None => bail!("@ invalid without $ORIGIN"),
} }
ensure!(!raw.is_empty(), "invalid empty name"); }
let mut label = Vec::new(); ensure!(!raw.is_empty(), "invalid empty name");
let mut pos = 0; let mut label = Vec::new();
while pos < raw.len() { let mut pos = 0;
if raw[pos] == b'.' { while pos < raw.len() {
ensure!(!label.is_empty(), "empty label in name: {:?}", value); if raw[pos] == b'.' {
name.push_back(DnsLabelRef::new(&label)?)?; ensure!(!label.is_empty(), "empty label in name: {:?}", value);
label.clear();
} else if raw[pos] == b'\\' {
ensure!(pos + 1 < raw.len(), "unexpected end of name after backslash: {:?}", value);
if raw[pos+1] >= b'0' && raw[pos+1] <= b'9' {
// \ddd escape
ensure!(pos + 3 < raw.len(), "unexpected end of name after backslash with digit: {:?}", value);
ensure!(raw[pos+2] >= b'0' && raw[pos+2] <= b'9' && raw[pos+3] >= b'0' && raw[pos+3] <= b'9', "expected three digits after backslash in name: {:?}", name);
let d1 = (raw[pos+1] - b'0') as u32;
let d2 = (raw[pos+2] - b'0') as u32;
let d3 = (raw[pos+3] - b'0') as u32;
let v = d1 * 100 + d2 * 10 + d3;
ensure!(v < 256, "invalid escape in name, {} > 255: {:?}", v, name);
label.push(v as u8);
} else {
ensure!(!quoted::is_ascii_whitespace(raw[pos+1]), "whitespace cannot be escaped with backslash prefix; encode it as \\{:03} in: {:?}", raw[pos+1], name);
label.push(raw[pos+1]);
}
} else {
ensure!(!quoted::is_ascii_whitespace(raw[pos]), "whitespace must be encoded as \\{:03} in: {:?}", raw[pos], name);
label.push(raw[pos]);
}
pos += 1;
}
if !label.is_empty() {
// no trailing dot, relative name
// push last label
name.push_back(DnsLabelRef::new(&label)?)?; name.push_back(DnsLabelRef::new(&label)?)?;
label.clear();
match context.origin() { } else if raw[pos] == b'\\' {
Some(o) => { ensure!(pos + 1 < raw.len(), "unexpected end of name after backslash: {:?}", value);
for l in o { name.push_back(l)?; } if raw[pos+1] >= b'0' && raw[pos+1] <= b'9' {
}, // \ddd escape
None => bail!("missing trailing dot without $ORIGIN"), ensure!(pos + 3 < raw.len(), "unexpected end of name after backslash with digit: {:?}", value);
ensure!(raw[pos+2] >= b'0' && raw[pos+2] <= b'9' && raw[pos+3] >= b'0' && raw[pos+3] <= b'9', "expected three digits after backslash in name: {:?}", name);
let d1 = (raw[pos+1] - b'0') as u32;
let d2 = (raw[pos+2] - b'0') as u32;
let d3 = (raw[pos+3] - b'0') as u32;
let v = d1 * 100 + d2 * 10 + d3;
ensure!(v < 256, "invalid escape in name, {} > 255: {:?}", v, name);
label.push(v as u8);
} else {
ensure!(!quoted::is_ascii_whitespace(raw[pos+1]), "whitespace cannot be escaped with backslash prefix; encode it as \\{:03} in: {:?}", raw[pos+1], name);
label.push(raw[pos+1]);
} }
} else {
ensure!(!quoted::is_ascii_whitespace(raw[pos]), "whitespace must be encoded as \\{:03} in: {:?}", raw[pos], name);
label.push(raw[pos]);
} }
pos += 1;
Ok(name)
}
}
impl DnsTextData for DnsName {
fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result<Self> {
let field = next_field(data)?;
DnsName::parse(context, field)
}
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
write!(f, "{}", self)
}
}
impl ::std::str::FromStr for DnsName {
type Err = ::failure::Error;
fn from_str(s: &str) -> Result<Self> {
parse_with(s, |data| DnsName::dns_parse(&DnsTextContext::new(), data))
}
}
impl DnsCompressedName {
/// Parse text representation of a domain name
pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self>
{
Ok(DnsCompressedName(DnsName::parse(context, value)?))
}
}
impl DnsTextData for DnsCompressedName {
fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result<Self> {
let field = next_field(data)?;
DnsCompressedName::parse(context, field)
}
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
self.0.dns_format(f)
}
}
impl ::std::str::FromStr for DnsCompressedName {
type Err = ::failure::Error;
fn from_str(s: &str) -> Result<Self> {
parse_with(s, |data| DnsCompressedName::dns_parse(&DnsTextContext::new(), data))
} }
if !label.is_empty() {
// no trailing dot, relative name
// push last label
name.push_back(DnsLabelRef::new(&label)?)?;
match context.origin() {
Some(o) => {
for l in o { name.push_back(l)?; }
},
None => bail!("missing trailing dot without $ORIGIN"),
}
}
Ok(name)
} }

View File

@ -0,0 +1,219 @@
use bytes::Bytes;
use ser::packet;
use ser::packet::DnsPacketData;
use std::io::Cursor;
use errors::*;
use super::{DnsName, DnsCompressedName, DnsLabelRef, DisplayLabels, DisplayLabelsOptions};
/*
fn deserialize(bytes: &'static [u8]) -> Result<DnsName> {
let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?;
{
let check_result = packet::deserialize_with(result.clone().encode(), DnsName::deserialize).unwrap();
assert_eq!(check_result, result);
}
Ok(result)
}
*/
fn de_uncompressed(bytes: &'static [u8]) -> Result<DnsName> {
let result = packet::deserialize_with(Bytes::from_static(bytes), DnsName::deserialize)?;
assert_eq!(bytes, result.clone().encode());
Ok(result)
}
fn check_uncompressed_display(bytes: &'static [u8], txt: &str, label_count: u8) {
let name = de_uncompressed(bytes).unwrap();
assert_eq!(
name.labels().count(),
label_count as usize
);
assert_eq!(
format!("{}", name),
txt
);
}
fn check_uncompressed_debug(bytes: &'static [u8], txt: &str) {
let name = de_uncompressed(bytes).unwrap();
assert_eq!(
format!("{:?}", name),
txt
);
}
#[test]
fn parse_and_display_name() {
check_uncompressed_display(
b"\x07example\x03com\x00",
"example.com.",
2,
);
check_uncompressed_display(
b"\x07e!am.l\\\x03com\x00",
"e\\033am\\.l\\\\.com.",
2,
);
check_uncompressed_debug(
b"\x07e!am.l\\\x03com\x00",
r#""e\\033am\\.l\\\\.com.""#,
);
}
#[test]
fn parse_and_reverse_name() {
let name = de_uncompressed(b"\x03www\x07example\x03com\x00").unwrap();
assert_eq!(
format!(
"{}",
DisplayLabels{
labels: name.labels().rev(),
options: DisplayLabelsOptions{
separator: " ",
trailing: false,
},
}
),
"com example www"
);
}
#[test]
fn modifications() {
let mut name = de_uncompressed(b"\x07example\x03com\x00").unwrap();
name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com."
);
name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com.org."
);
name.pop_front();
assert_eq!(
format!("{}", name),
"example.com.org."
);
name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org."
);
// the "mx" label should fit into the place "www" used before,
// make sure the buffer was reused and the name not moved within
assert_eq!(1, name.label_offsets.label_pos(0));
name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org.com."
);
}
fn de_compressed(bytes: &'static [u8], offset: usize) -> Result<DnsCompressedName> {
use bytes::Buf;
let mut c = Cursor::new(Bytes::from_static(bytes));
c.set_position(offset as u64);
let result = DnsPacketData::deserialize(&mut c)?;
if c.remaining() != 0 {
bail!("data remaining: {}", c.remaining())
}
Ok(result)
}
fn check_compressed_display(bytes: &'static [u8], offset: usize, txt: &str, label_count: u8) {
let name = de_compressed(bytes, offset).unwrap();
assert_eq!(
name.labels().count(),
label_count as usize
);
assert_eq!(
format!("{}", name),
txt
);
}
fn check_compressed_debug(bytes: &'static [u8], offset: usize, txt: &str) {
let name = de_compressed(bytes, offset).unwrap();
assert_eq!(
format!("{:?}", name),
txt
);
}
#[test]
fn parse_invalid_compressed_name() {
de_compressed(b"\x11com\x00\x07example\xc0\x00", 5).unwrap_err();
de_compressed(b"\x10com\x00\x07example\xc0\x00", 5).unwrap_err();
}
#[test]
fn parse_and_display_compressed_name() {
check_compressed_display(
b"\x03com\x00\x07example\xc0\x00", 5,
"example.com.",
2,
);
check_compressed_display(
b"\x03com\x00\x07e!am.l\\\xc0\x00", 5,
"e\\033am\\.l\\\\.com.",
2,
);
check_compressed_debug(
b"\x03com\x00\x07e!am.l\\\xc0\x00", 5,
r#""e\\033am\\.l\\\\.com.""#,
);
check_compressed_display(
b"\x03com\x00\x07example\xc0\x00\x03www\xc0\x05", 15,
"www.example.com.",
3,
);
}
#[test]
fn modifications_compressed() {
let mut name = de_compressed(b"\x03com\x00\x07example\xc0\x00\xc0\x05", 15).unwrap();
name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com."
);
name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"www.example.com.org."
);
name.pop_front();
assert_eq!(
format!("{}", name),
"example.com.org."
);
name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org."
);
// the "mx" label should fit into the place "www" used before,
// make sure the buffer was reused and the name not moved within
assert_eq!(1, name.label_offsets.label_pos(0));
name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap();
assert_eq!(
format!("{}", name),
"mx.example.com.org.com."
);
}

View File

@ -474,7 +474,6 @@ impl Type {
/// parses generic names of the form "TYPE..." /// parses generic names of the form "TYPE..."
pub fn from_generic_name(name: &str) -> Option<Self> { pub fn from_generic_name(name: &str) -> Option<Self> {
use std::ascii::AsciiExt;
if name.len() > 4 && name.as_bytes()[0..4].eq_ignore_ascii_case(b"TYPE") { if name.len() > 4 && name.as_bytes()[0..4].eq_ignore_ascii_case(b"TYPE") {
name[4..].parse::<u16>().ok().map(Type) name[4..].parse::<u16>().ok().map(Type)
} else { } else {

View File

@ -1,6 +1,5 @@
use bytes::Bytes; use bytes::Bytes;
use std::any::TypeId; use std::any::TypeId;
use std::ascii::AsciiExt;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Cursor; use std::io::Cursor;
use std::marker::PhantomData; use std::marker::PhantomData;

View File

@ -68,6 +68,24 @@ fn write_name(packet: &mut Vec<u8>, name: &DnsName) {
packet.put_u8(0); packet.put_u8(0);
} }
fn write_canonical_label(packet: &mut Vec<u8>, label: DnsLabelRef) {
let l = label.len();
debug_assert!(l < 64);
packet.reserve(l as usize + 1);
packet.put_u8(l);
for c in label.as_raw() {
packet.put_u8(c.to_ascii_lowercase());
}
}
fn write_canonical_name(packet: &mut Vec<u8>, name: &DnsName) {
for label in name {
write_canonical_label(packet, label);
}
packet.reserve(1);
packet.put_u8(0);
}
fn write_label_remember(packet: &mut Vec<u8>, labels: &mut Vec<LabelEntry>, label: DnsLabelRef, next_entry: usize) { fn write_label_remember(packet: &mut Vec<u8>, labels: &mut Vec<LabelEntry>, label: DnsLabelRef, next_entry: usize) {
labels.push(LabelEntry { labels.push(LabelEntry {
pos: packet.len(), pos: packet.len(),
@ -76,9 +94,23 @@ fn write_label_remember(packet: &mut Vec<u8>, labels: &mut Vec<LabelEntry>, labe
write_label(packet, label); write_label(packet, label);
} }
#[derive(Clone, Debug)]
enum LabelWriteMethod {
Uncompressed,
Compressed(Vec<LabelEntry>),
Canonical, // DNSSEC, uncompressed + ASCII lower-case
}
impl Default for LabelWriteMethod {
fn default() -> Self {
LabelWriteMethod::Uncompressed
}
}
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct DnsPacketWriteContext { pub struct DnsPacketWriteContext {
labels: Option<Vec<LabelEntry>>, labels: LabelWriteMethod,
} }
impl DnsPacketWriteContext { impl DnsPacketWriteContext {
@ -86,11 +118,26 @@ impl DnsPacketWriteContext {
Default::default() Default::default()
} }
/// Enables writing compressed names
///
/// Only `DnsCompressedName` uses compression, `DnsName` and
/// `DnsCanonicalName` are never compressed.
pub fn enable_compression(&mut self) { pub fn enable_compression(&mut self) {
self.labels = Some(Vec::new()); self.labels = LabelWriteMethod::Compressed(Vec::new());
} }
pub fn write_uncompressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsName) -> Result<()> { /// Enables writing canonical names
///
/// Disables compression, and converts `DnsCompressedName` and
/// `DnsCanonicalName` to ASCII lowercase.
///
/// `DnsName` is never compressed, but also never converted to
/// lowercase.
pub fn enable_canonical(&mut self) {
self.labels = LabelWriteMethod::Canonical;
}
pub(crate) fn write_uncompressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsName) -> Result<()> {
// for now we don't remember labels of these names. // for now we don't remember labels of these names.
// //
// if we did: would we want to check whether a suffix is already // if we did: would we want to check whether a suffix is already
@ -100,19 +147,36 @@ impl DnsPacketWriteContext {
Ok(()) Ok(())
} }
pub fn write_compressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsCompressedName) -> Result<()> { pub(crate) fn write_canonical_name(&mut self, packet: &mut Vec<u8>, name: &DnsName) -> Result<()> {
match self.labels {
LabelWriteMethod::Uncompressed | LabelWriteMethod::Compressed(_) => {
// uncompressed
write_name(packet, name);
},
LabelWriteMethod::Canonical => {
write_canonical_name(packet, name);
},
}
return Ok(())
}
pub(crate) fn write_compressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsCompressedName) -> Result<()> {
// for DNSSEC we need to write it canonical
if name.is_root() { if name.is_root() {
write_name(packet, name); write_name(packet, name);
return Ok(()); return Ok(());
} }
let labels = match self.labels { let labels = match self.labels {
Some(ref mut labels) => labels, LabelWriteMethod::Uncompressed => {
None => {
// compression disabled
write_name(packet, name); write_name(packet, name);
return Ok(()); return Ok(());
} },
LabelWriteMethod::Compressed(ref mut labels) => labels,
LabelWriteMethod::Canonical => {
write_canonical_name(packet, name);
return Ok(())
},
}; };
let mut best_match = None; let mut best_match = None;