You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

585 lines
14 KiB

4 years ago
#![deny(missing_docs)]
//! Various structs to represents DNS names and labels
use bytes::{Bytes,Buf,BytesMut};
use errors::*;
use std::fmt;
use std::io::Cursor;
use packet_data::{DnsPacketData,deserialize};
#[inline]
fn check_label(label: &[u8]) -> Result<()> {
if label.len() == 0 {
bail!("label must not be empty")
}
if label.len() > 63 {
bail!("label must not be longer than 63 bytes")
}
Ok(())
}
/// A DNS label (any binary string with `0 < length < 64`)
#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash)]
pub struct DnsLabel {
label: Bytes, // 0 < len < 64
}
impl DnsLabel {
/// Create new label from existing storage
///
/// Fails when the length doesn't match the requirement `0 < length < 64`.
pub fn new(label: Bytes) -> Result<Self> {
check_label(&label)?;
Ok(DnsLabel{label})
}
/// Convert to a representation without storage
pub fn as_ref<'a>(&'a self) -> DnsLabelRef<'a> {
DnsLabelRef{label: self.label.as_ref()}
}
/// Access as raw bytes
pub fn as_bytes(&self) -> &Bytes {
&self.label
}
/// Access as raw bytes
pub fn as_raw(&self) -> &[u8] {
&self.label
}
}
impl<'a> From<DnsLabelRef<'a>> for DnsLabel {
fn from(label_ref: DnsLabelRef<'a>) -> Self {
DnsLabel{
label: Bytes::from(label_ref.label),
}
}
}
impl fmt::Debug for DnsLabel {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.as_ref().fmt(w)
}
}
impl fmt::Display for DnsLabel {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
self.as_ref().fmt(w)
}
}
/// A DNS label (any binary string with `0 < length < 64`)
///
/// Storage is provided through lifetime.
#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash)]
pub struct DnsLabelRef<'a> {
label: &'a [u8], // 0 < len < 64
}
impl<'a> DnsLabelRef<'a> {
/// Create new label from existing storage
///
/// Fails when the length doesn't match the requirement `0 < length < 64`.
pub fn new(label: &'a [u8]) -> Result<Self> {
check_label(label)?;
Ok(DnsLabelRef{label})
}
/// Access as raw bytes
pub fn as_raw(&self) -> &'a [u8] {
self.label
}
}
impl<'a> From<&'a DnsLabel> for DnsLabelRef<'a> {
fn from(label: &'a DnsLabel) -> Self {
label.as_ref()
}
}
impl<'a> fmt::Debug for DnsLabelRef<'a> {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
// just escape the display version
format!("{}", self).fmt(w)
}
}
impl<'a> fmt::Display for DnsLabelRef<'a> {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
use std::str;
let mut done = 0;
for pos in 0..self.label.len() {
let c = self.label[pos];
if c <= 0x21 || c >= 0x7e || b'.' == c || b'\\' == c {
// flush
if done < pos {
w.write_str(unsafe {str::from_utf8_unchecked(&self.label[done..pos])})?;
}
match c {
b'.' => w.write_str(r#"\."#)?,
b'\\' => w.write_str(r#"\\"#)?,
_ => write!(w, r"\{:03o}", c)?,
}
done = pos + 1;
}
}
// final flush
if done < self.label.len() {
w.write_str(unsafe {str::from_utf8_unchecked(&self.label[done..])})?;
}
Ok(())
}
}
/// Customize formatting of DNS labels
///
/// The default uses "." as separator and adds a trailing separator.
///
/// The `Debug` formatters just format as `Display` to a String and then
/// format that string as `Debug`.
#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)]
pub struct DisplayLabelsOptions {
/// separator to insert between (escaped) labels
pub separator: &'static str,
/// whether a trailing separator is added.
///
/// without a trailing separator the root zone is represented as
/// empty string!
pub trailing: bool,
}
impl Default for DisplayLabelsOptions {
fn default() -> Self {
DisplayLabelsOptions{
separator: ".",
trailing: true,
}
}
}
/// Wrap anything representing a collection of labels (`DnsLabelRef`) to
/// format using the given `options`.
///
/// As name you can pass any cloneable `(Into)Iterator` with
/// `DnsLabelRef` items, e.g:
///
/// * `&DnsPlainName`
/// * `&DnsName`
/// * `DnsNameIterator`
/// * `DnsPlainNameIterator`
#[derive(Clone)]
pub struct DisplayLabels<'a, I>
where
I: IntoIterator<Item=DnsLabelRef<'a>>+Clone
{
/// Label collection to iterate over
pub labels: I,
/// Options
pub options: DisplayLabelsOptions,
}
impl<'a, I> fmt::Debug for DisplayLabels<'a, I>
where
I: IntoIterator<Item=DnsLabelRef<'a>>+Clone
{
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
// just escape the display version1
format!("{}", self).fmt(w)
}
}
impl<'a, I> fmt::Display for DisplayLabels<'a, I>
where
I: IntoIterator<Item=DnsLabelRef<'a>>+Clone
{
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
let mut i = self.labels.clone().into_iter();
if let Some(first_label) = i.next() {
// first label
fmt::Display::fmt(&first_label, w)?;
// remaining labels
while let Some(label) = i.next() {
w.write_str(self.options.separator)?;
fmt::Display::fmt(&label, w)?;
}
}
if self.options.trailing {
w.write_str(self.options.separator)?;
}
Ok(())
}
}
/// A DNS name
///
/// Uses the "original" uncompressed raw representation for storage
/// (i.e. can share memory with a parsed packet)
#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash)]
pub struct DnsPlainName {
// uncompressed raw representation
data: Bytes, // at most 255 bytes
label_count: u8, // at most 127 labels; doesn't count the final (empty) label
}
impl DnsPlainName {
/// Parse a name from raw bytes
pub fn new(raw: Bytes) -> Result<Self> {
deserialize(raw)
}
/// How many labels the name has (without the trailing empty label,
/// at most 127)
pub fn label_count(&self) -> u8 {
self.label_count
}
/// Iterator over the labels (in the order they are stored in memory,
/// i.e. top-level name last).
pub fn labels<'a>(&'a self) -> DnsPlainNameIterator<'a> {
DnsPlainNameIterator{
name_data: self.data.as_ref(),
position: 0,
labels_done: 0,
label_count: self.label_count,
}
}
}
impl<'a> IntoIterator for &'a DnsPlainName {
type Item = DnsLabelRef<'a>;
type IntoIter = DnsPlainNameIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.labels()
}
}
impl fmt::Debug for DnsPlainName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl fmt::Display for DnsPlainName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl DnsPacketData for DnsPlainName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
check_enough_data!(data, 1, "DnsPlainName");
let start_pos = data.position() as usize;
let mut total_len : usize = 0;
let mut label_count: u8 = 0;
loop {
check_enough_data!(data, 1, "DnsPlainName label len");
let label_len = data.get_u8() as usize;
total_len += 1;
if total_len > 255 { bail!{"DNS name too long"} }
if 0 == label_len { break; }
label_count += 1; // can't overflow: total_len <= 255, and each label so far was not empty, i.e. used at least two bytes.
if label_len > 63 { bail!("Invalid label length {}", label_len) }
total_len += label_len;
if total_len > 255 { bail!{"DNS name too long"} }
check_enough_data!(data, (label_len), "DnsPlainName label");
data.advance(label_len);
}
let end_pos = data.position() as usize;
Ok(DnsPlainName{
data: data.get_ref().slice(start_pos, end_pos),
label_count: label_count,
})
}
}
/// Iterator type for [`DnsPlainName::labels`]
///
/// [`DnsPlainName::labels`]: struct.DnsPlainName.html#method.labels
#[derive(Clone)]
pub struct DnsPlainNameIterator<'a> {
name_data: &'a [u8],
position: u8,
labels_done: u8,
label_count: u8,
}
impl<'a> Iterator for DnsPlainNameIterator<'a> {
type Item = DnsLabelRef<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.labels_done >= self.label_count { return None }
self.labels_done += 1;
let label_len = self.name_data[self.position as usize];
let end = self.position+1+label_len;
let label = DnsLabelRef{label: &self.name_data[(self.position+1) as usize..end as usize]};
self.position = end;
Some(label)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let count = self.len();
(count, Some(count))
}
fn count(self) -> usize {
self.len()
}
}
impl<'a> ExactSizeIterator for DnsPlainNameIterator<'a> {
fn len(&self) -> usize {
(self.label_count - self.labels_done) as usize
}
}
/// A DNS name
///
/// Uses a modified representation for storage (i.e. can NOT share
/// memory with a parsed packet), but supports a `DoubleEndedIterator`
/// to iterate over the labels.
#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash)]
pub struct DnsName {
// similar to the DNS encoding; but each "length" octet is the "XOR"
// of the length of the surrounding labels.
data: Bytes, // at most 255 bytes
label_count: u8, // at most 127 labels; doesn't count the final (empty) label
}
impl DnsName {
/// Parse a name from raw bytes
pub fn new(raw: Bytes) -> Result<Self> {
deserialize(raw)
}
/// How many labels the name has (without the trailing empty label,
/// at most 127)
pub fn label_count(&self) -> u8 {
self.label_count
}
/// Iterator over the labels (in the order they are stored in memory,
/// i.e. top-level name last).
pub fn labels(&self) -> DnsNameIterator {
DnsNameIterator{
name_data: self.data.as_ref(),
front_position: 0,
front_prev_label_len: 0,
back_position: self.data.len() as u8 - 1,
back_next_label_len: 0,
labels_done: 0,
label_count: self.label_count,
}
}
}
impl From<DnsPlainName> for DnsName {
fn from(plain: DnsPlainName) -> Self {
let mut data = plain.data.try_mut().unwrap_or_else(BytesMut::from);
let mut pos : u8 = 0;
let mut prev_len : u8 = 0;
loop {
let label_len = data[pos as usize];
data[pos as usize] ^= prev_len;
if 0 == label_len { break; }
pos += label_len + 1;
prev_len = label_len;
}
DnsName{
data: data.freeze(),
label_count: plain.label_count,
}
}
}
impl From<DnsName> for DnsPlainName {
fn from(name: DnsName) -> DnsPlainName {
let mut data = name.data.try_mut().unwrap_or_else(BytesMut::from);
let mut pos : u8 = 0;
let mut prev_len : u8 = 0;
loop {
let label_len = data[pos as usize] ^ prev_len;
data[pos as usize] = label_len;
if 0 == label_len { break; }
pos += label_len + 1;
prev_len = label_len;
}
DnsPlainName{
data: data.freeze(),
label_count: name.label_count,
}
}
}
impl<'a> IntoIterator for &'a DnsName {
type Item = DnsLabelRef<'a>;
type IntoIter = DnsNameIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.labels()
}
}
impl fmt::Debug for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl fmt::Display for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{
labels: self,
options: Default::default(),
}.fmt(w)
}
}
impl DnsPacketData for DnsName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
Ok(DnsName::from(DnsPlainName::deserialize(data)?))
}
}
/// Iterator type for [`DnsName::labels`]
///
/// [`DnsName::labels`]: struct.DnsName.html#method.labels
#[derive(Clone)]
pub struct DnsNameIterator<'a> {
name_data: &'a [u8],
front_position: u8,
front_prev_label_len: u8,
back_position: u8,
back_next_label_len: u8,
labels_done: u8,
label_count: u8,
}
impl<'a> Iterator for DnsNameIterator<'a> {
type Item = DnsLabelRef<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.labels_done >= self.label_count { return None }
self.labels_done += 1;
let label_len = self.name_data[self.front_position as usize] ^ self.front_prev_label_len;
self.front_prev_label_len = label_len;
let end = self.front_position+1+label_len;
let label = DnsLabelRef{label: &self.name_data[(self.front_position+1) as usize..end as usize]};
self.front_position = end;
Some(label)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let count = self.len();
(count, Some(count))
}
fn count(self) -> usize {
self.len()
}
}
impl<'a> ExactSizeIterator for DnsNameIterator<'a> {
fn len(&self) -> usize {
(self.label_count - self.labels_done) as usize
}
}
impl<'a> DoubleEndedIterator for DnsNameIterator<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.labels_done >= self.label_count { return None }
self.labels_done += 1;
let label_len = self.name_data[self.back_position as usize] ^ self.back_next_label_len;
self.back_next_label_len = label_len;
let end = self.back_position;
self.back_position -= 1 + label_len;
let label = DnsLabelRef{label: &self.name_data[(self.back_position+1) as usize..end as usize]};
Some(label)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn do_parse_and_display_name<T>()
where
T: fmt::Display+fmt::Debug+DnsPacketData,
for<'a> &'a T: IntoIterator<Item=DnsLabelRef<'a>>,
{
{
let name = deserialize::<T>(Bytes::from_static(b"\x07example\x03com\x00")).unwrap();
assert_eq!(
format!("{}", name
),
"example.com."
);
assert_eq!(
name.into_iter().count(),
2
);
}
assert_eq!(
format!(
"{}",
deserialize::<T>(Bytes::from_static(b"\x07e!am.l\\\x03com\x00")).unwrap()
),
"e\\041am\\.l\\\\.com."
);
assert_eq!(
format!(
"{:?}",
deserialize::<T>(Bytes::from_static(b"\x07e!am.l\\\x03com\x00")).unwrap()
),
r#""e\\041am\\.l\\\\.com.""#
);
}
#[test]
fn parse_and_display_plain_name() {
do_parse_and_display_name::<DnsPlainName>();
}
#[test]
fn parse_and_display_name() {
do_parse_and_display_name::<DnsName>();
}
#[test]
fn parse_and_reverse_name() {
let name = deserialize::<DnsName>(Bytes::from_static(b"\x03www\x07example\x03com\x00")).unwrap();
assert_eq!(
format!(
"{}",
DisplayLabels{
labels: name.labels().rev(),
options: DisplayLabelsOptions{
separator: " ",
trailing: false,
},
}
),
"com example www"
);
}
#[test]
fn parse_and_convert_names() {
let name = deserialize::<DnsPlainName>(Bytes::from_static(b"\x03www\x07example\x03com\x00")).unwrap();
assert_eq!(
DnsPlainName::from(DnsName::from(name.clone())),
name
);
}
}