cargo fmt

This commit is contained in:
Stefan Bühler 2020-03-07 16:57:47 +01:00
parent 460a8d1755
commit 9d5314d127
52 changed files with 1474 additions and 961 deletions

View File

@ -1,8 +1,8 @@
use dnsbox_base::common_types::{Type, DnsName}; use dnsbox_base::common_types::{DnsName, Type};
use failure::Error;
use dnsbox_base::ser::RRData; use dnsbox_base::ser::RRData;
use futures::{Future, Poll, Async}; use failure::Error;
use futures::unsync::oneshot; use futures::unsync::oneshot;
use futures::{Async, Future, Poll};
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::mem::replace; use std::mem::replace;
@ -82,13 +82,13 @@ impl InnerEntry {
replace(self, InnerEntry::Refreshing(e)); replace(self, InnerEntry::Refreshing(e));
(false, res) (false, res)
} }
} },
InnerEntry::Pending(mut queue) => { InnerEntry::Pending(mut queue) => {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
queue.push(tx); queue.push(tx);
replace(self, InnerEntry::Pending(queue)); replace(self, InnerEntry::Pending(queue));
(false, CacheResult(InnerResult::Waiting(rx))) (false, CacheResult(InnerResult::Waiting(rx)))
} },
} }
} }
} }
@ -168,7 +168,7 @@ impl Future for CacheResult {
// keep waiting // keep waiting
replace(&mut self.0, InnerResult::Waiting(rc)); replace(&mut self.0, InnerResult::Waiting(rc));
Ok(Async::NotReady) Ok(Async::NotReady)
} },
} }
}, },
InnerResult::Finished => Ok(Async::NotReady), InnerResult::Finished => Ok(Async::NotReady),

View File

@ -1,6 +1,6 @@
use dnsbox_base::common_types::{types}; use dnsbox_base::common_types::types;
use dnsbox_base::common_types::{DnsName, DnsCompressedName}; use dnsbox_base::common_types::{DnsCompressedName, DnsName};
use dnsbox_base::records::{NS, A, AAAA}; use dnsbox_base::records::{A, AAAA, NS};
use dnsbox_base::ser::RRData; use dnsbox_base::ser::RRData;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
@ -26,13 +26,21 @@ pub fn load_hints(cache: &super::Cache) {
for &(name, ipv4, ipv6) in &DATA { for &(name, ipv4, ipv6) in &DATA {
let name = name.parse::<DnsName>().expect("invalid root hint name"); let name = name.parse::<DnsName>().expect("invalid root hint name");
let ipv4 = ipv4.parse::<Ipv4Addr>().expect("invalid root hint ipv4 addr"); let ipv4 = ipv4
let ipv6 = ipv6.parse::<Ipv6Addr>().expect("invalid root hint ipv6 addr"); .parse::<Ipv4Addr>()
.expect("invalid root hint ipv4 addr");
let ipv6 = ipv6
.parse::<Ipv6Addr>()
.expect("invalid root hint ipv6 addr");
root_ns_set.push(Box::new(NS { root_ns_set.push(Box::new(NS {
nsdname: DnsCompressedName(name.clone()), nsdname: DnsCompressedName(name.clone()),
})); }));
cache.insert_hint(name.clone(), types::A, vec![Box::new(A { addr: ipv4 })]); cache.insert_hint(name.clone(), types::A, vec![Box::new(A { addr: ipv4 })]);
cache.insert_hint(name.clone(), types::AAAA, vec![Box::new(AAAA { addr: ipv6 })]); cache.insert_hint(
name.clone(),
types::AAAA,
vec![Box::new(AAAA { addr: ipv6 })],
);
} }
cache.insert_hint(DnsName::new_root(), types::NS, root_ns_set); cache.insert_hint(DnsName::new_root(), types::NS, root_ns_set);

View File

@ -1,15 +1,17 @@
use bytes::{Bytes, BufMut};
use data_encoding::{self, HEXLOWER_PERMISSIVE};
use crate::errors::*; use crate::errors::*;
use failure::{Fail, ResultExt}; use crate::ser::packet::{
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes, short_blob, write_short_blob, get_blob}; get_blob, remaining_bytes, short_blob, write_short_blob, DnsPacketData, DnsPacketWriteContext,
};
use crate::ser::text::*; use crate::ser::text::*;
use bytes::{BufMut, Bytes};
use data_encoding::{self, HEXLOWER_PERMISSIVE};
use failure::{Fail, ResultExt};
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
static WHITESPACE: &str = "\t\n\x0c\r "; // \f == \x0c formfeed static WHITESPACE: &str = "\t\n\x0c\r "; // \f == \x0c formfeed
lazy_static::lazy_static!{ lazy_static::lazy_static! {
pub(crate) static ref HEXLOWER_PERMISSIVE_ALLOW_WS: data_encoding::Encoding = { pub(crate) static ref HEXLOWER_PERMISSIVE_ALLOW_WS: data_encoding::Encoding = {
let mut spec = data_encoding::Specification::new(); let mut spec = data_encoding::Specification::new();
spec.symbols.push_str("0123456789abcdef"); spec.symbols.push_str("0123456789abcdef");
@ -35,7 +37,10 @@ pub struct HexShortBlob(Bytes);
impl HexShortBlob { impl HexShortBlob {
pub fn new(data: Vec<u8>) -> Result<Self> { pub fn new(data: Vec<u8>) -> Result<Self> {
failure::ensure!(data.len() < 256, "short hex blob must be at most 255 bytes long"); failure::ensure!(
data.len() < 256,
"short hex blob must be at most 255 bytes long"
);
Ok(Self(data.into())) Ok(Self(data.into()))
} }
} }
@ -56,9 +61,13 @@ impl DnsTextData for HexShortBlob {
if s == "-" { if s == "-" {
Ok(HexShortBlob(Bytes::new())) Ok(HexShortBlob(Bytes::new()))
} else { } else {
let raw = HEXLOWER_PERMISSIVE.decode(s.as_bytes()) let raw = HEXLOWER_PERMISSIVE
.decode(s.as_bytes())
.with_context(|e| e.context(format!("invalid hex: {:?}", s)))?; .with_context(|e| e.context(format!("invalid hex: {:?}", s)))?;
failure::ensure!(raw.len() < 256, "short hex field must be at most 255 bytes long"); failure::ensure!(
raw.len() < 256,
"short hex field must be at most 255 bytes long"
);
Ok(HexShortBlob(raw.into())) Ok(HexShortBlob(raw.into()))
} }
} }
@ -89,7 +98,10 @@ pub struct Base64LongBlob(Bytes);
impl Base64LongBlob { impl Base64LongBlob {
pub fn new(data: Vec<u8>) -> Result<Self> { pub fn new(data: Vec<u8>) -> Result<Self> {
failure::ensure!(data.len() < 0x1_0000, "long base64 blob must be at most 65535 bytes long"); failure::ensure!(
data.len() < 0x1_0000,
"long base64 blob must be at most 65535 bytes long"
);
Ok(Self(data.into())) Ok(Self(data.into()))
} }
} }
@ -115,12 +127,14 @@ impl DnsPacketData for Base64LongBlob {
impl DnsTextData for Base64LongBlob { impl DnsTextData for Base64LongBlob {
fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> { fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> {
let length_field = next_field(data)?; let length_field = next_field(data)?;
let length = length_field.parse::<u16>() let length = length_field
.parse::<u16>()
.with_context(|_| format!("invalid length for blob: {:?}", length_field))?; .with_context(|_| format!("invalid length for blob: {:?}", length_field))?;
if length > 0 { if length > 0 {
let blob_field = next_field(data)?; let blob_field = next_field(data)?;
let result = BASE64_ALLOW_WS.decode(blob_field.as_bytes()) let result = BASE64_ALLOW_WS
.decode(blob_field.as_bytes())
.with_context(|e| e.context(format!("invalid base64: {:?}", blob_field)))?; .with_context(|e| e.context(format!("invalid base64: {:?}", blob_field)))?;
Ok(Base64LongBlob(result.into())) Ok(Base64LongBlob(result.into()))
} else { } else {
@ -129,7 +143,9 @@ impl DnsTextData for Base64LongBlob {
} }
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
if self.0.len() >= 0x1_0000 { return Err(fmt::Error); } if self.0.len() >= 0x1_0000 {
return Err(fmt::Error);
}
if self.0.is_empty() { if self.0.is_empty() {
write!(f, "0") write!(f, "0")
} else { } else {
@ -175,7 +191,8 @@ impl DnsPacketData for Base64RemainingBlob {
impl DnsTextData for Base64RemainingBlob { impl DnsTextData for Base64RemainingBlob {
fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> { fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> {
skip_whitespace(data); skip_whitespace(data);
let result = BASE64_ALLOW_WS.decode(data.as_bytes()) let result = BASE64_ALLOW_WS
.decode(data.as_bytes())
.with_context(|e| e.context(format!("invalid base64: {:?}", data)))?; .with_context(|e| e.context(format!("invalid base64: {:?}", data)))?;
*data = ""; *data = "";
Ok(Base64RemainingBlob(result.into())) Ok(Base64RemainingBlob(result.into()))
@ -226,7 +243,8 @@ impl DnsPacketData for HexRemainingBlob {
impl DnsTextData for HexRemainingBlob { impl DnsTextData for HexRemainingBlob {
fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> { fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> {
skip_whitespace(data); skip_whitespace(data);
let result = HEXLOWER_PERMISSIVE_ALLOW_WS.decode(data.as_bytes()) let result = HEXLOWER_PERMISSIVE_ALLOW_WS
.decode(data.as_bytes())
.with_context(|e| e.context(format!("invalid hex: {:?}", data)))?; .with_context(|e| e.context(format!("invalid hex: {:?}", data)))?;
*data = ""; *data = "";
Ok(HexRemainingBlob(result.into())) Ok(HexRemainingBlob(result.into()))
@ -279,7 +297,8 @@ impl DnsPacketData for HexRemainingBlobNotEmpty {
impl DnsTextData for HexRemainingBlobNotEmpty { impl DnsTextData for HexRemainingBlobNotEmpty {
fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> { fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> {
skip_whitespace(data); skip_whitespace(data);
let result = HEXLOWER_PERMISSIVE_ALLOW_WS.decode(data.as_bytes()) let result = HEXLOWER_PERMISSIVE_ALLOW_WS
.decode(data.as_bytes())
.with_context(|e| e.context(format!("invalid hex: {:?}", data)))?; .with_context(|e| e.context(format!("invalid hex: {:?}", data)))?;
*data = ""; *data = "";
failure::ensure!(!result.is_empty(), "must not be empty"); failure::ensure!(!result.is_empty(), "must not be empty");
@ -287,7 +306,9 @@ impl DnsTextData for HexRemainingBlobNotEmpty {
} }
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
if self.0.is_empty() { return Err(fmt::Error); } if self.0.is_empty() {
return Err(fmt::Error);
}
write!(f, "{}", HEXLOWER_PERMISSIVE_ALLOW_WS.encode(&self.0)) write!(f, "{}", HEXLOWER_PERMISSIVE_ALLOW_WS.encode(&self.0))
} }
} }

View File

@ -18,7 +18,6 @@ pub enum DnsSecAlgorithm {
DSA = 3, DSA = 3,
// Reserved: 4 [RFC6725] // Reserved: 4 [RFC6725]
/// RSA/SHA-1 /// RSA/SHA-1
// [RFC3110][RFC4034] // [RFC3110][RFC4034]
RSASHA1 = 5, RSASHA1 = 5,
@ -33,13 +32,11 @@ pub enum DnsSecAlgorithm {
RSASHA256 = 8, RSASHA256 = 8,
// Reserved: 9 [RFC6725] // Reserved: 9 [RFC6725]
/// RSA/SHA-512 /// RSA/SHA-512
// [RFC5702][proposed standard] // [RFC5702][proposed standard]
RSASHA512 = 10, RSASHA512 = 10,
// Reserved: 11 [RFC6725] // Reserved: 11 [RFC6725]
/// GOST R 34.10-2001 /// GOST R 34.10-2001
// [RFC5933][standards track] // [RFC5933][standards track]
ECC_GOST = 12, ECC_GOST = 12,
@ -64,7 +61,6 @@ pub enum DnsSecAlgorithm {
/// private algorithm OID /// private algorithm OID
// [RFC4034] // [RFC4034]
PRIVATEOID = 254, PRIVATEOID = 254,
// Reserved: 255 [RFC4034][proposed standard] // Reserved: 255 [RFC4034][proposed standard]
} }

View File

@ -1,7 +1,7 @@
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field}; use crate::ser::text::{next_field, DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::Bytes;
use std::fmt; use std::fmt;
use std::io::{Cursor, Read}; use std::io::{Cursor, Read};
@ -16,7 +16,11 @@ fn fmt_eui_hyphens(data: &[u8], f: &mut fmt::Formatter) -> fmt::Result {
fn parse_eui_hyphens(dest: &mut [u8], source: &str) -> Result<()> { fn parse_eui_hyphens(dest: &mut [u8], source: &str) -> Result<()> {
let mut pos = 0; let mut pos = 0;
for octet in source.split('-') { for octet in source.split('-') {
failure::ensure!(pos < dest.len(), "too many octets for EUI{}", dest.len() * 8); failure::ensure!(
pos < dest.len(),
"too many octets for EUI{}",
dest.len() * 8
);
failure::ensure!(octet.len() == 2, "invalid octet {:?}", octet); failure::ensure!(octet.len() == 2, "invalid octet {:?}", octet);
match u8::from_str_radix(octet, 16) { match u8::from_str_radix(octet, 16) {
Ok(o) => { Ok(o) => {
@ -28,7 +32,11 @@ fn parse_eui_hyphens(dest: &mut [u8], source: &str) -> Result<()> {
}, },
} }
} }
failure::ensure!(pos == dest.len(), "not enough octets for EUI{}", dest.len() * 8); failure::ensure!(
pos == dest.len(),
"not enough octets for EUI{}",
dest.len() * 8
);
Ok(()) Ok(())
} }

View File

@ -1,71 +1,35 @@
pub mod binary; pub mod binary;
pub mod classes;
pub mod name;
pub mod text;
pub mod types;
mod caa; mod caa;
pub mod classes;
mod dnssec; mod dnssec;
mod eui; mod eui;
pub mod name;
mod nsec; mod nsec;
mod nxt; mod nxt;
mod sig; mod sig;
mod sshfp; mod sshfp;
pub mod text;
mod time; mod time;
pub mod types;
mod uri; mod uri;
pub use self::binary::{ pub use self::binary::{
Base64LongBlob, Base64LongBlob, Base64RemainingBlob, HexRemainingBlob, HexRemainingBlobNotEmpty, HexShortBlob,
Base64RemainingBlob,
HexRemainingBlob,
HexRemainingBlobNotEmpty,
HexShortBlob,
}; };
pub use self::classes::Class;
pub use self::caa::CaaFlags; pub use self::caa::CaaFlags;
pub use self::classes::Class;
pub use self::dnssec::{ pub use self::dnssec::{
DnskeyFlags, DnsSecAlgorithm, DnsSecAlgorithmKnown, DnsSecDigestAlgorithm, DnsSecDigestAlgorithmKnown,
DnskeyProtocol, DnskeyFlags, DnskeyProtocol, DnskeyProtocolKnown, Nsec3Algorithm, Nsec3AlgorithmKnown,
DnskeyProtocolKnown, Nsec3Flags, Nsec3ParamFlags,
DnsSecAlgorithm,
DnsSecAlgorithmKnown,
DnsSecDigestAlgorithm,
DnsSecDigestAlgorithmKnown,
Nsec3Algorithm,
Nsec3AlgorithmKnown,
Nsec3Flags,
Nsec3ParamFlags,
};
pub use self::eui::{
EUI48Addr,
EUI64Addr,
};
pub use self::name::{
DnsCanonicalName,
DnsCompressedName,
DnsName,
};
pub use self::nsec::{
NextHashedOwnerName,
NsecTypeBitmap,
}; };
pub use self::eui::{EUI48Addr, EUI64Addr};
pub use self::name::{DnsCanonicalName, DnsCompressedName, DnsName};
pub use self::nsec::{NextHashedOwnerName, NsecTypeBitmap};
pub use self::nxt::NxtTypeBitmap; pub use self::nxt::NxtTypeBitmap;
pub use self::sig::OptionalTTL; pub use self::sig::OptionalTTL;
pub use self::sshfp::{ pub use self::sshfp::{SshFpAlgorithm, SshFpAlgorithmKnown, SshFpType, SshFpTypeKnown};
SshFpAlgorithm, pub use self::text::{LongText, RemainingText, ShortText, UnquotedShortText};
SshFpAlgorithmKnown, pub use self::time::{Time, Time48, TimeStrict};
SshFpType,
SshFpTypeKnown,
};
pub use self::text::{
LongText,
RemainingText,
ShortText,
UnquotedShortText,
};
pub use self::time::{
Time,
Time48,
TimeStrict,
};
pub use self::types::Type; pub use self::types::Type;
pub use self::uri::UriText; pub use self::uri::UriText;

View File

@ -1,13 +1,13 @@
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with}; use crate::ser::text::{next_field, parse_with, DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::Bytes;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::str::FromStr; use std::str::FromStr;
use super::{DnsName, DnsNameIterator, DnsLabelRef}; use super::{DnsLabelRef, DnsName, DnsNameIterator};
/// names that should be written in canonical form for DNSSEC according /// names that should be written in canonical form for DNSSEC according
/// to https://tools.ietf.org/html/rfc4034#section-6.2 /// to https://tools.ietf.org/html/rfc4034#section-6.2
@ -24,8 +24,7 @@ impl DnsCanonicalName {
} }
/// Parse text representation of a domain name /// Parse text representation of a domain name
pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self> pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self> {
{
Ok(DnsCanonicalName(DnsName::parse(context, value)?)) Ok(DnsCanonicalName(DnsName::parse(context, value)?))
} }
} }
@ -65,8 +64,7 @@ impl<'a> IntoIterator for &'a DnsCanonicalName {
} }
} }
impl PartialEq<DnsName> for DnsCanonicalName impl PartialEq<DnsName> for DnsCanonicalName {
{
fn eq(&self, rhs: &DnsName) -> bool { fn eq(&self, rhs: &DnsName) -> bool {
let this: &DnsName = self; let this: &DnsName = self;
this == rhs this == rhs
@ -75,7 +73,7 @@ impl PartialEq<DnsName> for DnsCanonicalName
impl<T> PartialEq<T> for DnsCanonicalName impl<T> PartialEq<T> for DnsCanonicalName
where where
T: AsRef<DnsName> T: AsRef<DnsName>,
{ {
fn eq(&self, rhs: &T) -> bool { fn eq(&self, rhs: &T) -> bool {
let this: &DnsName = self.as_ref(); let this: &DnsName = self.as_ref();
@ -83,7 +81,7 @@ where
} }
} }
impl Eq for DnsCanonicalName{} impl Eq for DnsCanonicalName {}
impl fmt::Debug for DnsCanonicalName { impl fmt::Debug for DnsCanonicalName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
@ -101,7 +99,9 @@ impl FromStr for DnsCanonicalName {
type Err = ::failure::Error; type Err = ::failure::Error;
fn from_str(s: &str) -> Result<Self> { fn from_str(s: &str) -> Result<Self> {
parse_with(s, |data| DnsCanonicalName::dns_parse(&DnsTextContext::new(), data)) parse_with(s, |data| {
DnsCanonicalName::dns_parse(&DnsTextContext::new(), data)
})
} }
} }
@ -118,7 +118,9 @@ impl DnsTextData for DnsCanonicalName {
impl DnsPacketData for DnsCanonicalName { impl DnsPacketData for DnsCanonicalName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> { fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
Ok(DnsCanonicalName(super::name_packet_parser::deserialize_name(data, false)?)) Ok(DnsCanonicalName(
super::name_packet_parser::deserialize_name(data, false)?,
))
} }
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {

View File

@ -1,13 +1,13 @@
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with}; use crate::ser::text::{next_field, parse_with, DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::Bytes;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::str::FromStr; use std::str::FromStr;
use super::{DnsName, DnsNameIterator, DnsLabelRef}; use super::{DnsLabelRef, DnsName, DnsNameIterator};
/// Similar to `DnsName`, but allows using compressed labels in the /// Similar to `DnsName`, but allows using compressed labels in the
/// serialized form /// serialized form
@ -21,8 +21,7 @@ impl DnsCompressedName {
} }
/// Parse text representation of a domain name /// Parse text representation of a domain name
pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self> pub fn parse(context: &DnsTextContext, value: &str) -> Result<Self> {
{
Ok(DnsCompressedName(DnsName::parse(context, value)?)) Ok(DnsCompressedName(DnsName::parse(context, value)?))
} }
} }
@ -62,8 +61,7 @@ impl<'a> IntoIterator for &'a DnsCompressedName {
} }
} }
impl PartialEq<DnsName> for DnsCompressedName impl PartialEq<DnsName> for DnsCompressedName {
{
fn eq(&self, rhs: &DnsName) -> bool { fn eq(&self, rhs: &DnsName) -> bool {
let this: &DnsName = self; let this: &DnsName = self;
this == rhs this == rhs
@ -72,7 +70,7 @@ impl PartialEq<DnsName> for DnsCompressedName
impl<T> PartialEq<T> for DnsCompressedName impl<T> PartialEq<T> for DnsCompressedName
where where
T: AsRef<DnsName> T: AsRef<DnsName>,
{ {
fn eq(&self, rhs: &T) -> bool { fn eq(&self, rhs: &T) -> bool {
let this: &DnsName = self.as_ref(); let this: &DnsName = self.as_ref();
@ -80,7 +78,7 @@ where
} }
} }
impl Eq for DnsCompressedName{} impl Eq for DnsCompressedName {}
impl fmt::Debug for DnsCompressedName { impl fmt::Debug for DnsCompressedName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
@ -98,7 +96,9 @@ impl FromStr for DnsCompressedName {
type Err = ::failure::Error; type Err = ::failure::Error;
fn from_str(s: &str) -> Result<Self> { fn from_str(s: &str) -> Result<Self> {
parse_with(s, |data| DnsCompressedName::dns_parse(&DnsTextContext::new(), data)) parse_with(s, |data| {
DnsCompressedName::dns_parse(&DnsTextContext::new(), data)
})
} }
} }
@ -115,7 +115,9 @@ impl DnsTextData for DnsCompressedName {
impl DnsPacketData for DnsCompressedName { impl DnsPacketData for DnsCompressedName {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> { fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
Ok(DnsCompressedName(super::name_packet_parser::deserialize_name(data, true)?)) Ok(DnsCompressedName(
super::name_packet_parser::deserialize_name(data, true)?,
))
} }
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {

View File

@ -1,5 +1,5 @@
use std::fmt;
use super::DnsLabelRef; use super::DnsLabelRef;
use std::fmt;
/// Customize formatting of DNS labels /// Customize formatting of DNS labels
/// ///
@ -7,7 +7,7 @@ use super::DnsLabelRef;
/// ///
/// The `Debug` formatters just format as `Display` to a String and then /// The `Debug` formatters just format as `Display` to a String and then
/// format that string as `Debug`. /// format that string as `Debug`.
#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct DisplayLabelsOptions { pub struct DisplayLabelsOptions {
/// separator to insert between (escaped) labels /// separator to insert between (escaped) labels
pub separator: &'static str, pub separator: &'static str,
@ -20,7 +20,7 @@ pub struct DisplayLabelsOptions {
impl Default for DisplayLabelsOptions { impl Default for DisplayLabelsOptions {
fn default() -> Self { fn default() -> Self {
DisplayLabelsOptions{ DisplayLabelsOptions {
separator: ".", separator: ".",
trailing: true, trailing: true,
} }
@ -38,7 +38,7 @@ impl Default for DisplayLabelsOptions {
#[derive(Clone)] #[derive(Clone)]
pub struct DisplayLabels<'a, I> pub struct DisplayLabels<'a, I>
where where
I: IntoIterator<Item=DnsLabelRef<'a>>+Clone I: IntoIterator<Item = DnsLabelRef<'a>> + Clone,
{ {
/// Label collection to iterate over /// Label collection to iterate over
pub labels: I, pub labels: I,
@ -48,7 +48,7 @@ where
impl<'a, I> fmt::Debug for DisplayLabels<'a, I> impl<'a, I> fmt::Debug for DisplayLabels<'a, I>
where where
I: IntoIterator<Item=DnsLabelRef<'a>>+Clone I: IntoIterator<Item = DnsLabelRef<'a>> + Clone,
{ {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
// just escape the display version1 // just escape the display version1
@ -58,7 +58,7 @@ where
impl<'a, I> fmt::Display for DisplayLabels<'a, I> impl<'a, I> fmt::Display for DisplayLabels<'a, I>
where where
I: IntoIterator<Item=DnsLabelRef<'a>>+Clone I: IntoIterator<Item = DnsLabelRef<'a>> + Clone,
{ {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
let mut i = self.labels.clone().into_iter(); let mut i = self.labels.clone().into_iter();

View File

@ -1,7 +1,7 @@
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use std::fmt; use bytes::Bytes;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::fmt;
#[inline] #[inline]
fn check_label(label: &[u8]) -> Result<()> { fn check_label(label: &[u8]) -> Result<()> {
@ -26,12 +26,14 @@ impl DnsLabel {
/// Fails when the length doesn't match the requirement `0 < length < 64`. /// Fails when the length doesn't match the requirement `0 < length < 64`.
pub fn new(label: Bytes) -> Result<Self> { pub fn new(label: Bytes) -> Result<Self> {
check_label(&label)?; check_label(&label)?;
Ok(DnsLabel{label}) Ok(DnsLabel { label })
} }
/// Convert to a representation without storage /// Convert to a representation without storage
pub fn as_ref<'a>(&'a self) -> DnsLabelRef<'a> { pub fn as_ref<'a>(&'a self) -> DnsLabelRef<'a> {
DnsLabelRef{label: self.label.as_ref()} DnsLabelRef {
label: self.label.as_ref(),
}
} }
/// Access as raw bytes /// Access as raw bytes
@ -58,7 +60,7 @@ impl DnsLabel {
impl<'a> From<DnsLabelRef<'a>> for DnsLabel { impl<'a> From<DnsLabelRef<'a>> for DnsLabel {
fn from(label_ref: DnsLabelRef<'a>) -> Self { fn from(label_ref: DnsLabelRef<'a>) -> Self {
DnsLabel{ DnsLabel {
label: Bytes::from(label_ref.label), label: Bytes::from(label_ref.label),
} }
} }
@ -78,7 +80,7 @@ impl<'a> PartialEq<DnsLabelRef<'a>> for DnsLabel {
} }
} }
impl Eq for DnsLabel{} impl Eq for DnsLabel {}
impl PartialOrd<DnsLabel> for DnsLabel { impl PartialOrd<DnsLabel> for DnsLabel {
#[inline] #[inline]
@ -116,7 +118,7 @@ impl fmt::Display for DnsLabel {
/// A DNS label (any binary string with `0 < length < 64`) /// A DNS label (any binary string with `0 < length < 64`)
/// ///
/// Storage is provided through lifetime. /// Storage is provided through lifetime.
#[derive(Clone,Copy)] #[derive(Clone, Copy)]
pub struct DnsLabelRef<'a> { pub struct DnsLabelRef<'a> {
pub(super) label: &'a [u8], // 0 < len < 64 pub(super) label: &'a [u8], // 0 < len < 64
} }
@ -127,7 +129,7 @@ impl<'a> DnsLabelRef<'a> {
/// Fails when the length doesn't match the requirement `0 < length < 64`. /// Fails when the length doesn't match the requirement `0 < length < 64`.
pub fn new(label: &'a [u8]) -> Result<Self> { pub fn new(label: &'a [u8]) -> Result<Self> {
check_label(label)?; check_label(label)?;
Ok(DnsLabelRef{label}) Ok(DnsLabelRef { label })
} }
/// Access as raw bytes /// Access as raw bytes
@ -180,7 +182,7 @@ impl<'a> PartialEq<DnsLabel> for DnsLabelRef<'a> {
} }
} }
impl<'a> Eq for DnsLabelRef<'a>{} impl<'a> Eq for DnsLabelRef<'a> {}
impl<'a, 'b> PartialOrd<DnsLabelRef<'a>> for DnsLabelRef<'b> { impl<'a, 'b> PartialOrd<DnsLabelRef<'a>> for DnsLabelRef<'b> {
#[inline] #[inline]
@ -218,7 +220,9 @@ impl<'a> fmt::Display for DnsLabelRef<'a> {
if c <= 0x21 || c >= 0x7e || b'.' == c || b'\\' == c { if c <= 0x21 || c >= 0x7e || b'.' == c || b'\\' == c {
// flush // flush
if done < pos { if done < pos {
w.write_str(crate::unsafe_ops::from_utf8_unchecked(&self.label[done..pos]))?; w.write_str(crate::unsafe_ops::from_utf8_unchecked(
&self.label[done..pos],
))?;
} }
match c { match c {
b'.' => w.write_str(r#"\."#)?, b'.' => w.write_str(r#"\."#)?,

View File

@ -1,6 +1,6 @@
use smallvec::SmallVec; use smallvec::SmallVec;
#[derive(Clone,Copy,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub enum LabelOffset { pub enum LabelOffset {
LabelStart(u8), LabelStart(u8),
PacketStart(u16), PacketStart(u16),
@ -8,10 +8,10 @@ pub enum LabelOffset {
// the heap meta data is usually at least 2*usize big; assuming 64-bit // the heap meta data is usually at least 2*usize big; assuming 64-bit
// platforms it should be ok to use 16 bytes in the smallvec. // platforms it should be ok to use 16 bytes in the smallvec.
#[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub enum LabelOffsets { pub enum LabelOffsets {
Uncompressed(SmallVec<[u8;16]>), Uncompressed(SmallVec<[u8; 16]>),
Compressed(usize, SmallVec<[LabelOffset;4]>), Compressed(usize, SmallVec<[LabelOffset; 4]>),
} }
impl LabelOffsets { impl LabelOffsets {
@ -31,7 +31,7 @@ impl LabelOffsets {
LabelOffsets::Compressed(start, ref offs) => match offs[ndx as usize] { LabelOffsets::Compressed(start, ref offs) => match offs[ndx as usize] {
LabelOffset::LabelStart(o) => start + (o as usize), LabelOffset::LabelStart(o) => start + (o as usize),
LabelOffset::PacketStart(o) => o as usize, LabelOffset::PacketStart(o) => o as usize,
} },
} }
} }

View File

@ -1,8 +1,8 @@
#![deny(missing_docs)] #![deny(missing_docs)]
//! Various structs to represents DNS names and labels //! Various structs to represents DNS names and labels
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use bytes::Bytes;
use smallvec::SmallVec; use smallvec::SmallVec;
use std::io::Cursor; use std::io::Cursor;
@ -10,9 +10,9 @@ pub use self::canonical_name::*;
pub use self::compressed_name::*; pub use self::compressed_name::*;
pub use self::display::*; pub use self::display::*;
pub use self::label::*; pub use self::label::*;
pub use self::name_iterator::*;
pub use self::name::*;
use self::label_offsets::*; use self::label_offsets::*;
pub use self::name::*;
pub use self::name_iterator::*;
mod canonical_name; mod canonical_name;
mod compressed_name; mod compressed_name;

View File

@ -1,13 +1,13 @@
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field, parse_with}; use crate::ser::text::{next_field, parse_with, DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::Bytes;
use smallvec::SmallVec; use smallvec::SmallVec;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
use std::str::FromStr; use std::str::FromStr;
use super::{LabelOffsets, DnsNameIterator, DnsLabelRef, DnsLabel, DisplayLabels}; use super::{DisplayLabels, DnsLabel, DnsLabelRef, DnsNameIterator, LabelOffsets};
/// A DNS name /// A DNS name
/// ///
@ -30,7 +30,7 @@ pub struct DnsName {
impl DnsName { impl DnsName {
/// Create new name representing the DNS root (".") /// Create new name representing the DNS root (".")
pub fn new_root() -> Self { pub fn new_root() -> Self {
DnsName{ DnsName {
data: Bytes::new(), data: Bytes::new(),
label_offsets: LabelOffsets::Uncompressed(SmallVec::new()), label_offsets: LabelOffsets::Uncompressed(SmallVec::new()),
total_len: 1, total_len: 1,
@ -40,7 +40,7 @@ impl DnsName {
/// Create new name representing the DNS root (".") and pre-allocate /// Create new name representing the DNS root (".") and pre-allocate
/// storage /// storage
pub fn with_capacity(labels: u8, total_len: u8) -> Self { pub fn with_capacity(labels: u8, total_len: u8) -> Self {
DnsName{ DnsName {
data: Bytes::with_capacity(total_len as usize), data: Bytes::with_capacity(total_len as usize),
label_offsets: LabelOffsets::Uncompressed(SmallVec::with_capacity(labels as usize)), label_offsets: LabelOffsets::Uncompressed(SmallVec::with_capacity(labels as usize)),
total_len: 1, total_len: 1,
@ -66,7 +66,7 @@ impl DnsName {
/// Iterator over the labels (in the order they are stored in memory, /// Iterator over the labels (in the order they are stored in memory,
/// i.e. top-level name last). /// i.e. top-level name last).
pub fn labels<'a>(&'a self) -> DnsNameIterator<'a> { pub fn labels<'a>(&'a self) -> DnsNameIterator<'a> {
DnsNameIterator{ DnsNameIterator {
name: &self, name: &self,
front_label: 0, front_label: 0,
back_label: self.label_offsets.len(), back_label: self.label_offsets.len(),
@ -83,7 +83,9 @@ impl DnsName {
let label_len = self.data[pos]; let label_len = self.data[pos];
debug_assert!(label_len < 64); debug_assert!(label_len < 64);
let end = pos + 1 + label_len as usize; let end = pos + 1 + label_len as usize;
DnsLabelRef{label: &self.data[pos + 1..end]} DnsLabelRef {
label: &self.data[pos + 1..end],
}
} }
/// Return label at index `ndx` /// Return label at index `ndx`
@ -96,7 +98,9 @@ impl DnsName {
let label_len = self.data[pos]; let label_len = self.data[pos];
debug_assert!(label_len < 64); debug_assert!(label_len < 64);
let end = pos + 1 + label_len as usize; let end = pos + 1 + label_len as usize;
DnsLabel{label: self.data.slice(pos + 1, end) } DnsLabel {
label: self.data.slice(pos + 1, end),
}
} }
} }
@ -109,42 +113,45 @@ impl<'a> IntoIterator for &'a DnsName {
} }
} }
impl PartialEq<DnsName> for DnsName impl PartialEq<DnsName> for DnsName {
{
fn eq(&self, rhs: &DnsName) -> bool { fn eq(&self, rhs: &DnsName) -> bool {
let a_labels = self.labels(); let a_labels = self.labels();
let b_labels = rhs.labels(); let b_labels = rhs.labels();
if a_labels.len() != b_labels.len() { return false; } if a_labels.len() != b_labels.len() {
a_labels.zip(b_labels).all(|(a,b)| a == b) return false;
}
a_labels.zip(b_labels).all(|(a, b)| a == b)
} }
} }
impl<T> PartialEq<T> for DnsName impl<T> PartialEq<T> for DnsName
where where
T: AsRef<DnsName> T: AsRef<DnsName>,
{ {
fn eq(&self, rhs: &T) -> bool { fn eq(&self, rhs: &T) -> bool {
self == rhs.as_ref() self == rhs.as_ref()
} }
} }
impl Eq for DnsName{} impl Eq for DnsName {}
impl fmt::Debug for DnsName { impl fmt::Debug for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{ DisplayLabels {
labels: self, labels: self,
options: Default::default(), options: Default::default(),
}.fmt(w) }
.fmt(w)
} }
} }
impl fmt::Display for DnsName { impl fmt::Display for DnsName {
fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result {
DisplayLabels{ DisplayLabels {
labels: self, labels: self,
options: Default::default(), options: Default::default(),
}.fmt(w) }
.fmt(w)
} }
} }

View File

@ -1,4 +1,4 @@
use super::{DnsName, DnsLabelRef}; use super::{DnsLabelRef, DnsName};
/// Iterator type for [`DnsName::labels`] /// Iterator type for [`DnsName::labels`]
/// ///
@ -14,7 +14,9 @@ impl<'a> Iterator for DnsNameIterator<'a> {
type Item = DnsLabelRef<'a>; type Item = DnsLabelRef<'a>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.front_label >= self.back_label { return None } if self.front_label >= self.back_label {
return None;
}
let label = self.name.label_ref(self.front_label); let label = self.name.label_ref(self.front_label);
self.front_label += 1; self.front_label += 1;
Some(label) Some(label)
@ -38,7 +40,9 @@ impl<'a> ExactSizeIterator for DnsNameIterator<'a> {
impl<'a> DoubleEndedIterator for DnsNameIterator<'a> { impl<'a> DoubleEndedIterator for DnsNameIterator<'a> {
fn next_back(&mut self) -> Option<Self::Item> { fn next_back(&mut self) -> Option<Self::Item> {
if self.front_label >= self.back_label { return None } if self.front_label >= self.back_label {
return None;
}
self.back_label -= 1; self.back_label -= 1;
let label = self.name.label_ref(self.back_label); let label = self.name.label_ref(self.back_label);
Some(label) Some(label)

View File

@ -1,5 +1,5 @@
use bytes::{BytesMut,BufMut};
use super::*; use super::*;
use bytes::{BufMut, BytesMut};
impl DnsName { impl DnsName {
/// Remove the front label /// Remove the front label
@ -8,12 +8,16 @@ impl DnsName {
pub fn try_pop_front(&mut self) -> bool { pub fn try_pop_front(&mut self) -> bool {
match self.label_offsets { match self.label_offsets {
LabelOffsets::Uncompressed(ref mut offs) => { LabelOffsets::Uncompressed(ref mut offs) => {
if offs.is_empty() { return false; } if offs.is_empty() {
return false;
}
self.total_len -= self.data[offs[0] as usize] + 1; self.total_len -= self.data[offs[0] as usize] + 1;
offs.remove(0); offs.remove(0);
}, },
LabelOffsets::Compressed(ref mut start_pos, ref mut offs) => { LabelOffsets::Compressed(ref mut start_pos, ref mut offs) => {
if offs.is_empty() { return false; } if offs.is_empty() {
return false;
}
match offs[0] { match offs[0] {
LabelOffset::LabelStart(o) => { LabelOffset::LabelStart(o) => {
let label_space = self.data[*start_pos + o as usize] + 1; let label_space = self.data[*start_pos + o as usize] + 1;
@ -36,7 +40,9 @@ impl DnsName {
/// ///
/// Panics if the name was the root (".") /// Panics if the name was the root (".")
pub fn pop_front(&mut self) { pub fn pop_front(&mut self) {
if !self.try_pop_front() { panic!("Cannot pop label from root name") } if !self.try_pop_front() {
panic!("Cannot pop label from root name")
}
} }
/// Insert a new label at the front /// Insert a new label at the front
@ -44,12 +50,15 @@ impl DnsName {
/// Returns an error if the resulting name would be too long /// Returns an error if the resulting name would be too long
pub fn push_front<'a, L: Into<DnsLabelRef<'a>>>(&mut self, label: L) -> Result<()> { pub fn push_front<'a, L: Into<DnsLabelRef<'a>>>(&mut self, label: L) -> Result<()> {
let label = label.into(); let label = label.into();
if label.len() > 254 - self.total_len { failure::bail!("Cannot append label, resulting name too long") } if label.len() > 254 - self.total_len {
failure::bail!("Cannot append label, resulting name too long")
}
let (mut data, start) = self.reserve(label.len() as usize + 1, 0); let (mut data, start) = self.reserve(label.len() as usize + 1, 0);
let new_label_pos = start - (label.len() + 1) as usize; let new_label_pos = start - (label.len() + 1) as usize;
data[new_label_pos] = label.len(); data[new_label_pos] = label.len();
data[new_label_pos+1..new_label_pos+1+label.len() as usize].copy_from_slice(label.as_raw()); data[new_label_pos + 1..new_label_pos + 1 + label.len() as usize]
.copy_from_slice(label.as_raw());
self.data = data.freeze(); self.data = data.freeze();
self.total_len += label.len() + 1; self.total_len += label.len() + 1;
match self.label_offsets { match self.label_offsets {
@ -68,13 +77,17 @@ impl DnsName {
pub fn try_pop_back(&mut self) -> bool { pub fn try_pop_back(&mut self) -> bool {
match self.label_offsets { match self.label_offsets {
LabelOffsets::Uncompressed(ref mut offs) => { LabelOffsets::Uncompressed(ref mut offs) => {
if offs.is_empty() { return false; } if offs.is_empty() {
self.total_len -= self.data[offs[offs.len()-1] as usize] + 1; return false;
}
self.total_len -= self.data[offs[offs.len() - 1] as usize] + 1;
offs.pop(); offs.pop();
}, },
LabelOffsets::Compressed(ref mut start_pos, ref mut offs) => { LabelOffsets::Compressed(ref mut start_pos, ref mut offs) => {
if offs.is_empty() { return false; } if offs.is_empty() {
match offs[offs.len()-1] { return false;
}
match offs[offs.len() - 1] {
LabelOffset::LabelStart(o) => { LabelOffset::LabelStart(o) => {
self.total_len -= self.data[*start_pos + o as usize] + 1; self.total_len -= self.data[*start_pos + o as usize] + 1;
}, },
@ -94,7 +107,9 @@ impl DnsName {
/// ///
/// Panics if the name was the root (".") /// Panics if the name was the root (".")
pub fn pop_back(&mut self) { pub fn pop_back(&mut self) {
if !self.try_pop_back() { panic!("Cannot pop label from root name") } if !self.try_pop_back() {
panic!("Cannot pop label from root name")
}
} }
/// Insert a new label at the back /// Insert a new label at the back
@ -102,13 +117,16 @@ impl DnsName {
/// Returns an error if the resulting name would be too long /// Returns an error if the resulting name would be too long
pub fn push_back<'a, L: Into<DnsLabelRef<'a>>>(&mut self, label: L) -> Result<()> { pub fn push_back<'a, L: Into<DnsLabelRef<'a>>>(&mut self, label: L) -> Result<()> {
let label = label.into(); let label = label.into();
if label.len() > 254 - self.total_len { failure::bail!("Cannot append label, resulting name too long") } if label.len() > 254 - self.total_len {
failure::bail!("Cannot append label, resulting name too long")
}
let (mut data, start) = self.reserve(0, label.len() as usize + 1); let (mut data, start) = self.reserve(0, label.len() as usize + 1);
let new_label_pos = start + self.total_len as usize - 1; let new_label_pos = start + self.total_len as usize - 1;
data[new_label_pos] = label.len(); data[new_label_pos] = label.len();
data[new_label_pos+1..new_label_pos+1+label.len() as usize].copy_from_slice(label.as_raw()); data[new_label_pos + 1..new_label_pos + 1 + label.len() as usize]
data[new_label_pos+1+label.len() as usize] = 0; .copy_from_slice(label.as_raw());
data[new_label_pos + 1 + label.len() as usize] = 0;
self.data = data.freeze(); self.data = data.freeze();
self.total_len += label.len() + 1; self.total_len += label.len() + 1;
match self.label_offsets { match self.label_offsets {
@ -145,15 +163,18 @@ impl DnsName {
let add = new_len - data.len(); let add = new_len - data.len();
data.reserve(add); data.reserve(add);
} }
unsafe { data.set_len(new_len); } unsafe {
data.set_len(new_len);
}
data[0] = 0; data[0] = 0;
return (data, 0) return (data, 0);
} }
let old_start = label_offsets[0] as usize; let old_start = label_offsets[0] as usize;
// if current "prefix" space (old_start) is bigger than // if current "prefix" space (old_start) is bigger than
// requested but fits, just increase the prefix // requested but fits, just increase the prefix
let (prefix, new_len) = if old_start > prefix && self.total_len as usize + old_start + suffix < 256 { let (prefix, new_len) =
if old_start > prefix && self.total_len as usize + old_start + suffix < 256 {
(old_start, self.total_len as usize + old_start + suffix) (old_start, self.total_len as usize + old_start + suffix)
} else { } else {
(prefix, new_len) (prefix, new_len)
@ -173,13 +194,14 @@ impl DnsName {
Err(data) => Err(data), Err(data) => Err(data),
}; };
match data { match data {
Ok(mut data) => { Ok(mut data) => {
if data.len() < new_len { if data.len() < new_len {
let add = new_len - data.len(); let add = new_len - data.len();
data.reserve(add); data.reserve(add);
unsafe { data.set_len(new_len); } unsafe {
data.set_len(new_len);
}
} }
if old_start < prefix { if old_start < prefix {
// need more space in front, move back // need more space in front, move back
@ -213,9 +235,12 @@ impl DnsName {
}, },
Err(data) => { Err(data) => {
let mut new_data = BytesMut::with_capacity(new_len); let mut new_data = BytesMut::with_capacity(new_len);
unsafe { new_data.set_len(new_len); } unsafe {
new_data.set_len(new_len);
}
// copy old data // copy old data
new_data[prefix..prefix + self.total_len as usize].copy_from_slice(&data[old_start..old_start+self.total_len as usize]); new_data[prefix..prefix + self.total_len as usize]
.copy_from_slice(&data[old_start..old_start + self.total_len as usize]);
// adjust labels // adjust labels
for o in label_offsets.iter_mut() { for o in label_offsets.iter_mut() {
*o = (*o - old_start as u8) + prefix as u8; *o = (*o - old_start as u8) + prefix as u8;
@ -246,7 +271,7 @@ impl DnsName {
let new_len = self.total_len as usize + prefix_capacity + suffix_capacity; let new_len = self.total_len as usize + prefix_capacity + suffix_capacity;
assert!(new_len < 256); assert!(new_len < 256);
let mut data = BytesMut::with_capacity(new_len); let mut data = BytesMut::with_capacity(new_len);
let mut offsets = SmallVec::<[u8;16]>::with_capacity(label_count); let mut offsets = SmallVec::<[u8; 16]>::with_capacity(label_count);
unsafe { data.set_len(prefix_capacity) } unsafe { data.set_len(prefix_capacity) }
let mut pos = prefix_capacity as u8; let mut pos = prefix_capacity as u8;
@ -258,7 +283,7 @@ impl DnsName {
} }
data.put_u8(0); data.put_u8(0);
DnsName{ DnsName {
data: data.freeze(), data: data.freeze(),
label_offsets: LabelOffsets::Uncompressed(offsets), label_offsets: LabelOffsets::Uncompressed(offsets),
total_len: self.total_len, total_len: self.total_len,

View File

@ -1,13 +1,20 @@
use bytes::Buf;
use super::*; use super::*;
use bytes::Buf;
/// `data`: bytes of packet from beginning until at least the end of the name /// `data`: bytes of packet from beginning until at least the end of the name
/// `start_pos`: position of first byte of the name /// `start_pos`: position of first byte of the name
/// `uncmpr_offsets`: offsets of uncompressed labels so far /// `uncmpr_offsets`: offsets of uncompressed labels so far
/// `label_len`: first compressed label length (`0xc0 | offset-high, offset-low`) /// `label_len`: first compressed label length (`0xc0 | offset-high, offset-low`)
/// `total_len`: length of (uncompressed) label encoding so far /// `total_len`: length of (uncompressed) label encoding so far
fn deserialize_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offsets: SmallVec<[u8;16]>, mut total_len: usize, mut label_len: u8) -> Result<DnsName> { fn deserialize_name_compressed_cont(
let mut label_offsets = uncmpr_offsets.into_iter() data: Bytes,
start_pos: usize,
uncmpr_offsets: SmallVec<[u8; 16]>,
mut total_len: usize,
mut label_len: u8,
) -> Result<DnsName> {
let mut label_offsets = uncmpr_offsets
.into_iter()
.map(LabelOffset::LabelStart) .map(LabelOffset::LabelStart)
.collect::<SmallVec<_>>(); .collect::<SmallVec<_>>();
@ -16,7 +23,12 @@ fn deserialize_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offset
{ {
failure::ensure!(pos + 1 < data.len(), "not enough data for compressed label"); failure::ensure!(pos + 1 < data.len(), "not enough data for compressed label");
let new_pos = ((label_len as usize & 0x3f) << 8) | (data[pos + 1] as usize); let new_pos = ((label_len as usize & 0x3f) << 8) | (data[pos + 1] as usize);
failure::ensure!(new_pos < pos, "Compressed label offset too big: {} >= {}", new_pos, pos); failure::ensure!(
new_pos < pos,
"Compressed label offset too big: {} >= {}",
new_pos,
pos
);
pos = new_pos; pos = new_pos;
} }
@ -25,19 +37,23 @@ fn deserialize_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offset
label_len = data[pos]; label_len = data[pos];
if 0 == label_len { if 0 == label_len {
return Ok(DnsName{ return Ok(DnsName {
data: data, data: data,
label_offsets: LabelOffsets::Compressed(start_pos, label_offsets), label_offsets: LabelOffsets::Compressed(start_pos, label_offsets),
total_len: total_len as u8 + 1, total_len: total_len as u8 + 1,
}) });
} }
if label_len & 0xc0 == 0xc0 { continue 'next_compressed; } if label_len & 0xc0 == 0xc0 {
continue 'next_compressed;
}
failure::ensure!(label_len < 64, "Invalid label length {}", label_len); failure::ensure!(label_len < 64, "Invalid label length {}", label_len);
total_len += 1 + label_len as usize; total_len += 1 + label_len as usize;
// max len 255, but there also needs to be an empty label at the end // max len 255, but there also needs to be an empty label at the end
if total_len > 254 { failure::bail!("DNS name too long") } if total_len > 254 {
failure::bail!("DNS name too long")
}
label_offsets.push(LabelOffset::PacketStart(pos as u16)); label_offsets.push(LabelOffset::PacketStart(pos as u16));
pos += 1 + label_len as usize; pos += 1 + label_len as usize;
@ -48,22 +64,24 @@ fn deserialize_name_compressed_cont(data: Bytes, start_pos: usize, uncmpr_offset
pub fn deserialize_name(data: &mut Cursor<Bytes>, accept_compressed: bool) -> Result<DnsName> { pub fn deserialize_name(data: &mut Cursor<Bytes>, accept_compressed: bool) -> Result<DnsName> {
check_enough_data!(data, 1, "DnsName"); check_enough_data!(data, 1, "DnsName");
let start_pos = data.position() as usize; let start_pos = data.position() as usize;
let mut total_len : usize = 0; let mut total_len: usize = 0;
let mut label_offsets = SmallVec::new(); let mut label_offsets = SmallVec::new();
loop { loop {
check_enough_data!(data, 1, "DnsName label len"); check_enough_data!(data, 1, "DnsName label len");
let label_len = data.get_u8() as usize; let label_len = data.get_u8() as usize;
if 0 == label_len { if 0 == label_len {
let end_pos = data.position() as usize; let end_pos = data.position() as usize;
return Ok(DnsName{ return Ok(DnsName {
data: data.get_ref().slice(start_pos, end_pos), data: data.get_ref().slice(start_pos, end_pos),
label_offsets: LabelOffsets::Uncompressed(label_offsets), label_offsets: LabelOffsets::Uncompressed(label_offsets),
total_len: total_len as u8 + 1, total_len: total_len as u8 + 1,
}) });
} }
if label_len & 0xc0 == 0xc0 { if label_len & 0xc0 == 0xc0 {
// compressed label // compressed label
if !accept_compressed { failure::bail!("Invalid label compression {}", label_len) } if !accept_compressed {
failure::bail!("Invalid label compression {}", label_len)
}
check_enough_data!(data, 1, "DnsName compressed label target"); check_enough_data!(data, 1, "DnsName compressed label target");
// eat second part of compressed label // eat second part of compressed label
data.get_u8(); data.get_u8();
@ -71,13 +89,23 @@ pub fn deserialize_name(data: &mut Cursor<Bytes>, accept_compressed: bool) -> Re
let end_pos = data.position() as usize; let end_pos = data.position() as usize;
let data = data.get_ref().slice(0, end_pos); let data = data.get_ref().slice(0, end_pos);
return deserialize_name_compressed_cont(data, start_pos, label_offsets, total_len, label_len as u8); return deserialize_name_compressed_cont(
data,
start_pos,
label_offsets,
total_len,
label_len as u8,
);
} }
label_offsets.push(total_len as u8); label_offsets.push(total_len as u8);
if label_len > 63 { failure::bail!("Invalid label length {}", label_len) } if label_len > 63 {
failure::bail!("Invalid label length {}", label_len)
}
total_len += 1 + label_len; total_len += 1 + label_len;
// max len 255, but there also needs to be an empty label at the end // max len 255, but there also needs to be an empty label at the end
if total_len > 254 { failure::bail!{"DNS name too long"} } if total_len > 254 {
failure::bail! {"DNS name too long"}
}
check_enough_data!(data, label_len, "DnsName label"); check_enough_data!(data, label_len, "DnsName label");
data.advance(label_len); data.advance(label_len);
} }

View File

@ -1,11 +1,10 @@
use crate::errors::*; use crate::errors::*;
use crate::ser::text::{DnsTextContext, quoted}; use crate::ser::text::{quoted, DnsTextContext};
use super::{DnsName, DnsLabelRef}; use super::{DnsLabelRef, DnsName};
/// Parse text representation of a domain name /// Parse text representation of a domain name
pub fn parse_name(context: &DnsTextContext, value: &str) -> Result<DnsName> pub fn parse_name(context: &DnsTextContext, value: &str) -> Result<DnsName> {
{
let raw = value.as_bytes(); let raw = value.as_bytes();
let mut name = DnsName::new_root(); let mut name = DnsName::new_root();
if raw == b"." { if raw == b"." {
@ -25,23 +24,42 @@ pub fn parse_name(context: &DnsTextContext, value: &str) -> Result<DnsName>
name.push_back(DnsLabelRef::new(&label)?)?; name.push_back(DnsLabelRef::new(&label)?)?;
label.clear(); label.clear();
} else if raw[pos] == b'\\' { } else if raw[pos] == b'\\' {
failure::ensure!(pos + 1 < raw.len(), "unexpected end of name after backslash: {:?}", value); failure::ensure!(
if raw[pos+1] >= b'0' && raw[pos+1] <= b'9' { pos + 1 < raw.len(),
"unexpected end of name after backslash: {:?}",
value
);
if raw[pos + 1] >= b'0' && raw[pos + 1] <= b'9' {
// \ddd escape // \ddd escape
failure::ensure!(pos + 3 < raw.len(), "unexpected end of name after backslash with digit: {:?}", value); failure::ensure!(
failure::ensure!(raw[pos+2] >= b'0' && raw[pos+2] <= b'9' && raw[pos+3] >= b'0' && raw[pos+3] <= b'9', "expected three digits after backslash in name: {:?}", name); pos + 3 < raw.len(),
let d1 = (raw[pos+1] - b'0') as u32; "unexpected end of name after backslash with digit: {:?}",
let d2 = (raw[pos+2] - b'0') as u32; value
let d3 = (raw[pos+3] - b'0') as u32; );
failure::ensure!(
raw[pos + 2] >= b'0'
&& raw[pos + 2] <= b'9' && raw[pos + 3] >= b'0'
&& raw[pos + 3] <= b'9',
"expected three digits after backslash in name: {:?}",
name
);
let d1 = (raw[pos + 1] - b'0') as u32;
let d2 = (raw[pos + 2] - b'0') as u32;
let d3 = (raw[pos + 3] - b'0') as u32;
let v = d1 * 100 + d2 * 10 + d3; let v = d1 * 100 + d2 * 10 + d3;
failure::ensure!(v < 256, "invalid escape in name, {} > 255: {:?}", v, name); failure::ensure!(v < 256, "invalid escape in name, {} > 255: {:?}", v, name);
label.push(v as u8); label.push(v as u8);
} else { } else {
failure::ensure!(!quoted::is_ascii_whitespace(raw[pos+1]), "whitespace cannot be escaped with backslash prefix; encode it as \\{:03} in: {:?}", raw[pos+1], name); failure::ensure!(!quoted::is_ascii_whitespace(raw[pos+1]), "whitespace cannot be escaped with backslash prefix; encode it as \\{:03} in: {:?}", raw[pos+1], name);
label.push(raw[pos+1]); label.push(raw[pos + 1]);
} }
} else { } else {
failure::ensure!(!quoted::is_ascii_whitespace(raw[pos]), "whitespace must be encoded as \\{:03} in: {:?}", raw[pos], name); failure::ensure!(
!quoted::is_ascii_whitespace(raw[pos]),
"whitespace must be encoded as \\{:03} in: {:?}",
raw[pos],
name
);
label.push(raw[pos]); label.push(raw[pos]);
} }
pos += 1; pos += 1;
@ -55,7 +73,9 @@ pub fn parse_name(context: &DnsTextContext, value: &str) -> Result<DnsName>
match context.origin() { match context.origin() {
Some(o) => { Some(o) => {
for l in o { name.push_back(l)?; } for l in o {
name.push_back(l)?;
}
}, },
None => failure::bail!("missing trailing dot without $ORIGIN"), None => failure::bail!("missing trailing dot without $ORIGIN"),
} }

View File

@ -1,10 +1,10 @@
use bytes::Bytes; use crate::errors::*;
use crate::ser::packet; use crate::ser::packet;
use crate::ser::packet::DnsPacketData; use crate::ser::packet::DnsPacketData;
use bytes::Bytes;
use std::io::Cursor; use std::io::Cursor;
use crate::errors::*;
use super::{DnsName, DnsCompressedName, DnsLabelRef, DisplayLabels, DisplayLabelsOptions}; use super::{DisplayLabels, DisplayLabelsOptions, DnsCompressedName, DnsLabelRef, DnsName};
/* /*
fn deserialize(bytes: &'static [u8]) -> Result<DnsName> { fn deserialize(bytes: &'static [u8]) -> Result<DnsName> {
@ -25,40 +25,20 @@ fn de_uncompressed(bytes: &'static [u8]) -> Result<DnsName> {
fn check_uncompressed_display(bytes: &'static [u8], txt: &str, label_count: u8) { fn check_uncompressed_display(bytes: &'static [u8], txt: &str, label_count: u8) {
let name = de_uncompressed(bytes).unwrap(); let name = de_uncompressed(bytes).unwrap();
assert_eq!( assert_eq!(name.labels().count(), label_count as usize);
name.labels().count(), assert_eq!(format!("{}", name), txt);
label_count as usize
);
assert_eq!(
format!("{}", name),
txt
);
} }
fn check_uncompressed_debug(bytes: &'static [u8], txt: &str) { fn check_uncompressed_debug(bytes: &'static [u8], txt: &str) {
let name = de_uncompressed(bytes).unwrap(); let name = de_uncompressed(bytes).unwrap();
assert_eq!( assert_eq!(format!("{:?}", name), txt);
format!("{:?}", name),
txt
);
} }
#[test] #[test]
fn parse_and_display_name() { fn parse_and_display_name() {
check_uncompressed_display( check_uncompressed_display(b"\x07example\x03com\x00", "example.com.", 2);
b"\x07example\x03com\x00", check_uncompressed_display(b"\x07e!am.l\\\x03com\x00", "e\\033am\\.l\\\\.com.", 2);
"example.com.", check_uncompressed_debug(b"\x07e!am.l\\\x03com\x00", r#""e\\033am\\.l\\\\.com.""#);
2,
);
check_uncompressed_display(
b"\x07e!am.l\\\x03com\x00",
"e\\033am\\.l\\\\.com.",
2,
);
check_uncompressed_debug(
b"\x07e!am.l\\\x03com\x00",
r#""e\\033am\\.l\\\\.com.""#,
);
} }
#[test] #[test]
@ -67,9 +47,9 @@ fn parse_and_reverse_name() {
assert_eq!( assert_eq!(
format!( format!(
"{}", "{}",
DisplayLabels{ DisplayLabels {
labels: name.labels().rev(), labels: name.labels().rev(),
options: DisplayLabelsOptions{ options: DisplayLabelsOptions {
separator: " ", separator: " ",
trailing: false, trailing: false,
}, },
@ -84,41 +64,24 @@ fn modifications() {
let mut name = de_uncompressed(b"\x07example\x03com\x00").unwrap(); let mut name = de_uncompressed(b"\x07example\x03com\x00").unwrap();
name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap(); name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "www.example.com.");
format!("{}", name),
"www.example.com."
);
name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap(); name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "www.example.com.org.");
format!("{}", name),
"www.example.com.org."
);
name.pop_front(); name.pop_front();
assert_eq!( assert_eq!(format!("{}", name), "example.com.org.");
format!("{}", name),
"example.com.org."
);
name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap(); name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "mx.example.com.org.");
format!("{}", name),
"mx.example.com.org."
);
// the "mx" label should fit into the place "www" used before, // the "mx" label should fit into the place "www" used before,
// make sure the buffer was reused and the name not moved within // make sure the buffer was reused and the name not moved within
assert_eq!(1, name.label_offsets.label_pos(0)); assert_eq!(1, name.label_offsets.label_pos(0));
name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap(); name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "mx.example.com.org.com.");
format!("{}", name),
"mx.example.com.org.com."
);
} }
fn de_compressed(bytes: &'static [u8], offset: usize) -> Result<DnsCompressedName> { fn de_compressed(bytes: &'static [u8], offset: usize) -> Result<DnsCompressedName> {
use bytes::Buf; use bytes::Buf;
@ -133,22 +96,13 @@ fn de_compressed(bytes: &'static [u8], offset: usize) -> Result<DnsCompressedNam
fn check_compressed_display(bytes: &'static [u8], offset: usize, txt: &str, label_count: u8) { fn check_compressed_display(bytes: &'static [u8], offset: usize, txt: &str, label_count: u8) {
let name = de_compressed(bytes, offset).unwrap(); let name = de_compressed(bytes, offset).unwrap();
assert_eq!( assert_eq!(name.labels().count(), label_count as usize);
name.labels().count(), assert_eq!(format!("{}", name), txt);
label_count as usize
);
assert_eq!(
format!("{}", name),
txt
);
} }
fn check_compressed_debug(bytes: &'static [u8], offset: usize, txt: &str) { fn check_compressed_debug(bytes: &'static [u8], offset: usize, txt: &str) {
let name = de_compressed(bytes, offset).unwrap(); let name = de_compressed(bytes, offset).unwrap();
assert_eq!( assert_eq!(format!("{:?}", name), txt);
format!("{:?}", name),
txt
);
} }
#[test] #[test]
@ -159,22 +113,21 @@ fn parse_invalid_compressed_name() {
#[test] #[test]
fn parse_and_display_compressed_name() { fn parse_and_display_compressed_name() {
check_compressed_display(b"\x03com\x00\x07example\xc0\x00", 5, "example.com.", 2);
check_compressed_display( check_compressed_display(
b"\x03com\x00\x07example\xc0\x00", 5, b"\x03com\x00\x07e!am.l\\\xc0\x00",
"example.com.", 5,
2,
);
check_compressed_display(
b"\x03com\x00\x07e!am.l\\\xc0\x00", 5,
"e\\033am\\.l\\\\.com.", "e\\033am\\.l\\\\.com.",
2, 2,
); );
check_compressed_debug( check_compressed_debug(
b"\x03com\x00\x07e!am.l\\\xc0\x00", 5, b"\x03com\x00\x07e!am.l\\\xc0\x00",
5,
r#""e\\033am\\.l\\\\.com.""#, r#""e\\033am\\.l\\\\.com.""#,
); );
check_compressed_display( check_compressed_display(
b"\x03com\x00\x07example\xc0\x00\x03www\xc0\x05", 15, b"\x03com\x00\x07example\xc0\x00\x03www\xc0\x05",
15,
"www.example.com.", "www.example.com.",
3, 3,
); );
@ -185,35 +138,20 @@ fn modifications_compressed() {
let mut name = de_compressed(b"\x03com\x00\x07example\xc0\x00\xc0\x05", 15).unwrap(); let mut name = de_compressed(b"\x03com\x00\x07example\xc0\x00\xc0\x05", 15).unwrap();
name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap(); name.push_front(DnsLabelRef::new(b"www").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "www.example.com.");
format!("{}", name),
"www.example.com."
);
name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap(); name.push_back(DnsLabelRef::new(b"org").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "www.example.com.org.");
format!("{}", name),
"www.example.com.org."
);
name.pop_front(); name.pop_front();
assert_eq!( assert_eq!(format!("{}", name), "example.com.org.");
format!("{}", name),
"example.com.org."
);
name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap(); name.push_front(DnsLabelRef::new(b"mx").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "mx.example.com.org.");
format!("{}", name),
"mx.example.com.org."
);
// the "mx" label should fit into the place "www" used before, // the "mx" label should fit into the place "www" used before,
// make sure the buffer was reused and the name not moved within // make sure the buffer was reused and the name not moved within
assert_eq!(1, name.label_offsets.label_pos(0)); assert_eq!(1, name.label_offsets.label_pos(0));
name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap(); name.push_back(DnsLabelRef::new(b"com").unwrap()).unwrap();
assert_eq!( assert_eq!(format!("{}", name), "mx.example.com.org.com.");
format!("{}", name),
"mx.example.com.org.com."
);
} }

View File

@ -1,17 +1,21 @@
use bytes::{Bytes, Buf, BufMut};
use crate::common_types::Type; use crate::common_types::Type;
use data_encoding;
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{
remaining_bytes, short_blob, write_short_blob, DnsPacketData, DnsPacketWriteContext,
};
use crate::ser::text::{
next_field, skip_whitespace, DnsTextContext, DnsTextData, DnsTextFormatter,
};
use bytes::{Buf, BufMut, Bytes};
use data_encoding;
use failure::{Fail, ResultExt}; use failure::{Fail, ResultExt};
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes, short_blob, write_short_blob};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, skip_whitespace, next_field};
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
static WHITESPACE: &str = "\t\n\x0c\r "; // \f == \x0c formfeed static WHITESPACE: &str = "\t\n\x0c\r "; // \f == \x0c formfeed
lazy_static::lazy_static!{ lazy_static::lazy_static! {
static ref BASE32HEX_NOPAD_ALLOW_WS: data_encoding::Encoding = { static ref BASE32HEX_NOPAD_ALLOW_WS: data_encoding::Encoding = {
let mut spec = data_encoding::Specification::new(); let mut spec = data_encoding::Specification::new();
spec.symbols.push_str("0123456789ABCDEFGHIJKLMNOPQRSTUV"); spec.symbols.push_str("0123456789ABCDEFGHIJKLMNOPQRSTUV");
@ -79,7 +83,12 @@ impl DnsPacketData for NsecTypeBitmap {
let mut prev_window = None; let mut prev_window = None;
while data.has_remaining() { while data.has_remaining() {
let window_base = (data.get_u8() as u16) << 8; let window_base = (data.get_u8() as u16) << 8;
failure::ensure!(Some(window_base) > prev_window, "wrong nsec bitmap window order, {:?} <= {:?}", Some(window_base), prev_window); failure::ensure!(
Some(window_base) > prev_window,
"wrong nsec bitmap window order, {:?} <= {:?}",
Some(window_base),
prev_window
);
prev_window = Some(window_base); prev_window = Some(window_base);
check_enough_data!(data, 1, "nsec bitmap window length"); check_enough_data!(data, 1, "nsec bitmap window length");
let window_len = data.get_u8() as u16; let window_len = data.get_u8() as u16;
@ -89,16 +98,13 @@ impl DnsPacketData for NsecTypeBitmap {
let mut v = data.get_u8(); let mut v = data.get_u8();
for j in 0..8 { for j in 0..8 {
if 0 != v & 0x80 { if 0 != v & 0x80 {
set.insert(Type(window_base + i*8 + j)); set.insert(Type(window_base + i * 8 + j));
} }
v <<= 1; v <<= 1;
} }
} }
} }
Ok(NsecTypeBitmap{ Ok(NsecTypeBitmap { raw: raw, set: set })
raw: raw,
set: set,
})
} }
fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
@ -150,15 +156,21 @@ impl DnsPacketData for NextHashedOwnerName {
impl DnsTextData for NextHashedOwnerName { impl DnsTextData for NextHashedOwnerName {
fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> { fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> {
let field = next_field(data)?; let field = next_field(data)?;
let raw = BASE32HEX_NOPAD_ALLOW_WS.decode(field.as_bytes()) let raw = BASE32HEX_NOPAD_ALLOW_WS
.decode(field.as_bytes())
.with_context(|e| e.context(format!("invalid base32hex (no padding): {:?}", field)))?; .with_context(|e| e.context(format!("invalid base32hex (no padding): {:?}", field)))?;
failure::ensure!(raw.len() > 0, "NextHashedOwnerName must not be empty"); failure::ensure!(raw.len() > 0, "NextHashedOwnerName must not be empty");
failure::ensure!(raw.len() < 256, "NextHashedOwnerName field must be at most 255 bytes long"); failure::ensure!(
raw.len() < 256,
"NextHashedOwnerName field must be at most 255 bytes long"
);
Ok(NextHashedOwnerName(raw.into())) Ok(NextHashedOwnerName(raw.into()))
} }
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
if self.0.is_empty() { return Err(fmt::Error); } if self.0.is_empty() {
return Err(fmt::Error);
}
write!(f, "{}", BASE32HEX_NOPAD_ALLOW_WS.encode(&self.0)) write!(f, "{}", BASE32HEX_NOPAD_ALLOW_WS.encode(&self.0))
} }
} }

View File

@ -1,8 +1,8 @@
use bytes::{Bytes, Buf, BufMut};
use crate::common_types::Type; use crate::common_types::Type;
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes}; use crate::ser::packet::{remaining_bytes, DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, skip_whitespace}; use crate::ser::text::{skip_whitespace, DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::{Buf, BufMut, Bytes};
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
@ -67,10 +67,7 @@ impl DnsPacketData for NxtTypeBitmap {
current += 1; current += 1;
} }
} }
Ok(NxtTypeBitmap{ Ok(NxtTypeBitmap { raw: raw, set: set })
raw: raw,
set: set,
})
} }
fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {

View File

@ -1,8 +1,8 @@
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use failure::Fail;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field}; use crate::ser::text::{next_field, DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::Bytes;
use failure::Fail;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
@ -38,11 +38,9 @@ impl DnsTextData for OptionalTTL {
*data = data_found; *data = data_found;
Ok(OptionalTTL(ttl)) Ok(OptionalTTL(ttl))
}, },
Err(e) => { Err(e) => Ok(OptionalTTL(context.last_ttl().ok_or_else(|| {
Ok(OptionalTTL(context.last_ttl() e.context("TTL not available in context, failed parsing optional TTL")
.ok_or_else(|| e.context("TTL not available in context, failed parsing optional TTL"))? })?)),
))
},
} }
} }

View File

@ -1,7 +1,9 @@
use bytes::{Bytes, Buf, BufMut};
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext, short_blob, write_short_blob, remaining_bytes}; use crate::ser::packet::{
remaining_bytes, short_blob, write_short_blob, DnsPacketData, DnsPacketWriteContext,
};
use crate::ser::text::*; use crate::ser::text::*;
use bytes::{Buf, BufMut, Bytes};
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
@ -43,7 +45,9 @@ impl DnsPacketData for LongText {
let mut texts = Vec::new(); let mut texts = Vec::new();
loop { loop {
texts.push(short_blob(data)?); texts.push(short_blob(data)?);
if !data.has_remaining() { break; } if !data.has_remaining() {
break;
}
} }
Ok(LongText(texts)) Ok(LongText(texts))
} }
@ -66,7 +70,10 @@ impl DnsTextData for LongText {
skip_whitespace(data); skip_whitespace(data);
while !data.is_empty() { while !data.is_empty() {
let part = next_quoted_field(data)?; let part = next_quoted_field(data)?;
failure::ensure!(part.len() < 256, "long text component must be at most 255 bytes long"); failure::ensure!(
part.len() < 256,
"long text component must be at most 255 bytes long"
);
result.push(part.into()); result.push(part.into());
} }
Ok(LongText(result)) Ok(LongText(result))

View File

@ -36,7 +36,11 @@ fn test_is_leap_year() {
fn month_day_of_year_since_march(month: u8) -> u16 { fn month_day_of_year_since_march(month: u8) -> u16 {
debug_assert!(month >= 1 && month <= 12); debug_assert!(month >= 1 && month <= 12);
let month_from_march = if month > 2 { month as u16 - 3} else { month as u16 + 9 }; let month_from_march = if month > 2 {
month as u16 - 3
} else {
month as u16 + 9
};
(153 * month_from_march + 2) / 5 (153 * month_from_march + 2) / 5
} }
@ -49,7 +53,11 @@ fn month_and_day_from_day_of_year_since_march(day_of_year: i32) -> (u8, u8) {
let month_from_march = month_from_march as u8; let month_from_march = month_from_march as u8;
let month = if month_from_march < 10 { month_from_march + 3 } else { month_from_march - 9 }; let month = if month_from_march < 10 {
month_from_march + 3
} else {
month_from_march - 9
};
debug_assert!(month >= 1 && month <= 12); debug_assert!(month >= 1 && month <= 12);
(month, day as u8) (month, day as u8)
@ -57,12 +65,18 @@ fn month_and_day_from_day_of_year_since_march(day_of_year: i32) -> (u8, u8) {
#[test] #[test]
fn test_month_day_of_year_since_march() { fn test_month_day_of_year_since_march() {
static MONTH_START: [u16; 12] = [306, 337, 0, 31, 61, 92, 122, 153, 184, 214, 245, 275, ]; static MONTH_START: [u16; 12] = [306, 337, 0, 31, 61, 92, 122, 153, 184, 214, 245, 275];
static DAYS_IN_MONTHS: [u8; 12] = [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; static DAYS_IN_MONTHS: [u8; 12] = [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
for (m, &s) in MONTH_START.iter().enumerate() { for (m, &s) in MONTH_START.iter().enumerate() {
assert_eq!(month_day_of_year_since_march(m as u8 + 1), s); assert_eq!(month_day_of_year_since_march(m as u8 + 1), s);
assert_eq!(month_and_day_from_day_of_year_since_march(s as i32), (m as u8 + 1, 1)); assert_eq!(
assert_eq!(month_and_day_from_day_of_year_since_march(s as i32 + DAYS_IN_MONTHS[m] as i32 - 1), (m as u8 + 1, DAYS_IN_MONTHS[m])); month_and_day_from_day_of_year_since_march(s as i32),
(m as u8 + 1, 1)
);
assert_eq!(
month_and_day_from_day_of_year_since_march(s as i32 + DAYS_IN_MONTHS[m] as i32 - 1),
(m as u8 + 1, DAYS_IN_MONTHS[m])
);
} }
assert_eq!(month_and_day_from_day_of_year_since_march(365), (2, 29)); assert_eq!(month_and_day_from_day_of_year_since_march(365), (2, 29));
} }
@ -136,7 +150,7 @@ fn split_days_since_march1_y0_into_era(days: i32) -> (i32, i32) {
} }
fn year_of_era_from_day_of_era(day_of_era: i32) -> i32 { fn year_of_era_from_day_of_era(day_of_era: i32) -> i32 {
let res = (day_of_era - day_of_era/1460 + day_of_era/36524 - day_of_era/146096) / 365; let res = (day_of_era - day_of_era / 1460 + day_of_era / 36524 - day_of_era / 146096) / 365;
debug_assert!(res >= 0 && res <= 399); debug_assert!(res >= 0 && res <= 399);
res res
} }
@ -197,15 +211,21 @@ fn test_days_since_march1_y0() {
assert_eq!(days_since_march1_y0(0, 3, 1), 0); assert_eq!(days_since_march1_y0(0, 3, 1), 0);
assert_eq!(days_since_march1_y0(1, 3, 1), 365); assert_eq!(days_since_march1_y0(1, 3, 1), 365);
assert_eq!(days_since_march1_y0(1970, 1, 1), EPOCH_DAYS_SINCE_MARCH1_Y0); assert_eq!(days_since_march1_y0(1970, 1, 1), EPOCH_DAYS_SINCE_MARCH1_Y0);
assert_eq!(days_since_march1_y0(1970, 1, 1), assert_eq!(
/* regular days in years: */ 365 * 1969 days_since_march1_y0(1970, 1, 1),
/* regular days in years: */
365 * 1969
/* days from leap years: */ + (1969 / 4 - 15) /* days from leap years: */ + (1969 / 4 - 15)
/* days from march 1st year 0 to january 1st year 1: */ + 306); /* days from march 1st year 0 to january 1st year 1: */ + 306
);
assert_eq!(split_days_since_march1_y0(-366), (-1, 3, 1)); assert_eq!(split_days_since_march1_y0(-366), (-1, 3, 1));
assert_eq!(split_days_since_march1_y0(0), (0, 3, 1)); assert_eq!(split_days_since_march1_y0(0), (0, 3, 1));
assert_eq!(split_days_since_march1_y0(365), (1, 3, 1)); assert_eq!(split_days_since_march1_y0(365), (1, 3, 1));
assert_eq!(split_days_since_march1_y0(EPOCH_DAYS_SINCE_MARCH1_Y0), (1970, 1, 1)); assert_eq!(
split_days_since_march1_y0(EPOCH_DAYS_SINCE_MARCH1_Y0),
(1970, 1, 1)
);
} }
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
@ -223,7 +243,10 @@ impl Tm {
pub fn from_epoch(epoch: i64) -> Result<Self> { pub fn from_epoch(epoch: i64) -> Result<Self> {
let (day, time_of_day) = pos_div_rem64(epoch, 86400); let (day, time_of_day) = pos_div_rem64(epoch, 86400);
let days_since_march1_y0 = day + EPOCH_DAYS_SINCE_MARCH1_Y0 as i64; let days_since_march1_y0 = day + EPOCH_DAYS_SINCE_MARCH1_Y0 as i64;
failure::ensure!((days_since_march1_y0 as i32) as i64 == days_since_march1_y0, "days in epoch out of range"); failure::ensure!(
(days_since_march1_y0 as i32) as i64 == days_since_march1_y0,
"days in epoch out of range"
);
let days_since_march1_y0 = days_since_march1_y0 as i32; let days_since_march1_y0 = days_since_march1_y0 as i32;
let (year, month, day) = split_days_since_march1_y0(days_since_march1_y0); let (year, month, day) = split_days_since_march1_y0(days_since_march1_y0);
@ -233,7 +256,7 @@ impl Tm {
let (minute_of_day, second) = pos_div_rem(time_of_day as i32, 60); let (minute_of_day, second) = pos_div_rem(time_of_day as i32, 60);
let (hour, minute) = pos_div_rem(minute_of_day, 60); let (hour, minute) = pos_div_rem(minute_of_day, 60);
Ok(Tm{ Ok(Tm {
year, year,
month, month,
day, day,
@ -249,9 +272,7 @@ impl Tm {
} }
fn day_seconds(&self) -> u32 { fn day_seconds(&self) -> u32 {
self.second as u32 self.second as u32 + 60 * self.minute as u32 + 3600 * self.hour as u32
+ 60 * self.minute as u32
+ 3600 * self.hour as u32
} }
pub fn epoch(&self) -> i64 { pub fn epoch(&self) -> i64 {
@ -261,14 +282,24 @@ impl Tm {
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn parse_YYYYMMDDHHmmSS(s: &str) -> Result<Self> { pub fn parse_YYYYMMDDHHmmSS(s: &str) -> Result<Self> {
failure::ensure!(s.len() == 14, "Tm string must be exactly 14 digits long"); failure::ensure!(s.len() == 14, "Tm string must be exactly 14 digits long");
failure::ensure!(s.as_bytes().iter().all(|&b| b >= b'0' && b <= b'9'), "Tm string must be exactly 14 digits long"); failure::ensure!(
s.as_bytes().iter().all(|&b| b >= b'0' && b <= b'9'),
"Tm string must be exactly 14 digits long"
);
let year = s[0..4].parse::<i16>()?; let year = s[0..4].parse::<i16>()?;
failure::ensure!(year >= 1, "year must be >= 1"); failure::ensure!(year >= 1, "year must be >= 1");
failure::ensure!(year <= 9999, "year must be <= 9999"); failure::ensure!(year <= 9999, "year must be <= 9999");
fn p(s: &str, min: u8, max: u8, name: &'static str) -> crate::errors::Result<u8> { fn p(s: &str, min: u8, max: u8, name: &'static str) -> crate::errors::Result<u8> {
let v = s.parse::<u8>()?; let v = s.parse::<u8>()?;
failure::ensure!(v >= min && v <= max, "{} {} out of range {}-{}", name, v, min, max); failure::ensure!(
v >= min && v <= max,
"{} {} out of range {}-{}",
name,
v,
min,
max
);
Ok(v) Ok(v)
} }
@ -280,25 +311,44 @@ impl Tm {
if 2 == month { if 2 == month {
failure::ensure!(day < 30, "day {} out of range in february", day); failure::ensure!(day < 30, "day {} out of range in february", day);
failure::ensure!(is_leap_year(year) || day < 29, "day {} out of range in february (not a leap year)", day); failure::ensure!(
is_leap_year(year) || day < 29,
"day {} out of range in february (not a leap year)",
day
);
} else { } else {
static DAYS_IN_MONTHS: [u8; 12] = [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; static DAYS_IN_MONTHS: [u8; 12] = [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
let max_days = DAYS_IN_MONTHS[month as usize - 1]; let max_days = DAYS_IN_MONTHS[month as usize - 1];
failure::ensure!(day <= max_days, "day {} out of range for month {}", day, month); failure::ensure!(
day <= max_days,
"day {} out of range for month {}",
day,
month
);
} }
Ok(Tm{ year, month, day, hour, minute, second }) Ok(Tm {
year,
month,
day,
hour,
minute,
second,
})
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn format_YYYYMMDDHHmmSS<W>(&self, f: &mut W) -> fmt::Result pub fn format_YYYYMMDDHHmmSS<W>(&self, f: &mut W) -> fmt::Result
where where
W: fmt::Write + ?Sized W: fmt::Write + ?Sized,
{ {
if self.year < 0 || self.year > 9999 { return Err(fmt::Error); } if self.year < 0 || self.year > 9999 {
write!(f, "{:04}{:02}{:02}{:02}{:02}{:02}", return Err(fmt::Error);
self.year, self.month, self.day, }
self.hour, self.minute, self.second write!(
f,
"{:04}{:02}{:02}{:02}{:02}{:02}",
self.year, self.month, self.day, self.hour, self.minute, self.second
) )
} }
} }

View File

@ -1,7 +1,7 @@
use bytes::Bytes;
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field}; use crate::ser::text::{next_field, DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::Bytes;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
@ -39,7 +39,9 @@ impl DnsTextData for Time {
} }
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
epoch::Tm::from_epoch(self.0 as i64).unwrap().format_YYYYMMDDHHmmSS(&mut*f.format_field()?) epoch::Tm::from_epoch(self.0 as i64)
.unwrap()
.format_YYYYMMDDHHmmSS(&mut *f.format_field()?)
} }
} }
@ -69,7 +71,9 @@ impl DnsTextData for TimeStrict {
} }
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
epoch::Tm::from_epoch(self.0 as i64).unwrap().format_YYYYMMDDHHmmSS(&mut*f.format_field()?) epoch::Tm::from_epoch(self.0 as i64)
.unwrap()
.format_YYYYMMDDHHmmSS(&mut *f.format_field()?)
} }
} }

View File

@ -1,7 +1,7 @@
use bytes::{Bytes, BufMut};
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes}; use crate::ser::packet::{remaining_bytes, DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::*; use crate::ser::text::*;
use bytes::{BufMut, Bytes};
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;

View File

@ -14,10 +14,16 @@ fn parse_rsa(data: &[u8]) -> crate::errors::Result<PublicKey> {
let exp_len: usize; let exp_len: usize;
let offset: usize; let offset: usize;
if data[0] == 0 { if data[0] == 0 {
failure::ensure!(data.len() >= 3, "RSA public key: unexpected end of data when decoding exponent length"); failure::ensure!(
data.len() >= 3,
"RSA public key: unexpected end of data when decoding exponent length"
);
exp_len = (data[1] as usize) << 8 + (data[2] as usize); exp_len = (data[1] as usize) << 8 + (data[2] as usize);
offset = 3; offset = 3;
failure::ensure!(exp_len >= 256, "RSA public key: exponent length in long form but too small"); failure::ensure!(
exp_len >= 256,
"RSA public key: exponent length in long form but too small"
);
} else { } else {
exp_len = data[0] as usize; exp_len = data[0] as usize;
offset = 1; offset = 1;
@ -25,23 +31,37 @@ fn parse_rsa(data: &[u8]) -> crate::errors::Result<PublicKey> {
assert!(exp_len > 0); // should be unreachable: 0 means two bytes, which are checked for >= 256 assert!(exp_len > 0); // should be unreachable: 0 means two bytes, which are checked for >= 256
failure::ensure!(exp_len <= RSA_BYTES_LIMIT, "RSA public key: exponent too long (limit: {} bits)", RSA_BITS_LIMIT); failure::ensure!(
exp_len <= RSA_BYTES_LIMIT,
"RSA public key: exponent too long (limit: {} bits)",
RSA_BITS_LIMIT
);
failure::ensure!(data.len() >= offset + exp_len, "RSA public key: unexpected end of data when reading exponent"); failure::ensure!(
failure::ensure!(data[offset] != 0, "RSA public key: leading zero in exponent"); data.len() >= offset + exp_len,
"RSA public key: unexpected end of data when reading exponent"
);
failure::ensure!(
data[offset] != 0,
"RSA public key: leading zero in exponent"
);
let exponent = BigUint::from_bytes_be(&data[offset..][..exp_len]); let exponent = BigUint::from_bytes_be(&data[offset..][..exp_len]);
let modulus_data = &data[offset..][exp_len..]; let modulus_data = &data[offset..][exp_len..];
failure::ensure!(modulus_data.len() <= RSA_BYTES_LIMIT, "RSA public key: modulus too long (limit: {} bits)", RSA_BITS_LIMIT); failure::ensure!(
modulus_data.len() <= RSA_BYTES_LIMIT,
"RSA public key: modulus too long (limit: {} bits)",
RSA_BITS_LIMIT
);
failure::ensure!(!modulus_data.is_empty(), "RSA public key: modulus empty"); failure::ensure!(!modulus_data.is_empty(), "RSA public key: modulus empty");
failure::ensure!(modulus_data[offset] != 0, "RSA public key: leading zero in modulus"); failure::ensure!(
modulus_data[offset] != 0,
"RSA public key: leading zero in modulus"
);
let modulus = BigUint::from_bytes_be(modulus_data); let modulus = BigUint::from_bytes_be(modulus_data);
Ok(PublicKey::RSA { Ok(PublicKey::RSA { exponent, modulus })
exponent,
modulus,
})
} }
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
@ -58,34 +78,53 @@ impl PublicKey {
pub fn parse(algorithm: DnsSecAlgorithm, data: &[u8]) -> crate::errors::Result<Self> { pub fn parse(algorithm: DnsSecAlgorithm, data: &[u8]) -> crate::errors::Result<Self> {
use DnsSecAlgorithmKnown::*; use DnsSecAlgorithmKnown::*;
let algorithm = algorithm.into_known().ok_or_else(|| failure::format_err!("Unknown algorithm"))?; let algorithm = algorithm
.into_known()
.ok_or_else(|| failure::format_err!("Unknown algorithm"))?;
match algorithm { match algorithm {
DELETE|INDIRECT|PRIVATEDNS|PRIVATEOID => failure::bail!("Algorithm {:?} not used with actual key", algorithm), DELETE | INDIRECT | PRIVATEDNS | PRIVATEOID => {
RSAMD5|RSASHA1|RSASHA1_NSEC3_SHA1|RSASHA256|RSASHA512 => parse_rsa(data), failure::bail!("Algorithm {:?} not used with actual key", algorithm)
DH|DSA|DSA_NSEC3_SHA1 => failure::bail!("Algorithm {:?} not supported", algorithm), },
RSAMD5 | RSASHA1 | RSASHA1_NSEC3_SHA1 | RSASHA256 | RSASHA512 => parse_rsa(data),
DH | DSA | DSA_NSEC3_SHA1 => failure::bail!("Algorithm {:?} not supported", algorithm),
ECDSAP256SHA256 => { ECDSAP256SHA256 => {
failure::ensure!(data.len() == 64, "Expected 64 bytes public key for ECDSAP256"); failure::ensure!(
data.len() == 64,
"Expected 64 bytes public key for ECDSAP256"
);
let mut x = [0u8; 32]; let mut x = [0u8; 32];
x.copy_from_slice(&data[..32]); x.copy_from_slice(&data[..32]);
let mut y = [0u8; 32]; let mut y = [0u8; 32];
y.copy_from_slice(&data[32..]); y.copy_from_slice(&data[32..]);
Ok(PublicKey::ECDSAP256 { xy: Box::new((x, y)) }) Ok(PublicKey::ECDSAP256 {
xy: Box::new((x, y)),
})
}, },
ECDSAP384SHA384 => { ECDSAP384SHA384 => {
failure::ensure!(data.len() == 96, "Expected 96 bytes public key for ECDSAP384"); failure::ensure!(
data.len() == 96,
"Expected 96 bytes public key for ECDSAP384"
);
let mut x = [0u8; 48]; let mut x = [0u8; 48];
x.copy_from_slice(&data[..48]); x.copy_from_slice(&data[..48]);
let mut y = [0u8; 48]; let mut y = [0u8; 48];
y.copy_from_slice(&data[48..]); y.copy_from_slice(&data[48..]);
Ok(PublicKey::ECDSAP384 { xy: Box::new((x, y)) }) Ok(PublicKey::ECDSAP384 {
xy: Box::new((x, y)),
})
}, },
ECC_GOST => { ECC_GOST => {
failure::ensure!(data.len() == 64, "Expected 64 bytes public key for ECC_GOST"); failure::ensure!(
data.len() == 64,
"Expected 64 bytes public key for ECC_GOST"
);
let mut x = [0u8; 32]; let mut x = [0u8; 32];
x.copy_from_slice(&data[..32]); x.copy_from_slice(&data[..32]);
let mut y = [0u8; 32]; let mut y = [0u8; 32];
y.copy_from_slice(&data[32..]); y.copy_from_slice(&data[32..]);
Ok(PublicKey::ECC_GOST { xy: Box::new((x, y)) }) Ok(PublicKey::ECC_GOST {
xy: Box::new((x, y)),
})
}, },
ED25519 => { ED25519 => {
failure::ensure!(data.len() == 32, "Expected 32 bytes public key for ED25519"); failure::ensure!(data.len() == 32, "Expected 32 bytes public key for ED25519");
@ -105,11 +144,11 @@ impl PublicKey {
pub fn bits(&self) -> Option<u32> { pub fn bits(&self) -> Option<u32> {
match self { match self {
PublicKey::RSA { modulus, .. } => Some(modulus.bits() as u32), PublicKey::RSA { modulus, .. } => Some(modulus.bits() as u32),
PublicKey::ECDSAP256 { .. } => Some(32*8), PublicKey::ECDSAP256 { .. } => Some(32 * 8),
PublicKey::ECDSAP384 { .. } => Some(48*8), PublicKey::ECDSAP384 { .. } => Some(48 * 8),
PublicKey::ECC_GOST { .. } => Some(32*8), PublicKey::ECC_GOST { .. } => Some(32 * 8),
PublicKey::ED25519 { .. } => Some(32*8), PublicKey::ED25519 { .. } => Some(32 * 8),
PublicKey::ED448 { .. } => Some(57*8), PublicKey::ED448 { .. } => Some(57 * 8),
} }
} }
} }

View File

@ -15,7 +15,7 @@ pub struct NotEnoughData {
impl NotEnoughData { impl NotEnoughData {
pub fn check(data: &mut io::Cursor<Bytes>, need: usize) -> Result<()> { pub fn check(data: &mut io::Cursor<Bytes>, need: usize) -> Result<()> {
if data.remaining() < need { if data.remaining() < need {
failure::bail!(NotEnoughData{ failure::bail!(NotEnoughData {
position: data.position(), position: data.position(),
data: data.get_ref().clone(), data: data.get_ref().clone(),
}) })
@ -26,17 +26,19 @@ impl NotEnoughData {
impl fmt::Display for NotEnoughData { impl fmt::Display for NotEnoughData {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "not enough data at position {} in {:?}", self.position, self.data) write!(
f,
"not enough data at position {} in {:?}",
self.position, self.data
)
} }
} }
impl failure::Fail for NotEnoughData {} impl failure::Fail for NotEnoughData {}
macro_rules! check_enough_data { macro_rules! check_enough_data {
($data:ident, $n:expr, $context:expr) => { ($data:ident, $n:expr, $context:expr) => {{
{
use $crate::_failure::ResultExt; use $crate::_failure::ResultExt;
$crate::errors::NotEnoughData::check($data, $n).context($context)?; $crate::errors::NotEnoughData::check($data, $n).context($context)?;
} }};
};
} }

View File

@ -1,17 +1,17 @@
#[doc(hidden)] #[doc(hidden)]
pub use failure as _failure; // re-export for macros pub use bytes as _bytes;
#[doc(hidden)] #[doc(hidden)]
pub use bytes as _bytes; // re-export for macros pub use failure as _failure; // re-export for macros // re-export for macros
extern crate self as dnsbox_base; extern crate self as dnsbox_base;
#[macro_use] #[macro_use]
pub mod errors; pub mod errors;
pub mod common_types;
#[cfg(feature = "crypto")] #[cfg(feature = "crypto")]
pub mod crypto; pub mod crypto;
pub mod common_types;
pub mod ser;
pub mod packet; pub mod packet;
pub mod records; pub mod records;
pub mod ser;
mod unsafe_ops; mod unsafe_ops;

View File

@ -1,11 +1,11 @@
use byteorder::ByteOrder; use crate::common_types::{types, Class, DnsCompressedName, Type};
use bytes::{Bytes, Buf, BufMut, BigEndian};
use crate::common_types::{Type, Class, DnsCompressedName, types};
use crate::errors::*; use crate::errors::*;
use crate::ser::RRData;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use std::io::Cursor;
use crate::records::registry::deserialize_rr_data; use crate::records::registry::deserialize_rr_data;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::RRData;
use byteorder::ByteOrder;
use bytes::{BigEndian, Buf, BufMut, Bytes};
use std::io::Cursor;
pub mod opt; pub mod opt;
@ -38,7 +38,11 @@ pub struct DnsHeaderFlags {
impl DnsPacketData for DnsHeaderFlags { impl DnsPacketData for DnsHeaderFlags {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> { fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
let raw = u16::deserialize(data)?; let raw = u16::deserialize(data)?;
let qr = if 0 == raw & 0x8000 { QueryResponse::Query } else { QueryResponse::Response }; let qr = if 0 == raw & 0x8000 {
QueryResponse::Query
} else {
QueryResponse::Response
};
let opcode = 0xf & (raw >> 11) as u8; let opcode = 0xf & (raw >> 11) as u8;
let authoritative_answer = 0 != raw & 0x0400; let authoritative_answer = 0 != raw & 0x0400;
let truncation = 0 != raw & 0x0200; let truncation = 0 != raw & 0x0200;
@ -48,7 +52,7 @@ impl DnsPacketData for DnsHeaderFlags {
let authentic_data = 0 != raw & 0x0020; let authentic_data = 0 != raw & 0x0020;
let checking_disabled = 0 != raw & 0x0010; let checking_disabled = 0 != raw & 0x0010;
let rcode = 0xf & raw as u8; let rcode = 0xf & raw as u8;
Ok(DnsHeaderFlags{ Ok(DnsHeaderFlags {
qr, qr,
opcode, opcode,
authoritative_answer, authoritative_answer,
@ -67,8 +71,7 @@ impl DnsPacketData for DnsHeaderFlags {
| match self.qr { | match self.qr {
QueryResponse::Query => 0, QueryResponse::Query => 0,
QueryResponse::Response => 1, QueryResponse::Response => 1,
} } | (((0xf & self.opcode) as u16) << 11)
| (((0xf & self.opcode) as u16) << 11)
| if self.authoritative_answer { 0x0400 } else { 0 } | if self.authoritative_answer { 0x0400 } else { 0 }
| if self.truncation { 0x0200 } else { 0 } | if self.truncation { 0x0200 } else { 0 }
| if self.recursion_desired { 0x0100 } else { 0 } | if self.recursion_desired { 0x0100 } else { 0 }
@ -76,8 +79,7 @@ impl DnsPacketData for DnsHeaderFlags {
| if self.reserved_bit9 { 0x0040 } else { 0 } | if self.reserved_bit9 { 0x0040 } else { 0 }
| if self.authentic_data { 0x0020 } else { 0 } | if self.authentic_data { 0x0020 } else { 0 }
| if self.checking_disabled { 0x0010 } else { 0 } | if self.checking_disabled { 0x0010 } else { 0 }
| (0xf & self.rcode) as u16 | (0xf & self.rcode) as u16;
;
flags.serialize(context, packet) flags.serialize(context, packet)
} }
} }
@ -123,9 +125,13 @@ impl DnsPacketData for Resource {
rrdata.advance(pos); rrdata.advance(pos);
let rd = deserialize_rr_data(ttl, class, rr_type, &mut rrdata)?; let rd = deserialize_rr_data(ttl, class, rr_type, &mut rrdata)?;
failure::ensure!(!rrdata.has_remaining(), "data remaining: {} bytes", rrdata.remaining()); failure::ensure!(
!rrdata.has_remaining(),
"data remaining: {} bytes",
rrdata.remaining()
);
Ok(Resource{ Ok(Resource {
name, name,
class, class,
ttl, ttl,
@ -204,7 +210,7 @@ impl DnsPacket {
impl Default for DnsPacket { impl Default for DnsPacket {
fn default() -> Self { fn default() -> Self {
DnsPacket{ DnsPacket {
id: 0, id: 0,
flags: DnsHeaderFlags::default(), flags: DnsHeaderFlags::default(),
question: Vec::new(), question: Vec::new(),
@ -222,10 +228,18 @@ impl DnsPacketData for DnsPacket {
let mut p = DnsPacket { let mut p = DnsPacket {
id: header.id, id: header.id,
flags: header.flags, flags: header.flags,
question: (0..header.qdcount).map(|_| Question::deserialize(data)).collect::<Result<Vec<_>>>()?, question: (0..header.qdcount)
answer: (0..header.ancount).map(|_| Resource::deserialize(data)).collect::<Result<Vec<_>>>()?, .map(|_| Question::deserialize(data))
authority: (0..header.nscount).map(|_| Resource::deserialize(data)).collect::<Result<Vec<_>>>()?, .collect::<Result<Vec<_>>>()?,
additional: (0..header.arcount).map(|_| Resource::deserialize(data)).collect::<Result<Vec<_>>>()?, answer: (0..header.ancount)
.map(|_| Resource::deserialize(data))
.collect::<Result<Vec<_>>>()?,
authority: (0..header.nscount)
.map(|_| Resource::deserialize(data))
.collect::<Result<Vec<_>>>()?,
additional: (0..header.arcount)
.map(|_| Resource::deserialize(data))
.collect::<Result<Vec<_>>>()?,
opt: None, opt: None,
}; };
@ -249,9 +263,15 @@ impl DnsPacketData for DnsPacket {
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
failure::ensure!(self.question.len() < 0x1_0000, "too many question entries"); failure::ensure!(self.question.len() < 0x1_0000, "too many question entries");
failure::ensure!(self.answer.len() < 0x1_0000, "too many answer entries"); failure::ensure!(self.answer.len() < 0x1_0000, "too many answer entries");
failure::ensure!(self.authority.len() < 0x1_0000, "too many authority entries"); failure::ensure!(
failure::ensure!(self.additional.len() < 0x1_0000, "too many additional entries"); self.authority.len() < 0x1_0000,
let header = DnsHeader{ "too many authority entries"
);
failure::ensure!(
self.additional.len() < 0x1_0000,
"too many additional entries"
);
let header = DnsHeader {
id: self.id, id: self.id,
flags: self.flags, flags: self.flags,
qdcount: self.question.len() as u16, qdcount: self.question.len() as u16,

View File

@ -1,10 +1,10 @@
use byteorder::ByteOrder; use crate::common_types::{types, Class, DnsCompressedName};
use bytes::{Bytes, Buf, BufMut, BigEndian};
use crate::common_types::{DnsCompressedName, Class, types};
use crate::errors::*; use crate::errors::*;
use crate::packet::Resource; use crate::packet::Resource;
use crate::records::UnknownRecord; use crate::records::UnknownRecord;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext, get_blob, remaining_bytes}; use crate::ser::packet::{get_blob, remaining_bytes, DnsPacketData, DnsPacketWriteContext};
use byteorder::ByteOrder;
use bytes::{BigEndian, Buf, BufMut, Bytes};
use std::io::{Cursor, Read}; use std::io::{Cursor, Read};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
@ -65,7 +65,7 @@ pub enum DnsOption {
Unknown { Unknown {
code: u16, code: u16,
data: Bytes, data: Bytes,
} },
} }
impl DnsOption { impl DnsOption {
@ -74,7 +74,12 @@ impl DnsOption {
let source_prefix_length = u8::deserialize(data)?; let source_prefix_length = u8::deserialize(data)?;
let scope_prefix_length = u8::deserialize(data)?; let scope_prefix_length = u8::deserialize(data)?;
let addr_prefix_len = ((source_prefix_length + 7) / 8) as usize; let addr_prefix_len = ((source_prefix_length + 7) / 8) as usize;
failure::ensure!(scope_prefix_length <= source_prefix_length, "scope prefix {} > source prefix {}", scope_prefix_length, source_prefix_length); failure::ensure!(
scope_prefix_length <= source_prefix_length,
"scope prefix {} > source prefix {}",
scope_prefix_length,
source_prefix_length
);
let addr = match addr_family { let addr = match addr_family {
1 => { 1 => {
failure::ensure!(source_prefix_length <= 32, "invalid prefix for IPv4"); failure::ensure!(source_prefix_length <= 32, "invalid prefix for IPv4");
@ -100,7 +105,7 @@ impl DnsOption {
failure::bail!("unknown address family {}", addr_family); failure::bail!("unknown address family {}", addr_family);
}, },
}; };
Ok(DnsOption::ClientSubnet{ Ok(DnsOption::ClientSubnet {
source_prefix_length, source_prefix_length,
scope_prefix_length, scope_prefix_length,
addr, addr,
@ -110,24 +115,38 @@ impl DnsOption {
fn parse_opt(code: u16, opt_data: Bytes) -> Result<Self> { fn parse_opt(code: u16, opt_data: Bytes) -> Result<Self> {
let mut data = Cursor::new(opt_data); let mut data = Cursor::new(opt_data);
let result = (|| Ok(match code { let result = (|| {
Ok(match code {
0x0003 => DnsOption::NSID(remaining_bytes(&mut data)), 0x0003 => DnsOption::NSID(remaining_bytes(&mut data)),
0x0008 => DnsOption::parse_client_subnet(&mut data)?, 0x0008 => DnsOption::parse_client_subnet(&mut data)?,
_ => failure::bail!("unknown option {}", code), _ => failure::bail!("unknown option {}", code),
}))()?; })
})()?;
failure::ensure!(!data.has_remaining(), "option data remaining: {} bytes", data.remaining()); failure::ensure!(
!data.has_remaining(),
"option data remaining: {} bytes",
data.remaining()
);
Ok(result) Ok(result)
} }
fn write_opt_data(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn write_opt_data(
&self,
_context: &mut DnsPacketWriteContext,
packet: &mut Vec<u8>,
) -> Result<()> {
match *self { match *self {
DnsOption::NSID(ref id) => { DnsOption::NSID(ref id) => {
packet.reserve(id.len()); packet.reserve(id.len());
packet.put_slice(id); packet.put_slice(id);
}, },
DnsOption::ClientSubnet{source_prefix_length, scope_prefix_length, ref addr} => { DnsOption::ClientSubnet {
source_prefix_length,
scope_prefix_length,
ref addr,
} => {
let addr_prefix_len = ((source_prefix_length + 7) / 8) as usize; let addr_prefix_len = ((source_prefix_length + 7) / 8) as usize;
packet.reserve(4 + addr_prefix_len); packet.reserve(4 + addr_prefix_len);
packet.put_u16_be(match *addr { packet.put_u16_be(match *addr {
@ -148,7 +167,7 @@ impl DnsOption {
}, },
} }
}, },
DnsOption::Unknown{ref data, ..} => { DnsOption::Unknown { ref data, .. } => {
packet.reserve(data.len()); packet.reserve(data.len());
packet.put_slice(data); packet.put_slice(data);
}, },
@ -165,7 +184,7 @@ impl DnsPacketData for DnsOption {
let opt_data = get_blob(data, opt_len)?; let opt_data = get_blob(data, opt_len)?;
DnsOption::parse_opt(code, opt_data.clone()).or_else(|_| { DnsOption::parse_opt(code, opt_data.clone()).or_else(|_| {
Ok(DnsOption::Unknown{ Ok(DnsOption::Unknown {
code: code, code: code,
data: opt_data, data: opt_data,
}) })
@ -175,8 +194,8 @@ impl DnsPacketData for DnsOption {
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
let code: u16 = match *self { let code: u16 = match *self {
DnsOption::NSID(_) => 0x0003, DnsOption::NSID(_) => 0x0003,
DnsOption::ClientSubnet{..} => 0x0008, DnsOption::ClientSubnet { .. } => 0x0008,
DnsOption::Unknown{code, ..} => code, DnsOption::Unknown { code, .. } => code,
}; };
code.serialize(context, packet)?; code.serialize(context, packet)?;
@ -230,7 +249,9 @@ impl Opt {
let version = (r.ttl >> 16) as u8; let version = (r.ttl >> 16) as u8;
let flags = OptFlags(r.ttl as u16); let flags = OptFlags(r.ttl as u16);
if version > 0 { return Ok(Err(OptError::UnknownVersion)); } if version > 0 {
return Ok(Err(OptError::UnknownVersion));
}
let ur = match r.data.as_any().downcast_ref::<UnknownRecord>() { let ur = match r.data.as_any().downcast_ref::<UnknownRecord>() {
Some(ur) => ur, Some(ur) => ur,
@ -263,7 +284,7 @@ impl Opt {
let ttl = ((self.extended_rcode_high as u32) << 24) let ttl = ((self.extended_rcode_high as u32) << 24)
| ((self.version as u32) << 16) | ((self.version as u32) << 16)
| self.flags.0 as u32; | self.flags.0 as u32;
Ok(Resource{ Ok(Resource {
name: DnsCompressedName::new_root(), name: DnsCompressedName::new_root(),
class: Class(self.udp_payload_size), class: Class(self.udp_payload_size),
ttl: ttl, ttl: ttl,

View File

@ -1,7 +1,7 @@
mod weird_structs; pub mod registry;
mod structs; mod structs;
mod unknown; mod unknown;
pub mod registry; mod weird_structs;
pub use self::structs::*; pub use self::structs::*;
pub use self::unknown::*; pub use self::unknown::*;

View File

@ -1,33 +1,29 @@
#![allow(non_snake_case)] #![allow(non_snake_case)]
use bytes::Bytes; use crate::common_types::{classes, types, Class, DnsCompressedName, DnsName, Type};
use crate::common_types::{DnsName, DnsCompressedName, Class, Type, types, classes};
use crate::errors::*; use crate::errors::*;
use crate::packet::opt::{DnsOption, Opt};
use crate::packet::*; use crate::packet::*;
use crate::packet::opt::{Opt, DnsOption}; use crate::records::{registry, UnknownRecord, A};
use crate::records::{UnknownRecord, registry, A}; use crate::ser::packet::{deserialize_with, DnsPacketData};
use crate::ser::{RRData, text}; use crate::ser::{text, RRData};
use crate::ser::packet::{DnsPacketData, deserialize_with}; use bytes::Bytes;
use std::io::Cursor; use std::io::Cursor;
fn fake_packet(rrtype: Type, raw: &[u8]) -> Bytes { fn fake_packet(rrtype: Type, raw: &[u8]) -> Bytes {
let mut p = DnsPacket{ let mut p = DnsPacket {
question: vec![ question: vec![Question {
Question {
qname: ".".parse().unwrap(), qname: ".".parse().unwrap(),
qtype: rrtype, qtype: rrtype,
qclass: classes::IN, qclass: classes::IN,
} }],
], answer: vec![Resource {
answer: vec![
Resource {
name: "rec.test.".parse().unwrap(), name: "rec.test.".parse().unwrap(),
class: classes::IN, class: classes::IN,
ttl: 0, ttl: 0,
data: Box::new(UnknownRecord::new(rrtype, Bytes::from(raw))), data: Box::new(UnknownRecord::new(rrtype, Bytes::from(raw))),
} }],
], ..Default::default()
.. Default::default()
}; };
p.to_bytes().unwrap().into() p.to_bytes().unwrap().into()
@ -56,29 +52,24 @@ fn get_first_answer_rdata(packet: Bytes) -> Result<Bytes> {
} }
fn serialized_answer(rrdata: Box<dyn RRData>) -> Result<Bytes> { fn serialized_answer(rrdata: Box<dyn RRData>) -> Result<Bytes> {
let mut p = DnsPacket{ let mut p = DnsPacket {
question: vec![ question: vec![Question {
Question {
qname: ".".parse().unwrap(), qname: ".".parse().unwrap(),
qtype: rrdata.rr_type(), qtype: rrdata.rr_type(),
qclass: classes::IN, qclass: classes::IN,
} }],
], answer: vec![Resource {
answer: vec![
Resource {
name: "rec.test.".parse().unwrap(), name: "rec.test.".parse().unwrap(),
class: classes::IN, class: classes::IN,
ttl: 0, ttl: 0,
data: rrdata, data: rrdata,
} }],
], ..Default::default()
.. Default::default()
}; };
get_first_answer_rdata(p.to_bytes()?.into()) get_first_answer_rdata(p.to_bytes()?.into())
} }
fn check(q: Type, text_input: &'static str, canonic: Option<&'static str>, raw: &'static [u8]) { fn check(q: Type, text_input: &'static str, canonic: Option<&'static str>, raw: &'static [u8]) {
// Make sure the canonic representation is sound itself // Make sure the canonic representation is sound itself
if let Some(canonic) = canonic { if let Some(canonic) = canonic {
@ -93,9 +84,8 @@ fn check(q: Type, text_input: &'static str, canonic: Option<&'static str>, raw:
context.set_record_type(q); context.set_record_type(q);
context.set_last_ttl(3600); context.set_last_ttl(3600);
let d_zone: Box<dyn RRData> = text::parse_with(text_input, |data| { let d_zone: Box<dyn RRData> =
registry::parse_rr_data(&context, data) text::parse_with(text_input, |data| registry::parse_rr_data(&context, data)).unwrap();
}).unwrap();
let d_zone_text = d_zone.text().unwrap(); let d_zone_text = d_zone.text().unwrap();
// make sure we actually know the type and the text representation // make sure we actually know the type and the text representation
@ -117,7 +107,10 @@ fn check(q: Type, text_input: &'static str, canonic: Option<&'static str>, raw:
// pdns tests compare d_wire_text and canonic, but d_zone_text // pdns tests compare d_wire_text and canonic, but d_zone_text
// already matches canonic // already matches canonic
assert_eq!(d_zone_text, d_wire_text, "data parsed from zone doesn't match data from wire"); assert_eq!(
d_zone_text, d_wire_text,
"data parsed from zone doesn't match data from wire"
);
let zone_as_wire = serialized_answer(d_zone).unwrap(); let zone_as_wire = serialized_answer(d_zone).unwrap();
assert_eq!(zone_as_wire, raw); assert_eq!(zone_as_wire, raw);
@ -125,24 +118,16 @@ fn check(q: Type, text_input: &'static str, canonic: Option<&'static str>, raw:
#[test] #[test]
fn test_A() { fn test_A() {
check(types::A, check(types::A, "127.0.0.1", None, b"\x7F\x00\x00\x01");
"127.0.0.1",
None,
b"\x7F\x00\x00\x01",
);
} }
#[test] #[test]
fn test_NS() { fn test_NS() {
// local nameserver // local nameserver
check(types::NS, check(types::NS, "ns.rec.test.", None, b"\x02ns\xc0\x11");
"ns.rec.test.",
None,
b"\x02ns\xc0\x11",
);
// non-local nameserver // non-local nameserver
check(types::NS, check(
types::NS,
"ns.example.com.", "ns.example.com.",
None, None,
b"\x02ns\x07example\x03com\x00", b"\x02ns\x07example\x03com\x00",
@ -155,13 +140,10 @@ fn test_NS() {
#[test] #[test]
fn test_CNAME() { fn test_CNAME() {
// local alias // local alias
check(types::CNAME, check(types::CNAME, "name.rec.test.", None, b"\x04name\xc0\x11");
"name.rec.test.",
None,
b"\x04name\xc0\x11",
);
// non-local alias // non-local alias
check(types::CNAME, check(
types::CNAME,
"name.example.com.", "name.example.com.",
None, None,
b"\x04name\x07example\x03com\x00", b"\x04name\x07example\x03com\x00",
@ -215,13 +197,15 @@ fn test_SOA() {
fn test_MR() { fn test_MR() {
// BROKEN TESTS (2) (deprecated) // BROKEN TESTS (2) (deprecated)
// local name // local name
check(types::MR, check(
types::MR,
"newmailbox.rec.test.", "newmailbox.rec.test.",
None, None,
b"\x0anewmailbox\xc0\x11", b"\x0anewmailbox\xc0\x11",
); );
// non-local name // non-local name
check(types::MR, check(
types::MR,
"newmailbox.example.com.", "newmailbox.example.com.",
None, None,
b"\x0anewmailbox\x07example\x03com\x00", b"\x0anewmailbox\x07example\x03com\x00",
@ -231,13 +215,10 @@ fn test_MR() {
#[test] #[test]
fn test_PTR() { fn test_PTR() {
// local name // local name
check(types::PTR, check(types::PTR, "ptr.rec.test.", None, b"\x03ptr\xc0\x11");
"ptr.rec.test.",
None,
b"\x03ptr\xc0\x11",
);
// non-local name // non-local name
check(types::PTR, check(
types::PTR,
"ptr.example.com.", "ptr.example.com.",
None, None,
b"\x03ptr\x07example\x03com\x00", b"\x03ptr\x07example\x03com\x00",
@ -246,22 +227,26 @@ fn test_PTR() {
#[test] #[test]
fn test_HINFO() { fn test_HINFO() {
check(types::HINFO, check(
types::HINFO,
"\"i686\" \"Linux\"", "\"i686\" \"Linux\"",
None, None,
b"\x04i686\x05Linux", b"\x04i686\x05Linux",
); );
check(types::HINFO, check(
types::HINFO,
"i686 \"Linux\"", "i686 \"Linux\"",
Some("\"i686\" \"Linux\""), Some("\"i686\" \"Linux\""),
b"\x04i686\x05Linux", b"\x04i686\x05Linux",
); );
check(types::HINFO, check(
types::HINFO,
"\"i686\" Linux", "\"i686\" Linux",
Some("\"i686\" \"Linux\""), Some("\"i686\" \"Linux\""),
b"\x04i686\x05Linux", b"\x04i686\x05Linux",
); );
check(types::HINFO, check(
types::HINFO,
"i686 Linux", "i686 Linux",
Some("\"i686\" \"Linux\""), Some("\"i686\" \"Linux\""),
b"\x04i686\x05Linux", b"\x04i686\x05Linux",
@ -273,56 +258,53 @@ fn test_HINFO() {
#[test] #[test]
fn test_MX() { fn test_MX() {
// local name // local name
check(types::MX, check(
types::MX,
"10 mx.rec.test.", "10 mx.rec.test.",
None, None,
b"\x00\x0a\x02mx\xc0\x11", b"\x00\x0a\x02mx\xc0\x11",
); );
// non-local name // non-local name
check(types::MX, check(
types::MX,
"20 mx.example.com.", "20 mx.example.com.",
None, None,
b"\x00\x14\x02mx\x07example\x03com\x00", b"\x00\x14\x02mx\x07example\x03com\x00",
); );
// root label // root label
check(types::MX, check(types::MX, "20 .", None, b"\x00\x14\x00");
"20 .",
None,
b"\x00\x14\x00",
);
} }
#[test] #[test]
fn test_TXT() { fn test_TXT() {
check(types::TXT, check(types::TXT, "\"short text\"", None, b"\x0ashort text");
"\"short text\"",
None,
b"\x0ashort text",
);
check(types::TXT, check(types::TXT,
"\"long record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\" \"2222222222\"", "\"long record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\" \"2222222222\"",
None, None,
b"\xfflong record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\x0a2222222222", b"\xfflong record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\x0a2222222222",
); );
// autosplitting not supported // autosplitting not supported
/* /*
check(types::TXT, check(types::TXT,
"\"long record test 11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112222222222\"", "\"long record test 11111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111112222222222\"",
Some("\"long record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\" \"2222222222\""), Some("\"long record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\" \"2222222222\""),
b"\xfflong record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\x0a2222222222", b"\xfflong record test 1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111\x0a2222222222",
); );
*/ */
check(types::TXT, check(
types::TXT,
"\"\\195\\133LAND ISLANDS\"", "\"\\195\\133LAND ISLANDS\"",
None, None,
b"\x0e\xc3\x85LAND ISLANDS", b"\x0e\xc3\x85LAND ISLANDS",
); );
check(types::TXT, check(
types::TXT,
"\"\u{00c5}LAND ISLANDS\"", "\"\u{00c5}LAND ISLANDS\"",
Some("\"\\195\\133LAND ISLANDS\""), Some("\"\\195\\133LAND ISLANDS\""),
b"\x0e\xc3\x85LAND ISLANDS", b"\x0e\xc3\x85LAND ISLANDS",
); );
check(types::TXT, check(
types::TXT,
"\"nonbreakingtxt\"", "\"nonbreakingtxt\"",
None, None,
b"\x0enonbreakingtxt", b"\x0enonbreakingtxt",
@ -332,13 +314,15 @@ fn test_TXT() {
#[test] #[test]
fn test_RP() { fn test_RP() {
// local name // local name
check(types::RP, check(
types::RP,
"admin.rec.test. admin-info.rec.test.", "admin.rec.test. admin-info.rec.test.",
None, None,
b"\x05admin\x03rec\x04test\x00\x0aadmin-info\x03rec\x04test\x00", b"\x05admin\x03rec\x04test\x00\x0aadmin-info\x03rec\x04test\x00",
); );
// non-local name // non-local name
check(types::RP, check(
types::RP,
"admin.example.com. admin-info.example.com.", "admin.example.com. admin-info.example.com.",
None, None,
b"\x05admin\x07example\x03com\x00\x0aadmin-info\x07example\x03com\x00", b"\x05admin\x07example\x03com\x00\x0aadmin-info\x07example\x03com\x00",
@ -348,13 +332,15 @@ fn test_RP() {
#[test] #[test]
fn test_AFSDB() { fn test_AFSDB() {
// local name // local name
check(types::AFSDB, check(
types::AFSDB,
"1 afs-server.rec.test.", "1 afs-server.rec.test.",
None, None,
b"\x00\x01\x0aafs-server\x03rec\x04test\x00", b"\x00\x01\x0aafs-server\x03rec\x04test\x00",
); );
// non-local name // non-local name
check(types::AFSDB, check(
types::AFSDB,
"1 afs-server.example.com.", "1 afs-server.example.com.",
None, None,
b"\x00\x01\x0aafs-server\x07example\x03com\x00", b"\x00\x01\x0aafs-server\x07example\x03com\x00",
@ -390,17 +376,20 @@ fn test_KEY() {
#[test] #[test]
fn test_AAAA() { fn test_AAAA() {
check(types::AAAA, check(
types::AAAA,
"fe80::250:56ff:fe9b:114", "fe80::250:56ff:fe9b:114",
None, None,
b"\xFE\x80\x00\x00\x00\x00\x00\x00\x02\x50\x56\xFF\xFE\x9B\x01\x14", b"\xFE\x80\x00\x00\x00\x00\x00\x00\x02\x50\x56\xFF\xFE\x9B\x01\x14",
); );
check(types::AAAA, check(
types::AAAA,
"2a02:1b8:10:2::151", "2a02:1b8:10:2::151",
None, None,
b"\x2a\x02\x01\xb8\x00\x10\x00\x02\x00\x00\x00\x00\x00\x00\x01\x51", b"\x2a\x02\x01\xb8\x00\x10\x00\x02\x00\x00\x00\x00\x00\x00\x01\x51",
); );
check(types::AAAA, check(
types::AAAA,
"::1", "::1",
None, None,
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
@ -409,22 +398,26 @@ fn test_AAAA() {
#[test] #[test]
fn test_LOC() { fn test_LOC() {
check(types::LOC, check(
types::LOC,
"32 7 19 S 116 2 25 E", "32 7 19 S 116 2 25 E",
Some("32 7 19.000 S 116 2 25.000 E 0.00m 1.00m 10000.00m 10.00m"), Some("32 7 19.000 S 116 2 25.000 E 0.00m 1.00m 10000.00m 10.00m"),
b"\x00\x12\x16\x13\x79\x1b\x7d\x28\x98\xe6\x48\x68\x00\x98\x96\x80", b"\x00\x12\x16\x13\x79\x1b\x7d\x28\x98\xe6\x48\x68\x00\x98\x96\x80",
); );
check(types::LOC, check(
types::LOC,
"32 7 19 S 116 2 25 E 10m", "32 7 19 S 116 2 25 E 10m",
Some("32 7 19.000 S 116 2 25.000 E 10.00m 1.00m 10000.00m 10.00m"), Some("32 7 19.000 S 116 2 25.000 E 10.00m 1.00m 10000.00m 10.00m"),
b"\x00\x12\x16\x13\x79\x1b\x7d\x28\x98\xe6\x48\x68\x00\x98\x9a\x68", b"\x00\x12\x16\x13\x79\x1b\x7d\x28\x98\xe6\x48\x68\x00\x98\x9a\x68",
); );
check(types::LOC, check(
types::LOC,
"42 21 54 N 71 06 18 W -24m 30m", "42 21 54 N 71 06 18 W -24m 30m",
Some("42 21 54.000 N 71 6 18.000 W -24.00m 30.00m 10000.00m 10.00m"), Some("42 21 54.000 N 71 6 18.000 W -24.00m 30.00m 10000.00m 10.00m"),
b"\x00\x33\x16\x13\x89\x17\x2d\xd0\x70\xbe\x15\xf0\x00\x98\x8d\x20", b"\x00\x33\x16\x13\x89\x17\x2d\xd0\x70\xbe\x15\xf0\x00\x98\x8d\x20",
); );
check(types::LOC, check(
types::LOC,
"42 21 43.952 N 71 5 6.344 W -24m 1m 200m", "42 21 43.952 N 71 5 6.344 W -24m 1m 200m",
Some("42 21 43.952 N 71 5 6.344 W -24.00m 1.00m 200.00m 10.00m"), Some("42 21 43.952 N 71 5 6.344 W -24.00m 1.00m 200.00m 10.00m"),
b"\x00\x12\x24\x13\x89\x17\x06\x90\x70\xbf\x2d\xd8\x00\x98\x8d\x20", b"\x00\x12\x24\x13\x89\x17\x06\x90\x70\xbf\x2d\xd8\x00\x98\x8d\x20",
@ -436,19 +429,22 @@ fn test_LOC() {
#[test] #[test]
fn test_SRV() { fn test_SRV() {
// local name // local name
check(types::SRV, check(
types::SRV,
"10 10 5060 sip.rec.test.", "10 10 5060 sip.rec.test.",
None, None,
b"\x00\x0a\x00\x0a\x13\xc4\x03sip\x03rec\x04test\x00", b"\x00\x0a\x00\x0a\x13\xc4\x03sip\x03rec\x04test\x00",
); );
// non-local name // non-local name
check(types::SRV, check(
types::SRV,
"10 10 5060 sip.example.com.", "10 10 5060 sip.example.com.",
None, None,
b"\x00\x0a\x00\x0a\x13\xc4\x03sip\x07example\x03com\x00", b"\x00\x0a\x00\x0a\x13\xc4\x03sip\x07example\x03com\x00",
); );
// root name // root name
check(types::SRV, check(
types::SRV,
"10 10 5060 .", "10 10 5060 .",
None, None,
b"\x00\x0a\x00\x0a\x13\xc4\x00", b"\x00\x0a\x00\x0a\x13\xc4\x00",
@ -457,12 +453,14 @@ fn test_SRV() {
#[test] #[test]
fn test_NAPTR() { fn test_NAPTR() {
check(types::NAPTR, check(
types::NAPTR,
"100 10 \"\" \"\" \"/urn:cid:.+@([^\\\\.]+\\\\.)(.*)$/\\\\2/i\" .", "100 10 \"\" \"\" \"/urn:cid:.+@([^\\\\.]+\\\\.)(.*)$/\\\\2/i\" .",
None, None,
b"\x00\x64\x00\x0a\x00\x00\x20/urn:cid:.+@([^\\.]+\\.)(.*)$/\\2/i\x00", b"\x00\x64\x00\x0a\x00\x00\x20/urn:cid:.+@([^\\.]+\\.)(.*)$/\\2/i\x00",
); );
check(types::NAPTR, check(
types::NAPTR,
"100 50 \"s\" \"http+I2L+I2C+I2R\" \"\" _http._tcp.rec.test.", "100 50 \"s\" \"http+I2L+I2C+I2R\" \"\" _http._tcp.rec.test.",
None, None,
b"\x00\x64\x00\x32\x01s\x10http+I2L+I2C+I2R\x00\x05_http\x04_tcp\x03rec\x04test\x00", b"\x00\x64\x00\x32\x01s\x10http+I2L+I2C+I2R\x00\x05_http\x04_tcp\x03rec\x04test\x00",
@ -471,7 +469,8 @@ fn test_NAPTR() {
#[test] #[test]
fn test_KX() { fn test_KX() {
check(types::KX, check(
types::KX,
"10 mail.rec.test.", "10 mail.rec.test.",
None, None,
b"\x00\x0a\x04mail\x03rec\x04test\x00", b"\x00\x0a\x04mail\x03rec\x04test\x00",
@ -507,13 +506,15 @@ fn test_DS() {
#[test] #[test]
fn test_SSHFP() { fn test_SSHFP() {
check(types::SSHFP, check(
types::SSHFP,
"1 1 aa65e3415a50d9b3519c2b17aceb815fc2538d88", "1 1 aa65e3415a50d9b3519c2b17aceb815fc2538d88",
None, None,
b"\x01\x01\xaa\x65\xe3\x41\x5a\x50\xd9\xb3\x51\x9c\x2b\x17\xac\xeb\x81\x5f\xc2\x53\x8d\x88", b"\x01\x01\xaa\x65\xe3\x41\x5a\x50\xd9\xb3\x51\x9c\x2b\x17\xac\xeb\x81\x5f\xc2\x53\x8d\x88",
); );
// as per RFC4025 // as per RFC4025
check(types::SSHFP, check(
types::SSHFP,
"1 1 aa65e3415a50d9b3519c2b17aceb815fc253 8d88", "1 1 aa65e3415a50d9b3519c2b17aceb815fc253 8d88",
Some("1 1 aa65e3415a50d9b3519c2b17aceb815fc2538d88"), Some("1 1 aa65e3415a50d9b3519c2b17aceb815fc2538d88"),
b"\x01\x01\xaa\x65\xe3\x41\x5a\x50\xd9\xb3\x51\x9c\x2b\x17\xac\xeb\x81\x5f\xc2\x53\x8d\x88", b"\x01\x01\xaa\x65\xe3\x41\x5a\x50\xd9\xb3\x51\x9c\x2b\x17\xac\xeb\x81\x5f\xc2\x53\x8d\x88",
@ -523,22 +524,20 @@ fn test_SSHFP() {
#[test] #[test]
fn test_IPSECKEY() { fn test_IPSECKEY() {
// as per RFC4025 // as per RFC4025
check(types::IPSECKEY, check(types::IPSECKEY, "255 0 0", None, b"\xff\x00\x00");
"255 0 0",
None,
b"\xff\x00\x00",
);
check(types::IPSECKEY, check(types::IPSECKEY,
"255 0 1 V19hwufL6LJARVIxzHDyGdvZ7dbQE0Kyl18yPIWj/sbCcsBbz7zO6Q2qgdzmWI3OvGNne2nxflhorhefKIMsUg==", "255 0 1 V19hwufL6LJARVIxzHDyGdvZ7dbQE0Kyl18yPIWj/sbCcsBbz7zO6Q2qgdzmWI3OvGNne2nxflhorhefKIMsUg==",
None, None,
b"\xff\x00\x01\x57\x5f\x61\xc2\xe7\xcb\xe8\xb2\x40\x45\x52\x31\xcc\x70\xf2\x19\xdb\xd9\xed\xd6\xd0\x13\x42\xb2\x97\x5f\x32\x3c\x85\xa3\xfe\xc6\xc2\x72\xc0\x5b\xcf\xbc\xce\xe9\x0d\xaa\x81\xdc\xe6\x58\x8d\xce\xbc\x63\x67\x7b\x69\xf1\x7e\x58\x68\xae\x17\x9f\x28\x83\x2c\x52", b"\xff\x00\x01\x57\x5f\x61\xc2\xe7\xcb\xe8\xb2\x40\x45\x52\x31\xcc\x70\xf2\x19\xdb\xd9\xed\xd6\xd0\x13\x42\xb2\x97\x5f\x32\x3c\x85\xa3\xfe\xc6\xc2\x72\xc0\x5b\xcf\xbc\xce\xe9\x0d\xaa\x81\xdc\xe6\x58\x8d\xce\xbc\x63\x67\x7b\x69\xf1\x7e\x58\x68\xae\x17\x9f\x28\x83\x2c\x52",
); );
check(types::IPSECKEY, check(
types::IPSECKEY,
"255 1 0 127.0.0.1", "255 1 0 127.0.0.1",
None, None,
b"\xff\x01\x00\x7f\x00\x00\x01", b"\xff\x01\x00\x7f\x00\x00\x01",
); );
check(types::IPSECKEY, check(
types::IPSECKEY,
"255 2 0 fe80::250:56ff:fe9b:114", "255 2 0 fe80::250:56ff:fe9b:114",
None, None,
b"\xff\x02\x00\xFE\x80\x00\x00\x00\x00\x00\x00\x02\x50\x56\xFF\xFE\x9B\x01\x14", b"\xff\x02\x00\xFE\x80\x00\x00\x00\x00\x00\x00\x02\x50\x56\xFF\xFE\x9B\x01\x14",
@ -573,7 +572,8 @@ fn test_RRSIG() {
#[test] #[test]
fn test_NSEC() { fn test_NSEC() {
check(types::NSEC, check(
types::NSEC,
"a.rec.test. A NS SOA MX AAAA RRSIG NSEC DNSKEY", "a.rec.test. A NS SOA MX AAAA RRSIG NSEC DNSKEY",
None, None,
b"\x01a\x03rec\x04test\x00\x00\x07\x62\x01\x00\x08\x00\x03\x80", b"\x01a\x03rec\x04test\x00\x00\x07\x62\x01\x00\x08\x00\x03\x80",
@ -619,7 +619,8 @@ fn test_NSEC3() {
#[test] #[test]
fn test_NSEC3PARAM() { fn test_NSEC3PARAM() {
check(types::NSEC3PARAM, check(
types::NSEC3PARAM,
"1 0 1 f00b", "1 0 1 f00b",
None, None,
b"\x01\x00\x00\x01\x02\xf0\x0b", b"\x01\x00\x00\x01\x02\xf0\x0b",
@ -751,7 +752,8 @@ fn test_OPENPGPKEY() {
#[test] #[test]
fn test_SPF() { fn test_SPF() {
check(types::SPF, check(
types::SPF,
"\"v=spf1 a:mail.rec.test ~all\"", "\"v=spf1 a:mail.rec.test ~all\"",
None, None,
b"\x1bv=spf1 a:mail.rec.test ~all", b"\x1bv=spf1 a:mail.rec.test ~all",
@ -760,7 +762,8 @@ fn test_SPF() {
#[test] #[test]
fn test_EUI48() { fn test_EUI48() {
check(types::EUI48, check(
types::EUI48,
"00-11-22-33-44-55", "00-11-22-33-44-55",
None, None,
b"\x00\x11\x22\x33\x44\x55", b"\x00\x11\x22\x33\x44\x55",
@ -769,7 +772,8 @@ fn test_EUI48() {
#[test] #[test]
fn test_EUI64() { fn test_EUI64() {
check(types::EUI64, check(
types::EUI64,
"00-11-22-33-44-55-66-77", "00-11-22-33-44-55-66-77",
None, None,
b"\x00\x11\x22\x33\x44\x55\x66\x77", b"\x00\x11\x22\x33\x44\x55\x66\x77",
@ -778,7 +782,8 @@ fn test_EUI64() {
#[test] #[test]
fn test_TKEY() { fn test_TKEY() {
check(types::TKEY, check(
types::TKEY,
"gss-tsig. 12345 12345 3 21 4 dGVzdA== 4 dGVzdA==", "gss-tsig. 12345 12345 3 21 4 dGVzdA== 4 dGVzdA==",
None, None,
b"\x08gss-tsig\x00\x00\x00\x30\x39\x00\x00\x30\x39\x00\x03\x00\x15\x00\x04test\x00\x04test", b"\x08gss-tsig\x00\x00\x00\x30\x39\x00\x00\x30\x39\x00\x03\x00\x15\x00\x04test\x00\x04test",
@ -815,7 +820,8 @@ fn test_URI() {
#[test] #[test]
fn test_CAA() { fn test_CAA() {
check(types::CAA, check(
types::CAA,
"0 issue \"example.net\"", "0 issue \"example.net\"",
None, None,
b"\x00\x05\x69\x73\x73\x75\x65\x65\x78\x61\x6d\x70\x6c\x65\x2e\x6e\x65\x74", b"\x00\x05\x69\x73\x73\x75\x65\x65\x78\x61\x6d\x70\x6c\x65\x2e\x6e\x65\x74",
@ -835,10 +841,14 @@ fn test_DLV() {
fn test_TYPE65226() { fn test_TYPE65226() {
let d1 = text::parse_with("\\# 3 414243", |data| { let d1 = text::parse_with("\\# 3 414243", |data| {
super::UnknownRecord::dns_parse(types::Type(65226), data) super::UnknownRecord::dns_parse(types::Type(65226), data)
}).unwrap(); })
.unwrap();
let d2 = super::UnknownRecord::new(types::Type(65226), Bytes::from_static(b"\x41\x42\x43")); let d2 = super::UnknownRecord::new(types::Type(65226), Bytes::from_static(b"\x41\x42\x43"));
assert_eq!(d1, d2); assert_eq!(d1, d2);
assert_eq!(d1.text().unwrap(), ("TYPE65226".into(), "\\# 3 414243".into())); assert_eq!(
d1.text().unwrap(),
("TYPE65226".into(), "\\# 3 414243".into())
);
} }
fn check_invalid_zone(q: Type, text_input: &str) { fn check_invalid_zone(q: Type, text_input: &str) {
@ -848,9 +858,7 @@ fn check_invalid_zone(q: Type, text_input: &str) {
context.set_record_type(q); context.set_record_type(q);
context.set_last_ttl(3600); context.set_last_ttl(3600);
text::parse_with(text_input, |data| { text::parse_with(text_input, |data| registry::parse_rr_data(&context, data)).unwrap_err();
registry::parse_rr_data(&context, data)
}).unwrap_err();
} }
fn check_invalid_wire(q: Type, raw: &'static [u8]) { fn check_invalid_wire(q: Type, raw: &'static [u8]) {
@ -873,16 +881,25 @@ fn test_invalid_data_checks() {
check_invalid_zone(types::AAAA, "23:00"); // time when this test was written check_invalid_zone(types::AAAA, "23:00"); // time when this test was written
check_invalid_zone(types::AAAA, "23:00::15::43"); // double compression check_invalid_zone(types::AAAA, "23:00::15::43"); // double compression
check_invalid_zone(types::AAAA, "2a23:00::15::"); // ditto check_invalid_zone(types::AAAA, "2a23:00::15::"); // ditto
check_invalid_wire(types::AAAA, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"); // truncated wire value check_invalid_wire(
// empty label, must be broken types::AAAA,
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff",
); // truncated wire value
// empty label, must be broken
check_invalid_zone(types::CNAME, "name..example.com."); check_invalid_zone(types::CNAME, "name..example.com.");
// overly large label (64), must be broken // overly large label (64), must be broken
check_invalid_zone(types::CNAME, "1234567890123456789012345678901234567890123456789012345678901234.example.com."); check_invalid_zone(
// local overly large name (256), must be broken types::CNAME,
"1234567890123456789012345678901234567890123456789012345678901234.example.com.",
);
// local overly large name (256), must be broken
check_invalid_zone(types::CNAME, "123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.12345678901234567890123456789012345678901234567890123.rec.test."); check_invalid_zone(types::CNAME, "123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.12345678901234567890123456789012345678901234567890123.rec.test.");
// non-local overly large name (256), must be broken // non-local overly large name (256), must be broken
check_invalid_zone(types::CNAME, "123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.12345678901234567890123456789012345678901234567890123456789012."); check_invalid_zone(types::CNAME, "123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.123456789012345678901234567890123456789012345678901234567890123.12345678901234567890123456789012345678901234567890123456789012.");
check_invalid_zone(types::SOA, "ns.rec.test hostmaster.test.rec 20130512010 3600 3600 604800 120"); // too long serial check_invalid_zone(
types::SOA,
"ns.rec.test hostmaster.test.rec 20130512010 3600 3600 604800 120",
); // too long serial
} }
#[test] #[test]
@ -895,42 +912,39 @@ fn test_opt_record_in() {
let opt = p.opt.unwrap().unwrap(); let opt = p.opt.unwrap().unwrap();
assert_eq!(opt.udp_payload_size, 1280); assert_eq!(opt.udp_payload_size, 1280);
assert_eq!(opt.options, vec![ assert_eq!(
DnsOption::NSID(Bytes::from_static(b"powerdns")), opt.options,
]); vec![DnsOption::NSID(Bytes::from_static(b"powerdns")),]
);
} }
#[test] #[test]
fn test_opt_record_out() { fn test_opt_record_out() {
let mut p = DnsPacket{ let mut p = DnsPacket {
id: 0xf001, id: 0xf001,
flags: DnsHeaderFlags { flags: DnsHeaderFlags {
recursion_desired: true, recursion_desired: true,
.. Default::default() ..Default::default()
}, },
question: vec![ question: vec![Question {
Question {
qname: "www.powerdns.com.".parse().unwrap(), qname: "www.powerdns.com.".parse().unwrap(),
qtype: types::A, qtype: types::A,
qclass: classes::IN, qclass: classes::IN,
} }],
], answer: vec![Resource {
answer: vec![
Resource {
name: "www.powerdns.com.".parse().unwrap(), name: "www.powerdns.com.".parse().unwrap(),
class: classes::IN, class: classes::IN,
ttl: 16, ttl: 16,
data: Box::new(A { addr: "127.0.0.1".parse().unwrap() }), data: Box::new(A {
} addr: "127.0.0.1".parse().unwrap(),
], }),
}],
opt: Some(Ok(Opt { opt: Some(Ok(Opt {
udp_payload_size: 1280, udp_payload_size: 1280,
options: vec![ options: vec![DnsOption::NSID(Bytes::from_static(b"powerdns"))],
DnsOption::NSID(Bytes::from_static(b"powerdns")), ..Default::default()
],
.. Default::default()
})), })),
.. Default::default() ..Default::default()
}; };
assert_eq!( assert_eq!(

View File

@ -4,16 +4,16 @@ use std::collections::HashMap;
use std::io::Cursor; use std::io::Cursor;
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::common_types::{Class, Type, types}; use crate::common_types::{types, Class, Type};
use crate::errors::*; use crate::errors::*;
use crate::records::structs; use crate::records::structs;
use crate::ser::{RRData, StaticRRData};
use crate::ser::text::DnsTextContext; use crate::ser::text::DnsTextContext;
use crate::ser::{RRData, StaticRRData};
// this should be enough for registered names // this should be enough for registered names
const TYPE_NAME_MAX_LEN: usize = 16; const TYPE_NAME_MAX_LEN: usize = 16;
lazy_static::lazy_static!{ lazy_static::lazy_static! {
static ref REGISTRY: Registry = Registry::init(); static ref REGISTRY: Registry = Registry::init();
} }
@ -22,7 +22,9 @@ fn registry() -> &'static Registry {
} }
pub(crate) fn lookup_type_name(name: &str) -> Option<Type> { pub(crate) fn lookup_type_name(name: &str) -> Option<Type> {
if name.len() >= TYPE_NAME_MAX_LEN { return None; } if name.len() >= TYPE_NAME_MAX_LEN {
return None;
}
let mut name_buf_storage = [0u8; TYPE_NAME_MAX_LEN]; let mut name_buf_storage = [0u8; TYPE_NAME_MAX_LEN];
let name_buf = &mut name_buf_storage[..name.len()]; let name_buf = &mut name_buf_storage[..name.len()];
name_buf.copy_from_slice(name.as_bytes()); name_buf.copy_from_slice(name.as_bytes());
@ -42,7 +44,12 @@ pub fn known_name_to_type(name: &str) -> Option<Type> {
Some(t) Some(t)
} }
pub fn deserialize_rr_data(ttl: u32, rr_class: Class, rr_type: Type, data: &mut Cursor<Bytes>) -> Result<Box<dyn RRData>> { pub fn deserialize_rr_data(
ttl: u32,
rr_class: Class,
rr_type: Type,
data: &mut Cursor<Bytes>,
) -> Result<Box<dyn RRData>> {
let registry = registry(); let registry = registry();
match registry.type_parser.get(&rr_type) { match registry.type_parser.get(&rr_type) {
Some(p) => p.deserialize_rr_data(ttl, rr_class, rr_type, data), Some(p) => p.deserialize_rr_data(ttl, rr_class, rr_type, data),
@ -80,13 +87,25 @@ trait RRDataTypeParse: 'static {
TypeId::of::<Self>() TypeId::of::<Self>()
} }
fn deserialize_rr_data(&self, ttl: u32, rr_class: Class, rr_type: Type, data: &mut Cursor<Bytes>) -> Result<Box<dyn RRData>>; fn deserialize_rr_data(
&self,
ttl: u32,
rr_class: Class,
rr_type: Type,
data: &mut Cursor<Bytes>,
) -> Result<Box<dyn RRData>>;
fn parse_rr_data(&self, context: &DnsTextContext, data: &mut &str) -> Result<Box<dyn RRData>>; fn parse_rr_data(&self, context: &DnsTextContext, data: &mut &str) -> Result<Box<dyn RRData>>;
} }
impl<T: RRData + 'static> RRDataTypeParse for TagRRDataType<T> { impl<T: RRData + 'static> RRDataTypeParse for TagRRDataType<T> {
fn deserialize_rr_data(&self, ttl: u32, rr_class: Class, rr_type: Type, data: &mut Cursor<Bytes>) -> Result<Box<dyn RRData>> { fn deserialize_rr_data(
&self,
ttl: u32,
rr_class: Class,
rr_type: Type,
data: &mut Cursor<Bytes>,
) -> Result<Box<dyn RRData>> {
T::deserialize_rr_data(ttl, rr_class, rr_type, data).map(|d| Box::new(d) as _) T::deserialize_rr_data(ttl, rr_class, rr_type, data).map(|d| Box::new(d) as _)
} }
@ -122,8 +141,8 @@ impl Registry {
r.register_known::<structs::MB>(); r.register_known::<structs::MB>();
r.register_known::<structs::MG>(); r.register_known::<structs::MG>();
r.register_known::<structs::MR>(); r.register_known::<structs::MR>();
r.register_unknown("NULL" , types::NULL); r.register_unknown("NULL", types::NULL);
r.register_unknown("WKS" , types::WKS); r.register_unknown("WKS", types::WKS);
r.register_known::<structs::PTR>(); r.register_known::<structs::PTR>();
r.register_known::<structs::HINFO>(); r.register_known::<structs::HINFO>();
r.register_known::<structs::MINFO>(); r.register_known::<structs::MINFO>();
@ -131,10 +150,10 @@ impl Registry {
r.register_known::<structs::TXT>(); r.register_known::<structs::TXT>();
r.register_known::<structs::RP>(); r.register_known::<structs::RP>();
r.register_known::<structs::AFSDB>(); r.register_known::<structs::AFSDB>();
r.register_unknown("X25" , types::X25); r.register_unknown("X25", types::X25);
r.register_unknown("ISDN" , types::ISDN); r.register_unknown("ISDN", types::ISDN);
r.register_known::<structs::RT>(); r.register_known::<structs::RT>();
r.register_unknown("NSAP" , types::NSAP); r.register_unknown("NSAP", types::NSAP);
r.register_known::<structs::NSAP_PTR>(); r.register_known::<structs::NSAP_PTR>();
r.register_known::<structs::SIG>(); r.register_known::<structs::SIG>();
r.register_known::<structs::KEY>(); r.register_known::<structs::KEY>();
@ -143,17 +162,17 @@ impl Registry {
r.register_known::<structs::AAAA>(); r.register_known::<structs::AAAA>();
r.register_known::<structs::LOC>(); r.register_known::<structs::LOC>();
r.register_known::<structs::NXT>(); r.register_known::<structs::NXT>();
r.register_unknown("EID" , types::EID); r.register_unknown("EID", types::EID);
r.register_unknown("NIMLOC" , types::NIMLOC); r.register_unknown("NIMLOC", types::NIMLOC);
r.register_known::<structs::SRV>(); r.register_known::<structs::SRV>();
r.register_unknown("ATMA" , types::ATMA); r.register_unknown("ATMA", types::ATMA);
r.register_known::<structs::NAPTR>(); r.register_known::<structs::NAPTR>();
r.register_known::<structs::KX>(); r.register_known::<structs::KX>();
r.register_known::<structs::CERT>(); r.register_known::<structs::CERT>();
r.register_known::<structs::A6>(); r.register_known::<structs::A6>();
r.register_known::<structs::DNAME>(); r.register_known::<structs::DNAME>();
r.register_unknown("SINK" , types::SINK); r.register_unknown("SINK", types::SINK);
r.register_unknown("OPT" , types::OPT); r.register_unknown("OPT", types::OPT);
r.register_known::<structs::APL>(); r.register_known::<structs::APL>();
r.register_known::<structs::DS>(); r.register_known::<structs::DS>();
r.register_known::<structs::SSHFP>(); r.register_known::<structs::SSHFP>();
@ -166,39 +185,39 @@ impl Registry {
r.register_known::<structs::NSEC3PARAM>(); r.register_known::<structs::NSEC3PARAM>();
r.register_known::<structs::TLSA>(); r.register_known::<structs::TLSA>();
r.register_known::<structs::SMIMEA>(); r.register_known::<structs::SMIMEA>();
r.register_unknown("HIP" , types::HIP); r.register_unknown("HIP", types::HIP);
r.register_known::<structs::NINFO>(); r.register_known::<structs::NINFO>();
r.register_known::<structs::RKEY>(); r.register_known::<structs::RKEY>();
r.register_unknown("TALINK" , types::TALINK); r.register_unknown("TALINK", types::TALINK);
r.register_known::<structs::CDS>(); r.register_known::<structs::CDS>();
r.register_known::<structs::CDNSKEY>(); r.register_known::<structs::CDNSKEY>();
r.register_known::<structs::OPENPGPKEY>(); r.register_known::<structs::OPENPGPKEY>();
r.register_unknown("CSYNC" , types::CSYNC); r.register_unknown("CSYNC", types::CSYNC);
r.register_unknown("ZONEMD" , types::ZONEMD); r.register_unknown("ZONEMD", types::ZONEMD);
r.register_known::<structs::SPF>(); r.register_known::<structs::SPF>();
r.register_unknown("UINFO" , types::UINFO); r.register_unknown("UINFO", types::UINFO);
r.register_unknown("UID" , types::UID); r.register_unknown("UID", types::UID);
r.register_unknown("GID" , types::GID); r.register_unknown("GID", types::GID);
r.register_unknown("UNSPEC" , types::UNSPEC); r.register_unknown("UNSPEC", types::UNSPEC);
r.register_unknown("NID" , types::NID); r.register_unknown("NID", types::NID);
r.register_unknown("L32" , types::L32); r.register_unknown("L32", types::L32);
r.register_unknown("L64" , types::L64); r.register_unknown("L64", types::L64);
r.register_unknown("LP" , types::LP); r.register_unknown("LP", types::LP);
r.register_known::<structs::EUI48>(); r.register_known::<structs::EUI48>();
r.register_known::<structs::EUI64>(); r.register_known::<structs::EUI64>();
r.register_known::<structs::TKEY>(); r.register_known::<structs::TKEY>();
r.register_known::<structs::TSIG>(); r.register_known::<structs::TSIG>();
r.register_unknown("IXFR" , types::IXFR); r.register_unknown("IXFR", types::IXFR);
r.register_unknown("AXFR" , types::AXFR); r.register_unknown("AXFR", types::AXFR);
r.register_unknown("MAILB" , types::MAILB); r.register_unknown("MAILB", types::MAILB);
r.register_unknown("MAILA" , types::MAILA); r.register_unknown("MAILA", types::MAILA);
r.register_unknown("ANY" , types::ANY); r.register_unknown("ANY", types::ANY);
r.register_known::<structs::URI>(); r.register_known::<structs::URI>();
r.register_known::<structs::CAA>(); r.register_known::<structs::CAA>();
r.register_unknown("AVC" , types::AVC); r.register_unknown("AVC", types::AVC);
r.register_unknown("DOA" , types::DOA); r.register_unknown("DOA", types::DOA);
r.register_unknown("AMTRELAY" , types::AMTRELAY); r.register_unknown("AMTRELAY", types::AMTRELAY);
r.register_unknown("TA" , types::TA); r.register_unknown("TA", types::TA);
r.register_known::<structs::DLV>(); r.register_known::<structs::DLV>();
r.register_known::<structs::ALIAS>(); r.register_known::<structs::ALIAS>();
@ -213,9 +232,20 @@ impl Registry {
self.prev_type = Some(rrtype); self.prev_type = Some(rrtype);
let mut name: String = name.into(); let mut name: String = name.into();
name.make_ascii_uppercase(); name.make_ascii_uppercase();
assert!(!name.starts_with("TYPE"), "must not register generic name: {}", name); assert!(
assert!(name.len() <= TYPE_NAME_MAX_LEN, "name too long: {} - maybe you need to increase TYPE_NAME_MAX_LEN", name); !name.starts_with("TYPE"),
assert!(self.names_to_type.insert(name.clone().into_bytes(), rrtype).is_none()); "must not register generic name: {}",
name
);
assert!(
name.len() <= TYPE_NAME_MAX_LEN,
"name too long: {} - maybe you need to increase TYPE_NAME_MAX_LEN",
name
);
assert!(self
.names_to_type
.insert(name.clone().into_bytes(), rrtype)
.is_none());
self.type_names.insert(rrtype, name); self.type_names.insert(rrtype, name);
} }
@ -227,12 +257,16 @@ impl Registry {
let rrtype = T::TYPE; let rrtype = T::TYPE;
let name = T::NAME; let name = T::NAME;
self.register_name(name, rrtype); self.register_name(name, rrtype);
self.type_parser.insert(rrtype, Box::new(TagRRDataType::<T>(PhantomData))); self.type_parser
.insert(rrtype, Box::new(TagRRDataType::<T>(PhantomData)));
} }
fn check_registration<T: StaticRRData + Sync + 'static>(&self) { fn check_registration<T: StaticRRData + Sync + 'static>(&self) {
assert_eq!(self.names_to_type.get(T::NAME.as_bytes()), Some(&T::TYPE)); assert_eq!(self.names_to_type.get(T::NAME.as_bytes()), Some(&T::TYPE));
let p: &dyn RRDataTypeParse = &**self.type_parser.get(&T::TYPE).expect("no parser registered"); let p: &dyn RRDataTypeParse = &**self
.type_parser
.get(&T::TYPE)
.expect("no parser registered");
let tid = TypeId::of::<TagRRDataType<T>>(); let tid = TypeId::of::<TagRRDataType<T>>();
assert_eq!(p.type_id(), tid); assert_eq!(p.type_id(), tid);
} }

View File

@ -354,28 +354,42 @@ pub struct DNSKEY {
impl DNSKEY { impl DNSKEY {
fn alg1_tag(&self) -> u16 { fn alg1_tag(&self) -> u16 {
let key: &[u8] = &self.public_key; let key: &[u8] = &self.public_key;
if key.is_empty() { return 0; } // not enough data if key.is_empty() {
return 0;
} // not enough data
let pkey; let pkey;
if 0 == key[0] { if 0 == key[0] {
// two-byte length encoding of exponent // two-byte length encoding of exponent
if key.len() < 3 { return 0; } // not enough data if key.len() < 3 {
return 0;
} // not enough data
let explen = ((key[1] as u16) << 8) + (key[2] as u16); let explen = ((key[1] as u16) << 8) + (key[2] as u16);
if explen < 256 { return 0; } // should have used shorter length encoding if explen < 256 {
if key.len() < 3 + (explen as usize) { return 0; } // not enough data return 0;
} // should have used shorter length encoding
if key.len() < 3 + (explen as usize) {
return 0;
} // not enough data
pkey = &key[3 + (explen as usize)..]; pkey = &key[3 + (explen as usize)..];
} else { } else {
// one-byte length encoding of exponent // one-byte length encoding of exponent
let explen = key[0]; let explen = key[0];
if key.len() < 1 + (explen as usize) { return 0; } // not enough data if key.len() < 1 + (explen as usize) {
return 0;
} // not enough data
pkey = &key[1 + (explen as usize)..]; pkey = &key[1 + (explen as usize)..];
} }
if pkey.len() < 3 { return 0; } // not enough data if pkey.len() < 3 {
return 0;
} // not enough data
((pkey[pkey.len() - 3] as u16) << 8) + (pkey[pkey.len() - 3] as u16) ((pkey[pkey.len() - 3] as u16) << 8) + (pkey[pkey.len() - 3] as u16)
} }
/// calculate key tag /// calculate key tag
pub fn tag(&self) -> u16 { pub fn tag(&self) -> u16 {
if self.algorithm == DnsSecAlgorithm::RSAMD5 { return self.alg1_tag(); } if self.algorithm == DnsSecAlgorithm::RSAMD5 {
return self.alg1_tag();
}
let mut sum = 0u32; let mut sum = 0u32;
@ -395,7 +409,9 @@ impl DNSKEY {
#[cfg(feature = "crypto")] #[cfg(feature = "crypto")]
pub fn build_ds(&self, zone: &DnsName, algs: &[DnsSecDigestAlgorithmKnown]) -> Result<Vec<DS>> { pub fn build_ds(&self, zone: &DnsName, algs: &[DnsSecDigestAlgorithmKnown]) -> Result<Vec<DS>> {
if algs.is_empty() { return Ok(Vec::new()); } if algs.is_empty() {
return Ok(Vec::new());
}
use crate::ser::packet::DnsPacketWriteContext; use crate::ser::packet::DnsPacketWriteContext;
let mut ctx = DnsPacketWriteContext::new(); let mut ctx = DnsPacketWriteContext::new();
@ -406,12 +422,15 @@ impl DNSKEY {
let key_tag = self.tag(); let key_tag = self.tag();
Ok(algs.iter().map(|alg| DS { Ok(algs
.iter()
.map(|alg| DS {
key_tag, key_tag,
algorithm: self.algorithm, algorithm: self.algorithm,
digest_type: (*alg).into(), digest_type: (*alg).into(),
digest: HexRemainingBlob::new(crate::crypto::ds_hash(*alg, &bin)), digest: HexRemainingBlob::new(crate::crypto::ds_hash(*alg, &bin)),
}).collect()) })
.collect())
} }
#[cfg(feature = "crypto")] #[cfg(feature = "crypto")]
@ -476,7 +495,6 @@ pub struct NINFO {
pub text: LongText, pub text: LongText,
} }
#[derive(Clone, PartialEq, Eq, Debug, DnsPacketData, DnsTextData, RRData)] #[derive(Clone, PartialEq, Eq, Debug, DnsPacketData, DnsTextData, RRData)]
#[RRClass(ANY)] #[RRClass(ANY)]
pub struct RKEY { pub struct RKEY {
@ -515,7 +533,8 @@ impl CDNSKEY {
protocol: self.protocol, protocol: self.protocol,
algorithm: self.algorithm, algorithm: self.algorithm,
public_key: self.public_key.clone(), public_key: self.public_key.clone(),
}.tag() }
.tag()
} }
} }
@ -585,7 +604,7 @@ pub struct EUI48 {
#[derive(Clone, PartialEq, Eq, Debug, DnsPacketData, DnsTextData, RRData)] #[derive(Clone, PartialEq, Eq, Debug, DnsPacketData, DnsTextData, RRData)]
#[RRClass(ANY)] #[RRClass(ANY)]
pub struct EUI64 { pub struct EUI64 {
pub addr: EUI64Addr pub addr: EUI64Addr,
} }
#[derive(Clone, PartialEq, Eq, Debug, DnsPacketData, DnsTextData, RRData)] #[derive(Clone, PartialEq, Eq, Debug, DnsPacketData, DnsTextData, RRData)]

View File

@ -1,15 +1,15 @@
use bytes::{Bytes, Buf};
use crate::common_types::classes; use crate::common_types::classes;
use failure::ResultExt;
use crate::records::structs; use crate::records::structs;
use crate::ser::{StaticRRData, packet, text};
use crate::ser::packet::DnsPacketData; use crate::ser::packet::DnsPacketData;
use crate::ser::{packet, text, StaticRRData};
use bytes::{Buf, Bytes};
use failure::ResultExt;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
fn rrdata_de<T>(data: &'static [u8]) -> crate::errors::Result<T> fn rrdata_de<T>(data: &'static [u8]) -> crate::errors::Result<T>
where where
T: StaticRRData T: StaticRRData,
{ {
let mut data = Cursor::new(Bytes::from_static(data)); let mut data = Cursor::new(Bytes::from_static(data));
let result = T::deserialize_rr_data(3600, classes::IN, T::TYPE, &mut data)?; let result = T::deserialize_rr_data(3600, classes::IN, T::TYPE, &mut data)?;
@ -19,20 +19,18 @@ where
fn rrdata_parse<T>(data: &str) -> crate::errors::Result<T> fn rrdata_parse<T>(data: &str) -> crate::errors::Result<T>
where where
T: StaticRRData T: StaticRRData,
{ {
let mut ctx = text::DnsTextContext::new(); let mut ctx = text::DnsTextContext::new();
ctx.set_zone_class(classes::IN); ctx.set_zone_class(classes::IN);
ctx.set_record_type(T::TYPE); ctx.set_record_type(T::TYPE);
ctx.set_last_ttl(3600); ctx.set_last_ttl(3600);
text::parse_with(data, |data| { text::parse_with(data, |data| T::dns_parse_rr_data(&ctx, data))
T::dns_parse_rr_data(&ctx, data)
})
} }
fn check<T>(txt: &str, data: &'static [u8]) -> crate::errors::Result<()> fn check<T>(txt: &str, data: &'static [u8]) -> crate::errors::Result<()>
where where
T: StaticRRData + fmt::Debug + PartialEq T: StaticRRData + fmt::Debug + PartialEq,
{ {
let d1: T = rrdata_de(data).context("couldn't parse binary record")?; let d1: T = rrdata_de(data).context("couldn't parse binary record")?;
let d2: T = rrdata_parse(txt).context("couldn't parse text record")?; let d2: T = rrdata_parse(txt).context("couldn't parse text record")?;
@ -42,7 +40,7 @@ where
fn check2<T>(txt: &str, data: &'static [u8], canon: &str) -> crate::errors::Result<()> fn check2<T>(txt: &str, data: &'static [u8], canon: &str) -> crate::errors::Result<()>
where where
T: StaticRRData + fmt::Debug + PartialEq T: StaticRRData + fmt::Debug + PartialEq,
{ {
let d1: T = rrdata_de(data).context("couldn't parse binary record")?; let d1: T = rrdata_de(data).context("couldn't parse binary record")?;
let d2: T = rrdata_parse(txt).context("couldn't parse text record")?; let d2: T = rrdata_parse(txt).context("couldn't parse text record")?;
@ -52,8 +50,18 @@ where
let d2_text = d2.text().unwrap(); let d2_text = d2.text().unwrap();
let canon_text = (T::NAME.to_owned(), canon.into()); let canon_text = (T::NAME.to_owned(), canon.into());
failure::ensure!(d1_text == canon_text, "re-formatted binary record not equal to canonical representation: {:?} != {:?}", d1_text, canon_text); failure::ensure!(
failure::ensure!(d2_text == canon_text, "re-formatted text record not equal to canonical representation: {:?} != {:?}", d2_text, canon_text); d1_text == canon_text,
"re-formatted binary record not equal to canonical representation: {:?} != {:?}",
d1_text,
canon_text
);
failure::ensure!(
d2_text == canon_text,
"re-formatted text record not equal to canonical representation: {:?} != {:?}",
d2_text,
canon_text
);
Ok(()) Ok(())
} }
@ -70,7 +78,7 @@ fn test_mx() {
fn test_txt_for<T>() fn test_txt_for<T>()
where where
T: StaticRRData + fmt::Debug + PartialEq T: StaticRRData + fmt::Debug + PartialEq,
{ {
// at least one "segment" (which could be empty) // at least one "segment" (which could be empty)
check2::<T>(r#" "" "#, b"", r#""""#).unwrap_err(); check2::<T>(r#" "" "#, b"", r#""""#).unwrap_err();
@ -85,7 +93,9 @@ where
{ {
let mut s = String::new(); let mut s = String::new();
s.push('"'); s.push('"');
for _ in 0..256 { s.push('a'); } for _ in 0..256 {
s.push('a');
}
s.push('"'); s.push('"');
rrdata_parse::<T>(&s).unwrap_err(); rrdata_parse::<T>(&s).unwrap_err();
} }
@ -107,8 +117,16 @@ fn test_ds() {
fn test_nsec() { fn test_nsec() {
check::<structs::NSEC>("foo.bar. ", b"\x03foo\x03bar\x00").unwrap(); check::<structs::NSEC>("foo.bar. ", b"\x03foo\x03bar\x00").unwrap();
check::<structs::NSEC>("foo.bar. A NS ", b"\x03foo\x03bar\x00\x00\x01\x60").unwrap(); check::<structs::NSEC>("foo.bar. A NS ", b"\x03foo\x03bar\x00\x00\x01\x60").unwrap();
check::<structs::NSEC>("foo.bar. A NS SOA MX AAAA RRSIG NSEC DNSKEY ", b"\x03foo\x03bar\x00\x00\x07\x62\x01\x00\x08\x00\x03\x80").unwrap(); check::<structs::NSEC>(
check::<structs::NSEC>("foo.bar. A NS TYPE256 TYPE65280 ", b"\x03foo\x03bar\x00\x00\x01\x60\x01\x01\x80\xff\x01\x80").unwrap(); "foo.bar. A NS SOA MX AAAA RRSIG NSEC DNSKEY ",
b"\x03foo\x03bar\x00\x00\x07\x62\x01\x00\x08\x00\x03\x80",
)
.unwrap();
check::<structs::NSEC>(
"foo.bar. A NS TYPE256 TYPE65280 ",
b"\x03foo\x03bar\x00\x00\x01\x60\x01\x01\x80\xff\x01\x80",
)
.unwrap();
} }
#[test] #[test]
@ -121,15 +139,27 @@ fn test_dnskey() {
#[test] #[test]
fn test_nsec3() { fn test_nsec3() {
check::<structs::NSEC3>("1 2 300 - vs", b"\x01\x02\x01\x2c\x00\x01\xff").unwrap(); check::<structs::NSEC3>("1 2 300 - vs", b"\x01\x02\x01\x2c\x00\x01\xff").unwrap();
check::<structs::NSEC3>("1 2 300 - vs A NS", b"\x01\x02\x01\x2c\x00\x01\xff\x00\x01\x60").unwrap(); check::<structs::NSEC3>(
check::<structs::NSEC3>("1 2 300 ab vs A NS", b"\x01\x02\x01\x2c\x01\xab\x01\xff\x00\x01\x60").unwrap(); "1 2 300 - vs A NS",
b"\x01\x02\x01\x2c\x00\x01\xff\x00\x01\x60",
)
.unwrap();
check::<structs::NSEC3>(
"1 2 300 ab vs A NS",
b"\x01\x02\x01\x2c\x01\xab\x01\xff\x00\x01\x60",
)
.unwrap();
// invalid base32 texts // invalid base32 texts
rrdata_parse::<structs::NSEC3>("1 2 300 - v").unwrap_err(); rrdata_parse::<structs::NSEC3>("1 2 300 - v").unwrap_err();
rrdata_parse::<structs::NSEC3>("1 2 300 - vv").unwrap_err(); rrdata_parse::<structs::NSEC3>("1 2 300 - vv").unwrap_err();
// invalid (empty) next-hashed values // invalid (empty) next-hashed values
packet::deserialize_with(Bytes::from_static(b"\x01\x02\x01\x2c\x00\x00"), structs::NSEC3::deserialize).unwrap_err(); packet::deserialize_with(
Bytes::from_static(b"\x01\x02\x01\x2c\x00\x00"),
structs::NSEC3::deserialize,
)
.unwrap_err();
} }
#[test] #[test]
@ -151,5 +181,9 @@ fn test_apl() {
check::<structs::APL>("!1:0.0.0.0/0", b"\x00\x01\x00\x80").unwrap(); check::<structs::APL>("!1:0.0.0.0/0", b"\x00\x01\x00\x80").unwrap();
check::<structs::APL>("2:::/0", b"\x00\x02\x00\x00").unwrap(); check::<structs::APL>("2:::/0", b"\x00\x02\x00\x00").unwrap();
check::<structs::APL>("!2:::/0", b"\x00\x02\x00\x80").unwrap(); check::<structs::APL>("!2:::/0", b"\x00\x02\x00\x80").unwrap();
check::<structs::APL>("1:192.0.2.0/24 !2:2001:db8::/32", b"\x00\x01\x18\x03\xc0\x00\x02\x00\x02\x20\x84\x20\x01\x0d\xb8").unwrap(); check::<structs::APL>(
"1:192.0.2.0/24 !2:2001:db8::/32",
b"\x00\x01\x18\x03\xc0\x00\x02\x00\x02\x20\x84\x20\x01\x0d\xb8",
)
.unwrap();
} }

View File

@ -1,11 +1,11 @@
use bytes::{Bytes, BufMut};
use crate::common_types::*;
use crate::common_types::binary::HEXLOWER_PERMISSIVE_ALLOW_WS; use crate::common_types::binary::HEXLOWER_PERMISSIVE_ALLOW_WS;
use crate::common_types::*;
use crate::errors::*; use crate::errors::*;
use failure::{ResultExt, Fail}; use crate::ser::packet::{remaining_bytes, DnsPacketWriteContext};
use crate::ser::packet::{DnsPacketWriteContext, remaining_bytes}; use crate::ser::text::{next_field, DnsTextContext, DnsTextFormatter};
use crate::ser::{RRData, RRDataPacket, RRDataText}; use crate::ser::{RRData, RRDataPacket, RRDataText};
use crate::ser::text::{DnsTextFormatter, DnsTextContext, next_field}; use bytes::{BufMut, Bytes};
use failure::{Fail, ResultExt};
use std::borrow::Cow; use std::borrow::Cow;
use std::fmt; use std::fmt;
use std::io::Cursor; use std::io::Cursor;
@ -38,9 +38,15 @@ impl UnknownRecord {
let field = next_field(data).context("generic record data length")?; let field = next_field(data).context("generic record data length")?;
let len: usize = field.parse()?; let len: usize = field.parse()?;
let result = HEXLOWER_PERMISSIVE_ALLOW_WS.decode(data.as_bytes()) let result = HEXLOWER_PERMISSIVE_ALLOW_WS
.decode(data.as_bytes())
.with_context(|e| e.context(format!("invalid hex: {:?}", data)))?; .with_context(|e| e.context(format!("invalid hex: {:?}", data)))?;
failure::ensure!(len == result.len(), "length {} doesn't match length of encoded data {}", len, result.len()); failure::ensure!(
len == result.len(),
"length {} doesn't match length of encoded data {}",
len,
result.len()
);
*data = ""; // read all data *data = ""; // read all data
Ok(UnknownRecord { Ok(UnknownRecord {
@ -51,7 +57,12 @@ impl UnknownRecord {
} }
impl RRDataPacket for UnknownRecord { impl RRDataPacket for UnknownRecord {
fn deserialize_rr_data(_ttl: u32, _rr_class: Class, rr_type: Type, data: &mut Cursor<Bytes>) -> Result<Self> { fn deserialize_rr_data(
_ttl: u32,
_rr_class: Class,
rr_type: Type,
data: &mut Cursor<Bytes>,
) -> Result<Self> {
UnknownRecord::deserialize(rr_type, data) UnknownRecord::deserialize(rr_type, data)
} }
@ -59,7 +70,11 @@ impl RRDataPacket for UnknownRecord {
self.rr_type self.rr_type
} }
fn serialize_rr_data(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize_rr_data(
&self,
_context: &mut DnsPacketWriteContext,
packet: &mut Vec<u8>,
) -> Result<()> {
packet.reserve(self.raw.len()); packet.reserve(self.raw.len());
packet.put_slice(&self.raw); packet.put_slice(&self.raw);
Ok(()) Ok(())
@ -81,7 +96,12 @@ impl RRDataText for UnknownRecord {
/// this must never fail unless the underlying buffer fails. /// this must never fail unless the underlying buffer fails.
fn dns_format_rr_data(&self, f: &mut DnsTextFormatter) -> fmt::Result { fn dns_format_rr_data(&self, f: &mut DnsTextFormatter) -> fmt::Result {
write!(f, "\\# {} {}", self.raw.len(), HEXLOWER_PERMISSIVE_ALLOW_WS.encode(&self.raw)) write!(
f,
"\\# {} {}",
self.raw.len(),
HEXLOWER_PERMISSIVE_ALLOW_WS.encode(&self.raw)
)
} }
fn rr_type_txt(&self) -> Cow<'static, str> { fn rr_type_txt(&self) -> Cow<'static, str> {

View File

@ -1,10 +1,10 @@
use bytes::{Bytes, Buf, BufMut};
use crate::errors::*;
use crate::common_types::*; use crate::common_types::*;
use failure::ResultExt; use crate::errors::*;
use crate::ser::packet::{get_blob, remaining_bytes, DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{next_field, DnsTextContext, DnsTextData, DnsTextFormatter};
use crate::ser::RRData; use crate::ser::RRData;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes, get_blob}; use bytes::{Buf, BufMut, Bytes};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field}; use failure::ResultExt;
use std::fmt; use std::fmt;
use std::io::Read; use std::io::Read;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
@ -17,10 +17,7 @@ use std::net::{Ipv4Addr, Ipv6Addr};
#[RRClass(ANY)] #[RRClass(ANY)]
pub enum LOC { pub enum LOC {
Version0(LOC0), Version0(LOC0),
UnknownVersion{ UnknownVersion { version: u8, data: Bytes },
version: u8,
data: Bytes,
},
} }
impl DnsPacketData for LOC { impl DnsPacketData for LOC {
@ -29,7 +26,7 @@ impl DnsPacketData for LOC {
if 0 == version { if 0 == version {
Ok(LOC::Version0(DnsPacketData::deserialize(data)?)) Ok(LOC::Version0(DnsPacketData::deserialize(data)?))
} else { } else {
Ok(LOC::UnknownVersion{ Ok(LOC::UnknownVersion {
version: version, version: version,
data: remaining_bytes(data), data: remaining_bytes(data),
}) })
@ -43,7 +40,7 @@ impl DnsPacketData for LOC {
packet.put_u8(0); packet.put_u8(0);
l0.serialize(context, packet) l0.serialize(context, packet)
}, },
LOC::UnknownVersion{version, ref data} => { LOC::UnknownVersion { version, ref data } => {
packet.reserve(data.len() + 1); packet.reserve(data.len() + 1);
packet.put_u8(version); packet.put_u8(version);
packet.put_slice(data); packet.put_slice(data);
@ -56,53 +53,85 @@ impl DnsPacketData for LOC {
impl DnsTextData for LOC { impl DnsTextData for LOC {
fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> Result<Self> { fn dns_parse(_context: &DnsTextContext, data: &mut &str) -> Result<Self> {
let degrees_latitude = next_field(data)?.parse::<u8>()?; let degrees_latitude = next_field(data)?.parse::<u8>()?;
failure::ensure!(degrees_latitude <= 90, "degrees latitude out of range: {}", degrees_latitude); failure::ensure!(
degrees_latitude <= 90,
"degrees latitude out of range: {}",
degrees_latitude
);
let mut minutes_latitude = 0; let mut minutes_latitude = 0;
let mut seconds_latitude = 0.0; let mut seconds_latitude = 0.0;
let mut field = next_field(data)?; let mut field = next_field(data)?;
if field != "N" && field != "n" && field != "S" && field != "s" { if field != "N" && field != "n" && field != "S" && field != "s" {
minutes_latitude = field.parse::<u8>()?; minutes_latitude = field.parse::<u8>()?;
failure::ensure!(minutes_latitude < 60, "minutes latitude out of range: {}", minutes_latitude); failure::ensure!(
minutes_latitude < 60,
"minutes latitude out of range: {}",
minutes_latitude
);
field = next_field(data)?; field = next_field(data)?;
if field != "N" && field != "n" && field != "S" && field != "s" { if field != "N" && field != "n" && field != "S" && field != "s" {
seconds_latitude = field.parse::<f32>()?; seconds_latitude = field.parse::<f32>()?;
failure::ensure!(seconds_latitude >= 0.0 && seconds_latitude < 60.0, "seconds latitude out of range: {}", seconds_latitude); failure::ensure!(
seconds_latitude >= 0.0 && seconds_latitude < 60.0,
"seconds latitude out of range: {}",
seconds_latitude
);
field = next_field(data)?; field = next_field(data)?;
} }
} }
let latitude_off = (3600_000 * degrees_latitude as u32) + (60_000 * minutes_latitude as u32) + (1_000.0 * seconds_latitude).round() as u32; let latitude_off = (3600_000 * degrees_latitude as u32)
+ (60_000 * minutes_latitude as u32)
+ (1_000.0 * seconds_latitude).round() as u32;
failure::ensure!(latitude_off <= 3600_000 * 180, "latitude out of range"); failure::ensure!(latitude_off <= 3600_000 * 180, "latitude out of range");
let latitude = match field { let latitude = match field {
"N"|"n" => 0x8000_0000 + latitude_off, "N" | "n" => 0x8000_0000 + latitude_off,
"S"|"s" => 0x8000_0000 - latitude_off, "S" | "s" => 0x8000_0000 - latitude_off,
_ => failure::bail!("invalid latitude orientation [NS]: {}", field), _ => failure::bail!("invalid latitude orientation [NS]: {}", field),
}; };
let degrees_longitude = next_field(data)?.parse::<u8>()?; let degrees_longitude = next_field(data)?.parse::<u8>()?;
failure::ensure!(degrees_longitude <= 180, "degrees longitude out of range: {}", degrees_longitude); failure::ensure!(
degrees_longitude <= 180,
"degrees longitude out of range: {}",
degrees_longitude
);
let mut minutes_longitude = 0; let mut minutes_longitude = 0;
let mut seconds_longitude = 0.0; let mut seconds_longitude = 0.0;
let mut field = next_field(data)?; let mut field = next_field(data)?;
if field != "E" && field != "e" && field != "W" && field != "w" { if field != "E" && field != "e" && field != "W" && field != "w" {
minutes_longitude = field.parse::<u8>()?; minutes_longitude = field.parse::<u8>()?;
failure::ensure!(minutes_longitude < 60, "minutes longitude out of range: {}", minutes_longitude); failure::ensure!(
minutes_longitude < 60,
"minutes longitude out of range: {}",
minutes_longitude
);
field = next_field(data)?; field = next_field(data)?;
if field != "E" && field != "e" && field != "W" && field != "w" { if field != "E" && field != "e" && field != "W" && field != "w" {
seconds_longitude = field.parse::<f32>()?; seconds_longitude = field.parse::<f32>()?;
failure::ensure!(seconds_longitude >= 0.0 && seconds_longitude < 60.0, "seconds longitude out of range: {}", seconds_longitude); failure::ensure!(
seconds_longitude >= 0.0 && seconds_longitude < 60.0,
"seconds longitude out of range: {}",
seconds_longitude
);
field = next_field(data)?; field = next_field(data)?;
} }
} }
let longitude_off = (3600_000 * degrees_longitude as u32) + (60_000 * minutes_longitude as u32) + (1_000.0 * seconds_longitude).round() as u32; let longitude_off = (3600_000 * degrees_longitude as u32)
+ (60_000 * minutes_longitude as u32)
+ (1_000.0 * seconds_longitude).round() as u32;
failure::ensure!(longitude_off <= 3600_000 * 180, "longitude out of range"); failure::ensure!(longitude_off <= 3600_000 * 180, "longitude out of range");
let longitude = match field { let longitude = match field {
"E"|"e" => 0x8000_0000 + longitude_off, "E" | "e" => 0x8000_0000 + longitude_off,
"W"|"w" => 0x8000_0000 - longitude_off, "W" | "w" => 0x8000_0000 - longitude_off,
_ => failure::bail!("invalid longitude orientation [EW]: {}", field), _ => failure::bail!("invalid longitude orientation [EW]: {}", field),
}; };
fn trim_unit_m(s: &str) -> &str { fn trim_unit_m(s: &str) -> &str {
if s.ends_with('m') { &s[..s.len()-1] } else { s } if s.ends_with('m') {
&s[..s.len() - 1]
} else {
s
}
} }
fn parse_precision(s: &str) -> Result<u8> { fn parse_precision(s: &str) -> Result<u8> {
@ -112,13 +141,23 @@ impl DnsTextData for LOC {
let mut dec_point = None; let mut dec_point = None;
for &b in s.as_bytes() { for &b in s.as_bytes() {
if b == b'.' { if b == b'.' {
failure::ensure!(dec_point.is_none(), "invalid precision (double decimal point): {:?}", s); failure::ensure!(
dec_point.is_none(),
"invalid precision (double decimal point): {:?}",
s
);
dec_point = Some(0); dec_point = Some(0);
continue; continue;
} }
failure::ensure!(b >= b'0' && b <= b'9', "invalid precision (invalid character): {:?}", s); failure::ensure!(
b >= b'0' && b <= b'9',
"invalid precision (invalid character): {:?}",
s
);
if let Some(ref mut dp) = dec_point { if let Some(ref mut dp) = dec_point {
if *dp == 2 { continue; } // ignore following digits if *dp == 2 {
continue;
} // ignore following digits
*dp += 1; *dp += 1;
} }
let d = b - b'0'; let d = b - b'0';
@ -138,7 +177,10 @@ impl DnsTextData for LOC {
Ok(field) => { Ok(field) => {
let f_altitude = trim_unit_m(field).parse::<f64>()?; let f_altitude = trim_unit_m(field).parse::<f64>()?;
let altitude = (f_altitude * 100.0 + 10000000.0).round() as i64; let altitude = (f_altitude * 100.0 + 10000000.0).round() as i64;
failure::ensure!(altitude > 0 && (altitude as u32) as i64 == altitude, "altitude out of range"); failure::ensure!(
altitude > 0 && (altitude as u32) as i64 == altitude,
"altitude out of range"
);
altitude as u32 altitude as u32
}, },
// standard requires the field, but the example parser doesn't.. // standard requires the field, but the example parser doesn't..
@ -160,7 +202,7 @@ impl DnsTextData for LOC {
Err(_) => 0x13, // 1e3 cm = 10m */ Err(_) => 0x13, // 1e3 cm = 10m */
}; };
Ok(LOC::Version0(LOC0{ Ok(LOC::Version0(LOC0 {
size, size,
horizontal_precision, horizontal_precision,
vertical_precision, vertical_precision,
@ -178,10 +220,14 @@ impl DnsTextData for LOC {
const MAX_LAT_OFFSET: u32 = 3600_000 * 180; const MAX_LAT_OFFSET: u32 = 3600_000 * 180;
const MAX_LON_OFFSET: u32 = 3600_000 * 180; const MAX_LON_OFFSET: u32 = 3600_000 * 180;
const LATLON_MID: u32 = 0x8000_0000; const LATLON_MID: u32 = 0x8000_0000;
if this.latitude < LATLON_MID - MAX_LAT_OFFSET || this.latitude > LATLON_MID + MAX_LAT_OFFSET { if this.latitude < LATLON_MID - MAX_LAT_OFFSET
|| this.latitude > LATLON_MID + MAX_LAT_OFFSET
{
return Err(fmt::Error); return Err(fmt::Error);
} }
if this.longitude < LATLON_MID - MAX_LON_OFFSET || this.longitude > LATLON_MID + MAX_LON_OFFSET { if this.longitude < LATLON_MID - MAX_LON_OFFSET
|| this.longitude > LATLON_MID + MAX_LON_OFFSET
{
return Err(fmt::Error); return Err(fmt::Error);
} }
@ -190,7 +236,10 @@ impl DnsTextData for LOC {
// if the leading digit is 0, the exponent must be 0 too. // if the leading digit is 0, the exponent must be 0 too.
(v > 0x00 && v < 0x10) || (v >> 4) > 9 || (v & 0xf) > 9 (v > 0x00 && v < 0x10) || (v >> 4) > 9 || (v & 0xf) > 9
} }
if is_invalid_prec(this.size) || is_invalid_prec(this.horizontal_precision) || is_invalid_prec(this.vertical_precision) { if is_invalid_prec(this.size)
|| is_invalid_prec(this.horizontal_precision)
|| is_invalid_prec(this.vertical_precision)
{
return Err(fmt::Error); return Err(fmt::Error);
} }
@ -227,7 +276,8 @@ impl DnsTextData for LOC {
write!(f, "{:0<width$}.00m", m, width = e as usize - 1) write!(f, "{:0<width$}.00m", m, width = e as usize - 1)
} else if e == 1 { } else if e == 1 {
write!(f, ".{}0m", m) write!(f, ".{}0m", m)
} else { // e == 0 } else {
// e == 0
write!(f, ".0{}m", m) write!(f, ".0{}m", m)
} }
} }
@ -262,8 +312,8 @@ pub struct A6 {
impl DnsPacketData for A6 { impl DnsPacketData for A6 {
fn deserialize(data: &mut ::std::io::Cursor<Bytes>) -> Result<Self> { fn deserialize(data: &mut ::std::io::Cursor<Bytes>) -> Result<Self> {
let prefix: u8 = DnsPacketData::deserialize(data) let prefix: u8 =
.context("failed parsing field A6::prefix")?; DnsPacketData::deserialize(data).context("failed parsing field A6::prefix")?;
failure::ensure!(prefix <= 128, "invalid A6::prefix {}", prefix); failure::ensure!(prefix <= 128, "invalid A6::prefix {}", prefix);
let suffix_offset = (prefix / 8) as usize; let suffix_offset = (prefix / 8) as usize;
debug_assert!(suffix_offset <= 16); debug_assert!(suffix_offset <= 16);
@ -310,15 +360,15 @@ impl DnsPacketData for A6 {
impl DnsTextData for A6 { impl DnsTextData for A6 {
fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result<Self> { fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result<Self> {
let prefix: u8 = DnsTextData::dns_parse(context, data) let prefix: u8 =
.context("failed parsing field A6::prefix")?; DnsTextData::dns_parse(context, data).context("failed parsing field A6::prefix")?;
failure::ensure!(prefix <= 128, "invalid A6::prefix {}", prefix); failure::ensure!(prefix <= 128, "invalid A6::prefix {}", prefix);
let suffix_offset = (prefix / 8) as usize; let suffix_offset = (prefix / 8) as usize;
debug_assert!(suffix_offset <= 16); debug_assert!(suffix_offset <= 16);
let suffix: Ipv6Addr = DnsTextData::dns_parse(context, data) let suffix: Ipv6Addr =
.context("failed parsing field A6::suffix")?; DnsTextData::dns_parse(context, data).context("failed parsing field A6::suffix")?;
// clear prefix bits // clear prefix bits
let mut suffix = suffix.octets(); let mut suffix = suffix.octets();
@ -332,8 +382,10 @@ impl DnsTextData for A6 {
let suffix = Ipv6Addr::from(suffix); let suffix = Ipv6Addr::from(suffix);
let prefix_name = if !data.is_empty() { let prefix_name = if !data.is_empty() {
Some(DnsTextData::dns_parse(context, data) Some(
.context("failed parsing field A6::prefix_name")?) DnsTextData::dns_parse(context, data)
.context("failed parsing field A6::prefix_name")?,
)
} else { } else {
None None
}; };
@ -376,22 +428,34 @@ impl DnsPacketData for APL {
fn deserialize(data: &mut ::std::io::Cursor<Bytes>) -> Result<Self> { fn deserialize(data: &mut ::std::io::Cursor<Bytes>) -> Result<Self> {
let mut items = Vec::new(); let mut items = Vec::new();
while data.has_remaining() { while data.has_remaining() {
let family: u16 = DnsPacketData::deserialize(data) let family: u16 =
.context("failed parsing APL::ADDRESSFAMILY")?; DnsPacketData::deserialize(data).context("failed parsing APL::ADDRESSFAMILY")?;
failure::ensure!(family == 1 || family == 2, "unknown APL::ADDRESSFAMILY {}", family); failure::ensure!(
let prefix: u8 = DnsPacketData::deserialize(data) family == 1 || family == 2,
.context("failed parsing field APL::PREFIX")?; "unknown APL::ADDRESSFAMILY {}",
let afd_length: u8 = DnsPacketData::deserialize(data) family
.context("failed parsing field APL::AFDLENGTH")?; );
let prefix: u8 =
DnsPacketData::deserialize(data).context("failed parsing field APL::PREFIX")?;
let afd_length: u8 =
DnsPacketData::deserialize(data).context("failed parsing field APL::AFDLENGTH")?;
let negation = 0 != (afd_length & 0x80); let negation = 0 != (afd_length & 0x80);
let afd_length = afd_length & 0x7f; let afd_length = afd_length & 0x7f;
let data = get_blob(data, afd_length as usize)?; let data = get_blob(data, afd_length as usize)?;
failure::ensure!(!data.ends_with(b"\0"), "APL::AFDPART ends with trailing zero"); failure::ensure!(
!data.ends_with(b"\0"),
"APL::AFDPART ends with trailing zero"
);
let address = if family == 1 { let address = if family == 1 {
failure::ensure!(prefix <= 32, "invalid APL::prefix {} for IPv4", prefix); failure::ensure!(prefix <= 32, "invalid APL::prefix {} for IPv4", prefix);
failure::ensure!((afd_length as u32) * 8 < (prefix as u32) + 7, "APL::AFDPART too long {} for prefix {}", afd_length, prefix); failure::ensure!(
(afd_length as u32) * 8 < (prefix as u32) + 7,
"APL::AFDPART too long {} for prefix {}",
afd_length,
prefix
);
assert!(afd_length <= 4); assert!(afd_length <= 4);
let mut buf = [0u8; 4]; let mut buf = [0u8; 4];
buf[..data.len()].copy_from_slice(&data); buf[..data.len()].copy_from_slice(&data);
@ -399,7 +463,12 @@ impl DnsPacketData for APL {
} else { } else {
assert!(family == 2); assert!(family == 2);
failure::ensure!(prefix <= 128, "invalid APL::prefix {} for IPv6", prefix); failure::ensure!(prefix <= 128, "invalid APL::prefix {} for IPv6", prefix);
failure::ensure!((afd_length as u32) * 8 < (prefix as u32) + 7, "AFD::AFDPART too long {} for prefix {}", afd_length, prefix); failure::ensure!(
(afd_length as u32) * 8 < (prefix as u32) + 7,
"AFD::AFDPART too long {} for prefix {}",
afd_length,
prefix
);
assert!(afd_length <= 16); assert!(afd_length <= 16);
let mut buf = [0u8; 16]; let mut buf = [0u8; 16];
buf[..data.len()].copy_from_slice(&data); buf[..data.len()].copy_from_slice(&data);
@ -408,10 +477,7 @@ impl DnsPacketData for APL {
use cidr::Cidr; use cidr::Cidr;
let prefix = cidr::IpCidr::new(address, prefix)?; let prefix = cidr::IpCidr::new(address, prefix)?;
items.push(AplItem { items.push(AplItem { prefix, negation })
prefix,
negation,
})
} }
Ok(APL { items }) Ok(APL { items })
} }
@ -427,13 +493,17 @@ impl DnsPacketData for APL {
match &item.prefix { match &item.prefix {
cidr::IpCidr::V4(p) => { cidr::IpCidr::V4(p) => {
let addr = p.first_address().octets(); let addr = p.first_address().octets();
while l > 0 && addr[l as usize -1] == 0 { l -= 1; } while l > 0 && addr[l as usize - 1] == 0 {
l -= 1;
}
packet.put_u8(l | negation_flag); packet.put_u8(l | negation_flag);
packet.extend_from_slice(&addr[..l as usize]); packet.extend_from_slice(&addr[..l as usize]);
}, },
cidr::IpCidr::V6(p) => { cidr::IpCidr::V6(p) => {
let addr = p.first_address().octets(); let addr = p.first_address().octets();
while l > 0 && addr[l as usize -1] == 0 { l -= 1; } while l > 0 && addr[l as usize - 1] == 0 {
l -= 1;
}
packet.put_u8(l | negation_flag); packet.put_u8(l | negation_flag);
packet.extend_from_slice(&addr[..l as usize]); packet.extend_from_slice(&addr[..l as usize]);
}, },
@ -453,7 +523,7 @@ impl DnsTextData for APL {
(false, item) (false, item)
}; };
let (afi, prefix) = match content.find(':') { let (afi, prefix) = match content.find(':') {
Some(colon) => (&content[..colon], &content[colon+1..]), Some(colon) => (&content[..colon], &content[colon + 1..]),
None => failure::bail!("no colon in APL item: {:?}", item), None => failure::bail!("no colon in APL item: {:?}", item),
}; };
let afi = afi.parse::<u16>()?; let afi = afi.parse::<u16>()?;
@ -462,10 +532,7 @@ impl DnsTextData for APL {
2 => prefix.parse::<cidr::Ipv6Cidr>()?.into(), 2 => prefix.parse::<cidr::Ipv6Cidr>()?.into(),
_ => failure::bail!("Unknown address family {} in item: {:?}", afi, item), _ => failure::bail!("Unknown address family {} in item: {:?}", afi, item),
}; };
items.push(AplItem { items.push(AplItem { prefix, negation });
prefix,
negation,
});
} }
*data = ""; *data = "";
Ok(APL { items }) Ok(APL { items })
@ -482,7 +549,6 @@ impl DnsTextData for APL {
} }
} }
#[derive(Clone, PartialEq, Eq, Debug)] #[derive(Clone, PartialEq, Eq, Debug)]
pub enum IpsecKeyGateway { pub enum IpsecKeyGateway {
None, None,
@ -494,19 +560,19 @@ pub enum IpsecKeyGateway {
#[derive(Clone, PartialEq, Eq, Debug, RRData)] #[derive(Clone, PartialEq, Eq, Debug, RRData)]
#[RRClass(ANY)] #[RRClass(ANY)]
pub enum IPSECKEY { pub enum IPSECKEY {
Known{ Known {
precedence: u8, precedence: u8,
algorithm: u8, algorithm: u8,
gateway: IpsecKeyGateway, gateway: IpsecKeyGateway,
public_key: Base64RemainingBlob, public_key: Base64RemainingBlob,
}, },
UnknownGateway{ UnknownGateway {
precedence: u8, precedence: u8,
gateway_type: u8, gateway_type: u8,
algorithm: u8, algorithm: u8,
// length of gateway is unknown, can't split gateway and public key // length of gateway is unknown, can't split gateway and public key
remaining: Bytes, remaining: Bytes,
} },
} }
impl DnsPacketData for IPSECKEY { impl DnsPacketData for IPSECKEY {
@ -519,14 +585,16 @@ impl DnsPacketData for IPSECKEY {
1 => IpsecKeyGateway::Ipv4(Ipv4Addr::deserialize(data)?), 1 => IpsecKeyGateway::Ipv4(Ipv4Addr::deserialize(data)?),
2 => IpsecKeyGateway::Ipv6(Ipv6Addr::deserialize(data)?), 2 => IpsecKeyGateway::Ipv6(Ipv6Addr::deserialize(data)?),
3 => IpsecKeyGateway::Name(DnsName::deserialize(data)?), 3 => IpsecKeyGateway::Name(DnsName::deserialize(data)?),
_ => return Ok(IPSECKEY::UnknownGateway{ _ => {
return Ok(IPSECKEY::UnknownGateway {
precedence, precedence,
gateway_type, gateway_type,
algorithm, algorithm,
remaining: remaining_bytes(data), remaining: remaining_bytes(data),
}), })
},
}; };
Ok(IPSECKEY::Known{ Ok(IPSECKEY::Known {
precedence, precedence,
algorithm, algorithm,
gateway, gateway,
@ -536,7 +604,12 @@ impl DnsPacketData for IPSECKEY {
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
match *self { match *self {
IPSECKEY::Known{precedence, algorithm, ref gateway, ref public_key} => { IPSECKEY::Known {
precedence,
algorithm,
ref gateway,
ref public_key,
} => {
packet.reserve(3); packet.reserve(3);
packet.put_u8(precedence); packet.put_u8(precedence);
let gateway_type: u8 = match *gateway { let gateway_type: u8 = match *gateway {
@ -555,13 +628,18 @@ impl DnsPacketData for IPSECKEY {
}; };
public_key.serialize(context, packet)?; public_key.serialize(context, packet)?;
}, },
IPSECKEY::UnknownGateway{precedence, gateway_type, algorithm, ref remaining} => { IPSECKEY::UnknownGateway {
precedence,
gateway_type,
algorithm,
ref remaining,
} => {
packet.reserve(3 + remaining.len()); packet.reserve(3 + remaining.len());
packet.put_u8(precedence); packet.put_u8(precedence);
packet.put_u8(gateway_type); packet.put_u8(gateway_type);
packet.put_u8(algorithm); packet.put_u8(algorithm);
packet.put_slice(remaining); packet.put_slice(remaining);
} },
} }
Ok(()) Ok(())
} }
@ -579,7 +657,7 @@ impl DnsTextData for IPSECKEY {
3 => IpsecKeyGateway::Name(DnsName::dns_parse(context, data)?), 3 => IpsecKeyGateway::Name(DnsName::dns_parse(context, data)?),
_ => failure::bail!("unknown gateway type {} for IPSECKEY", gateway_type), _ => failure::bail!("unknown gateway type {} for IPSECKEY", gateway_type),
}; };
Ok(IPSECKEY::Known{ Ok(IPSECKEY::Known {
precedence, precedence,
algorithm, algorithm,
gateway, gateway,
@ -589,7 +667,12 @@ impl DnsTextData for IPSECKEY {
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result { fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result {
match *self { match *self {
IPSECKEY::Known{precedence, algorithm, ref gateway, ref public_key} => { IPSECKEY::Known {
precedence,
algorithm,
ref gateway,
ref public_key,
} => {
let gateway_type: u8 = match *gateway { let gateway_type: u8 = match *gateway {
IpsecKeyGateway::None => 0, IpsecKeyGateway::None => 0,
IpsecKeyGateway::Ipv4(_) => 1, IpsecKeyGateway::Ipv4(_) => 1,
@ -606,7 +689,7 @@ impl DnsTextData for IPSECKEY {
public_key.dns_format(f)?; public_key.dns_format(f)?;
Ok(()) Ok(())
}, },
IPSECKEY::UnknownGateway{..} => Err(fmt::Error), IPSECKEY::UnknownGateway { .. } => Err(fmt::Error),
} }
} }
} }

View File

@ -1,7 +1,7 @@
mod rrdata;
pub mod packet; pub mod packet;
mod rrdata;
pub mod text; pub mod text;
pub use self::packet::DnsPacketWriteContext; pub use self::packet::DnsPacketWriteContext;
pub use self::rrdata::{RRDataPacket, RRDataText, RRData, StaticRRData}; pub use self::rrdata::{RRData, RRDataPacket, RRDataText, StaticRRData};
pub use self::text::DnsTextContext; pub use self::text::DnsTextContext;

View File

@ -1,5 +1,5 @@
use bytes::{Bytes, Buf, BufMut};
use crate::errors::*; use crate::errors::*;
use bytes::{Buf, BufMut, Bytes};
use std::io::Cursor; use std::io::Cursor;
mod std_impls; mod std_impls;
@ -20,7 +20,11 @@ where
{ {
let mut c = Cursor::new(data); let mut c = Cursor::new(data);
let result = parser(&mut c)?; let result = parser(&mut c)?;
failure::ensure!(!c.has_remaining(), "data remaining: {} bytes", c.remaining()); failure::ensure!(
!c.has_remaining(),
"data remaining: {} bytes",
c.remaining()
);
Ok(result) Ok(result)
} }
@ -47,7 +51,10 @@ pub fn short_blob(data: &mut Cursor<Bytes>) -> Result<Bytes> {
} }
pub fn write_short_blob(data: &[u8], packet: &mut Vec<u8>) -> Result<()> { pub fn write_short_blob(data: &[u8], packet: &mut Vec<u8>) -> Result<()> {
failure::ensure!(data.len() < 256, "short blob must be at most 255 bytes long"); failure::ensure!(
data.len() < 256,
"short blob must be at most 255 bytes long"
);
packet.reserve(data.len() + 1); packet.reserve(data.len() + 1);
packet.put_u8(data.len() as u8); packet.put_u8(data.len() as u8);
packet.put_slice(data); packet.put_slice(data);

View File

@ -1,6 +1,6 @@
use bytes::{Bytes, Buf, BufMut};
use crate::errors::*; use crate::errors::*;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use bytes::{Buf, BufMut, Bytes};
use std::io::{Cursor, Read}; use std::io::{Cursor, Read};
use std::mem::size_of; use std::mem::size_of;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
@ -101,7 +101,7 @@ mod tests {
use crate::errors::*; use crate::errors::*;
fn deserialize<T: super::DnsPacketData>(data: &'static [u8]) -> Result<T> { fn deserialize<T: super::DnsPacketData>(data: &'static [u8]) -> Result<T> {
use bytes::{Bytes,Buf}; use bytes::{Buf, Bytes};
use std::io::Cursor; use std::io::Cursor;
let mut c = Cursor::new(Bytes::from_static(data)); let mut c = Cursor::new(Bytes::from_static(data));
let result = T::deserialize(&mut c)?; let result = T::deserialize(&mut c)?;

View File

@ -1,6 +1,6 @@
use bytes::{BufMut}; use crate::common_types::name::{DnsCompressedName, DnsLabelRef, DnsName};
use crate::errors::*; use crate::errors::*;
use crate::common_types::name::{DnsName, DnsCompressedName, DnsLabelRef}; use bytes::BufMut;
// only points to uncompressed labels; if a label of a name is stored, // only points to uncompressed labels; if a label of a name is stored,
// all following labels must be stored too, even if their pos >= 0x4000. // all following labels must be stored too, even if their pos >= 0x4000.
@ -16,7 +16,7 @@ impl LabelEntry {
fn label_ref<'a>(&self, packet: &'a Vec<u8>) -> DnsLabelRef<'a> { fn label_ref<'a>(&self, packet: &'a Vec<u8>) -> DnsLabelRef<'a> {
let p = self.pos as usize; let p = self.pos as usize;
let len = packet[p] as usize; let len = packet[p] as usize;
DnsLabelRef::new(&packet[p+1..][..len]).unwrap() DnsLabelRef::new(&packet[p + 1..][..len]).unwrap()
} }
fn next(&self, labels: &Vec<LabelEntry>) -> Option<Self> { fn next(&self, labels: &Vec<LabelEntry>) -> Option<Self> {
@ -28,13 +28,19 @@ impl LabelEntry {
} }
} }
fn matches(&self, packet: &Vec<u8>, labels: &Vec<LabelEntry>, name: &DnsName, min: u8) -> Option<u8> { fn matches(
&self,
packet: &Vec<u8>,
labels: &Vec<LabelEntry>,
name: &DnsName,
min: u8,
) -> Option<u8> {
'outer: for i in 0..min { 'outer: for i in 0..min {
if name.label_ref(i) != self.label_ref(packet) { if name.label_ref(i) != self.label_ref(packet) {
continue; continue;
} }
let mut l = *self; let mut l = *self;
for j in i+1..name.label_count() { for j in i + 1..name.label_count() {
l = match l.next(labels) { l = match l.next(labels) {
None => continue 'outer, None => continue 'outer,
Some(l) => l, Some(l) => l,
@ -86,7 +92,12 @@ fn write_canonical_name(packet: &mut Vec<u8>, name: &DnsName) {
packet.put_u8(0); packet.put_u8(0);
} }
fn write_label_remember(packet: &mut Vec<u8>, labels: &mut Vec<LabelEntry>, label: DnsLabelRef, next_entry: usize) { fn write_label_remember(
packet: &mut Vec<u8>,
labels: &mut Vec<LabelEntry>,
label: DnsLabelRef,
next_entry: usize,
) {
labels.push(LabelEntry { labels.push(LabelEntry {
pos: packet.len(), pos: packet.len(),
next_entry: next_entry, next_entry: next_entry,
@ -110,7 +121,6 @@ impl Default for LabelWriteMethod {
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct DnsPacketWriteContext { pub struct DnsPacketWriteContext {
labels: LabelWriteMethod, labels: LabelWriteMethod,
} }
impl DnsPacketWriteContext { impl DnsPacketWriteContext {
@ -137,7 +147,11 @@ impl DnsPacketWriteContext {
self.labels = LabelWriteMethod::Canonical; self.labels = LabelWriteMethod::Canonical;
} }
pub(crate) fn write_uncompressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsName) -> Result<()> { pub(crate) fn write_uncompressed_name(
&mut self,
packet: &mut Vec<u8>,
name: &DnsName,
) -> Result<()> {
// for now we don't remember labels of these names. // for now we don't remember labels of these names.
// //
// if we did: would we want to check whether a suffix is already // if we did: would we want to check whether a suffix is already
@ -147,7 +161,11 @@ impl DnsPacketWriteContext {
Ok(()) Ok(())
} }
pub(crate) fn write_canonical_name(&mut self, packet: &mut Vec<u8>, name: &DnsName) -> Result<()> { pub(crate) fn write_canonical_name(
&mut self,
packet: &mut Vec<u8>,
name: &DnsName,
) -> Result<()> {
match self.labels { match self.labels {
LabelWriteMethod::Uncompressed | LabelWriteMethod::Compressed(_) => { LabelWriteMethod::Uncompressed | LabelWriteMethod::Compressed(_) => {
// uncompressed // uncompressed
@ -157,10 +175,14 @@ impl DnsPacketWriteContext {
write_canonical_name(packet, name); write_canonical_name(packet, name);
}, },
} }
return Ok(()) return Ok(());
} }
pub(crate) fn write_compressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsCompressedName) -> Result<()> { pub(crate) fn write_compressed_name(
&mut self,
packet: &mut Vec<u8>,
name: &DnsCompressedName,
) -> Result<()> {
// for DNSSEC we need to write it canonical // for DNSSEC we need to write it canonical
if name.is_root() { if name.is_root() {
write_name(packet, name); write_name(packet, name);
@ -175,7 +197,7 @@ impl DnsPacketWriteContext {
LabelWriteMethod::Compressed(ref mut labels) => labels, LabelWriteMethod::Compressed(ref mut labels) => labels,
LabelWriteMethod::Canonical => { LabelWriteMethod::Canonical => {
write_canonical_name(packet, name); write_canonical_name(packet, name);
return Ok(()) return Ok(());
}, },
}; };
@ -183,7 +205,9 @@ impl DnsPacketWriteContext {
let mut best_match_len = name.label_count(); let mut best_match_len = name.label_count();
for (e_ndx, e) in (labels as &Vec<LabelEntry>).into_iter().enumerate() { for (e_ndx, e) in (labels as &Vec<LabelEntry>).into_iter().enumerate() {
if e.pos >= 0x4000 { break; } // this and following labels can't be used for compression if e.pos >= 0x4000 {
break;
} // this and following labels can't be used for compression
if let Some(l) = e.matches(packet, labels, name, best_match_len) { if let Some(l) = e.matches(packet, labels, name, best_match_len) {
debug_assert!(l < best_match_len); debug_assert!(l < best_match_len);
best_match_len = l; best_match_len = l;
@ -207,7 +231,12 @@ impl DnsPacketWriteContext {
write_label_remember(packet, labels, name.label_ref(i), n); write_label_remember(packet, labels, name.label_ref(i), n);
} }
// the next label following is at e_ndx // the next label following is at e_ndx
write_label_remember(packet, labels, name.label_ref(best_match_len-1), e_ndx); write_label_remember(
packet,
labels,
name.label_ref(best_match_len - 1),
e_ndx,
);
} else { } else {
// no need to remember, can't be used for compression // no need to remember, can't be used for compression
for i in 0..best_match_len { for i in 0..best_match_len {
@ -232,7 +261,7 @@ impl DnsPacketWriteContext {
} }
// the next label is the TLD // the next label is the TLD
let n = labels.len(); // point to itself let n = labels.len(); // point to itself
write_label_remember(packet, labels, name.label_ref(best_match_len-1), n); write_label_remember(packet, labels, name.label_ref(best_match_len - 1), n);
// terminate name // terminate name
packet.reserve(1); packet.reserve(1);
packet.put_u8(0); packet.put_u8(0);
@ -240,7 +269,7 @@ impl DnsPacketWriteContext {
// no need to remember, can't be used for compression // no need to remember, can't be used for compression
write_name(packet, name); write_name(packet, name);
} }
} },
} }
Ok(()) Ok(())

View File

@ -1,9 +1,9 @@
use bytes::Bytes; use crate::common_types::{classes, Class, Type};
use crate::common_types::{Class, Type, classes};
use crate::errors::*; use crate::errors::*;
use crate::records::UnknownRecord; use crate::records::UnknownRecord;
use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext}; use crate::ser::packet::{DnsPacketData, DnsPacketWriteContext};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext}; use crate::ser::text::{DnsTextContext, DnsTextData, DnsTextFormatter};
use bytes::Bytes;
use std::any::Any; use std::any::Any;
use std::borrow::Cow; use std::borrow::Cow;
use std::fmt; use std::fmt;
@ -12,21 +12,39 @@ use std::io::Cursor;
pub use dnsbox_derive::RRData; pub use dnsbox_derive::RRData;
pub trait RRDataPacket { pub trait RRDataPacket {
fn deserialize_rr_data(ttl: u32, rr_class: Class, rr_type: Type, data: &mut Cursor<Bytes>) -> Result<Self> fn deserialize_rr_data(
ttl: u32,
rr_class: Class,
rr_type: Type,
data: &mut Cursor<Bytes>,
) -> Result<Self>
where where
Self: Sized, Self: Sized;
;
fn rr_type(&self) -> Type; fn rr_type(&self) -> Type;
fn serialize_rr_data(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()>; fn serialize_rr_data(
&self,
context: &mut DnsPacketWriteContext,
packet: &mut Vec<u8>,
) -> Result<()>;
} }
impl<T: DnsPacketData + StaticRRData> RRDataPacket for T { impl<T: DnsPacketData + StaticRRData> RRDataPacket for T {
fn deserialize_rr_data(_ttl: u32, rr_class: Class, rr_type: Type, data: &mut Cursor<Bytes>) -> Result<Self> { fn deserialize_rr_data(
_ttl: u32,
rr_class: Class,
rr_type: Type,
data: &mut Cursor<Bytes>,
) -> Result<Self> {
failure::ensure!(rr_type == T::TYPE, "type mismatch"); failure::ensure!(rr_type == T::TYPE, "type mismatch");
if T::CLASS != classes::ANY { if T::CLASS != classes::ANY {
failure::ensure!(rr_class == T::CLASS, "class mismatch: got {}, need {}", rr_class, T::CLASS); failure::ensure!(
rr_class == T::CLASS,
"class mismatch: got {}, need {}",
rr_class,
T::CLASS
);
} }
T::deserialize(data) T::deserialize(data)
} }
@ -35,7 +53,11 @@ impl<T: DnsPacketData + StaticRRData> RRDataPacket for T {
T::TYPE T::TYPE
} }
fn serialize_rr_data(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> { fn serialize_rr_data(
&self,
context: &mut DnsPacketWriteContext,
packet: &mut Vec<u8>,
) -> Result<()> {
self.serialize(context, packet) self.serialize(context, packet)
} }
} }
@ -43,8 +65,7 @@ impl<T: DnsPacketData + StaticRRData> RRDataPacket for T {
pub trait RRDataText { pub trait RRDataText {
fn dns_parse_rr_data(context: &DnsTextContext, data: &mut &str) -> Result<Self> fn dns_parse_rr_data(context: &DnsTextContext, data: &mut &str) -> Result<Self>
where where
Self: Sized, Self: Sized;
;
// format might fail if there is no (known) text representation. // format might fail if there is no (known) text representation.
fn dns_format_rr_data(&self, f: &mut DnsTextFormatter) -> fmt::Result; fn dns_format_rr_data(&self, f: &mut DnsTextFormatter) -> fmt::Result;
@ -58,9 +79,16 @@ impl<T: DnsTextData + StaticRRData> RRDataText for T {
Self: Sized, Self: Sized,
{ {
failure::ensure!(context.record_type() == Some(T::TYPE), "type mismatch"); failure::ensure!(context.record_type() == Some(T::TYPE), "type mismatch");
let rr_class = context.zone_class().expect("require zone CLASS to parse record"); let rr_class = context
.zone_class()
.expect("require zone CLASS to parse record");
if T::CLASS != classes::ANY { if T::CLASS != classes::ANY {
failure::ensure!(rr_class == T::CLASS, "class mismatch: got {}, need {}", rr_class, T::CLASS); failure::ensure!(
rr_class == T::CLASS,
"class mismatch: got {}, need {}",
rr_class,
T::CLASS
);
} }
failure::ensure!(context.last_ttl().is_some(), "require TTL to parse record"); failure::ensure!(context.last_ttl().is_some(), "require TTL to parse record");
T::dns_parse(context, data) T::dns_parse(context, data)
@ -84,9 +112,7 @@ pub trait RRData: RRDataPacket + RRDataText + fmt::Debug + 'static {
fn text(&self) -> Result<(String, String)> { fn text(&self) -> Result<(String, String)> {
let mut buf = String::new(); let mut buf = String::new();
match self.dns_format_rr_data(&mut DnsTextFormatter::new(&mut buf)) { match self.dns_format_rr_data(&mut DnsTextFormatter::new(&mut buf)) {
Ok(()) => { Ok(()) => return Ok((self.rr_type_txt().into(), buf)),
return Ok((self.rr_type_txt().into(), buf))
},
Err(_) => (), Err(_) => (),
} }
let mut raw = Vec::new(); let mut raw = Vec::new();
@ -94,7 +120,8 @@ pub trait RRData: RRDataPacket + RRDataText + fmt::Debug + 'static {
let ur = UnknownRecord::new(self.rr_type(), raw.into()); let ur = UnknownRecord::new(self.rr_type(), raw.into());
// formatting UnknownRecord should not fail // formatting UnknownRecord should not fail
buf.clear(); buf.clear();
ur.dns_format_rr_data(&mut DnsTextFormatter::new(&mut buf)).expect("formatting UnknownRecord must not fail"); ur.dns_format_rr_data(&mut DnsTextFormatter::new(&mut buf))
.expect("formatting UnknownRecord must not fail");
Ok((ur.rr_type_txt().into(), buf)) Ok((ur.rr_type_txt().into(), buf))
} }
} }

View File

@ -1,8 +1,8 @@
use crate::common_types; use crate::common_types;
use std::fmt; use std::fmt;
mod std_impls;
pub mod quoted; pub mod quoted;
mod std_impls;
pub use dnsbox_derive::DnsTextData; pub use dnsbox_derive::DnsTextData;
@ -12,7 +12,9 @@ pub fn skip_whitespace(data: &mut &str) {
pub fn next_field<'a>(data: &mut &'a str) -> crate::errors::Result<&'a str> { pub fn next_field<'a>(data: &mut &'a str) -> crate::errors::Result<&'a str> {
*data = (*data).trim_start(); *data = (*data).trim_start();
if data.is_empty() { failure::bail!("missing field"); } if data.is_empty() {
failure::bail!("missing field");
}
match data.find(char::is_whitespace) { match data.find(char::is_whitespace) {
None => { None => {
let result = *data; let result = *data;
@ -29,7 +31,9 @@ pub fn next_field<'a>(data: &mut &'a str) -> crate::errors::Result<&'a str> {
pub fn next_quoted_field(data: &mut &str) -> crate::errors::Result<Vec<u8>> { pub fn next_quoted_field(data: &mut &str) -> crate::errors::Result<Vec<u8>> {
*data = (*data).trim_start(); *data = (*data).trim_start();
if data.is_empty() { failure::bail!("missing field"); } if data.is_empty() {
failure::bail!("missing field");
}
let result = quoted::UnquoteIterator::new(data).collect::<Result<Vec<_>, _>>()?; let result = quoted::UnquoteIterator::new(data).collect::<Result<Vec<_>, _>>()?;
Ok(result) Ok(result)
@ -91,7 +95,7 @@ impl<'a> DnsTextFormatter<'a> {
pub fn format_field<'b>(&'b mut self) -> Result<DnsTextFormatField<'a, 'b>, fmt::Error> { pub fn format_field<'b>(&'b mut self) -> Result<DnsTextFormatField<'a, 'b>, fmt::Error> {
self.next_field()?; self.next_field()?;
Ok(DnsTextFormatField{ inner: self }) Ok(DnsTextFormatField { inner: self })
} }
pub fn write_fmt(&mut self, args: fmt::Arguments) -> fmt::Result { pub fn write_fmt(&mut self, args: fmt::Arguments) -> fmt::Result {
@ -208,8 +212,7 @@ impl DnsTextContext {
pub trait DnsTextData { pub trait DnsTextData {
fn dns_parse(context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self> fn dns_parse(context: &DnsTextContext, data: &mut &str) -> crate::errors::Result<Self>
where where
Self: Sized, Self: Sized;
;
// format might fail if there is no (known) text representation. // format might fail if there is no (known) text representation.
fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result; fn dns_format(&self, f: &mut DnsTextFormatter) -> fmt::Result;
} }
@ -221,7 +224,11 @@ where
let mut data = data; let mut data = data;
let result = parser(&mut data)?; let result = parser(&mut data)?;
let data = data.trim(); let data = data.trim();
failure::ensure!(data.is_empty(), "didn't parse complete text, remaining: {:?}", data); failure::ensure!(
data.is_empty(),
"didn't parse complete text, remaining: {:?}",
data
);
Ok(result) Ok(result)
} }

View File

@ -15,19 +15,19 @@ impl ::std::ops::Deref for EncodedByte {
pub struct EncodeIterator<'a> { pub struct EncodeIterator<'a> {
encode_whitespace: bool, encode_whitespace: bool,
data: &'a [u8] data: &'a [u8],
} }
impl<'a> EncodeIterator<'a> { impl<'a> EncodeIterator<'a> {
pub fn new_quoted(value: &'a [u8]) -> Self { pub fn new_quoted(value: &'a [u8]) -> Self {
EncodeIterator{ EncodeIterator {
encode_whitespace: false, encode_whitespace: false,
data: value, data: value,
} }
} }
pub fn new_encode_whitespace(value: &'a [u8]) -> Self { pub fn new_encode_whitespace(value: &'a [u8]) -> Self {
EncodeIterator{ EncodeIterator {
encode_whitespace: true, encode_whitespace: true,
data: value, data: value,
} }
@ -38,7 +38,9 @@ impl<'a> Iterator for EncodeIterator<'a> {
type Item = EncodedByte; type Item = EncodedByte;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.data.is_empty() { return None; } if self.data.is_empty() {
return None;
}
let b = self.data[0]; let b = self.data[0];
self.data = &self.data[1..]; self.data = &self.data[1..];
if b < 32 || b > 127 || (self.encode_whitespace && is_ascii_whitespace(b)) { if b < 32 || b > 127 || (self.encode_whitespace && is_ascii_whitespace(b)) {
@ -46,18 +48,18 @@ impl<'a> Iterator for EncodeIterator<'a> {
let d1 = b / 100; let d1 = b / 100;
let d2 = (b / 10) % 10; let d2 = (b / 10) % 10;
let d3 = b % 10; let d3 = b % 10;
Some(EncodedByte{ Some(EncodedByte {
storage: [b'\\', b'0' + d1, b'0' + d2, b'0' + d3], storage: [b'\\', b'0' + d1, b'0' + d2, b'0' + d3],
used: 4, used: 4,
}) })
} else if b == b'"' || b == b'\\' { } else if b == b'"' || b == b'\\' {
// `\c` // `\c`
Some(EncodedByte{ Some(EncodedByte {
storage: [b'\\', b, 0, 0], storage: [b'\\', b, 0, 0],
used: 2, used: 2,
}) })
} else { } else {
Some(EncodedByte{ Some(EncodedByte {
storage: [b, 0, 0, 0], storage: [b, 0, 0, 0],
used: 1, used: 1,
}) })
@ -74,7 +76,11 @@ pub struct UnquoteError {
impl fmt::Display for UnquoteError { impl fmt::Display for UnquoteError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "unquote error at position {} in {:?}: {}", self.position, self.data, self.msg) write!(
f,
"unquote error at position {} in {:?}: {}",
self.position, self.data, self.msg
)
} }
} }
@ -98,7 +104,7 @@ impl<'a, 'b: 'a> UnquoteIterator<'a, 'b> {
} }
fn err<T>(&mut self, msg: &'static str) -> Option<Result<T, UnquoteError>> { fn err<T>(&mut self, msg: &'static str) -> Option<Result<T, UnquoteError>> {
Some(Err(UnquoteError{ Some(Err(UnquoteError {
data: (*self.data).into(), data: (*self.data).into(),
position: self.pos, position: self.pos,
msg: msg, msg: msg,
@ -123,7 +129,9 @@ impl<'a, 'b: 'a> Iterator for UnquoteIterator<'a, 'b> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let raw = self.data.as_bytes(); let raw = self.data.as_bytes();
if raw.is_empty() { return self.err("empty input"); } if raw.is_empty() {
return self.err("empty input");
}
if 0 == self.pos { if 0 == self.pos {
// check for starting quote: // check for starting quote:
@ -144,12 +152,12 @@ impl<'a, 'b: 'a> Iterator for UnquoteIterator<'a, 'b> {
if raw[self.pos] == b'"' { if raw[self.pos] == b'"' {
if self.quoted { if self.quoted {
// either followed by end-of-string or a whitespace // either followed by end-of-string or a whitespace
if self.pos+1 < raw.len() && !is_ascii_whitespace(raw[self.pos+1]) { if self.pos + 1 < raw.len() && !is_ascii_whitespace(raw[self.pos + 1]) {
return self.err("quote in the middle of quoted string"); return self.err("quote in the middle of quoted string");
} }
// eat terminating quote // eat terminating quote
// pos+1 is obviously a good utf-8 boundary // pos+1 is obviously a good utf-8 boundary
*self.data = self.data[self.pos+1..].trim_start(); *self.data = self.data[self.pos + 1..].trim_start();
return None; return None;
} else { } else {
return self.err("quote in the middle of unquoted string"); return self.err("quote in the middle of unquoted string");
@ -159,9 +167,11 @@ impl<'a, 'b: 'a> Iterator for UnquoteIterator<'a, 'b> {
*self.data = self.data[self.pos..].trim_start(); *self.data = self.data[self.pos..].trim_start();
return None; return None;
} else if raw[self.pos] == b'\\' { } else if raw[self.pos] == b'\\' {
if self.pos + 1 >= raw.len() { return self.err("unexpected end of string after backslash"); } if self.pos + 1 >= raw.len() {
if raw[self.pos+1] < b'0' || raw[self.pos+1] > b'9' { return self.err("unexpected end of string after backslash");
let result = raw[self.pos+1]; }
if raw[self.pos + 1] < b'0' || raw[self.pos + 1] > b'9' {
let result = raw[self.pos + 1];
if !self.quoted && is_ascii_whitespace(result) { if !self.quoted && is_ascii_whitespace(result) {
return self.err("(escaped) whitespace not allowed in unquoted field"); return self.err("(escaped) whitespace not allowed in unquoted field");
} }
@ -169,16 +179,24 @@ impl<'a, 'b: 'a> Iterator for UnquoteIterator<'a, 'b> {
return Some(Ok(result)); return Some(Ok(result));
} }
// otherwise require 3 decimal digits // otherwise require 3 decimal digits
if self.pos + 3 >= raw.len() { return self.err("unexpected end of string after backslash with decimal"); } if self.pos + 3 >= raw.len() {
return self.err("unexpected end of string after backslash with decimal");
}
// raw[self.pos+1] already checked for digit above // raw[self.pos+1] already checked for digit above
if raw[self.pos+2] < b'0' || raw[self.pos+2] > b'9' || raw[self.pos+3] < b'0' || raw[self.pos+3] > b'9' { if raw[self.pos + 2] < b'0'
|| raw[self.pos + 2] > b'9'
|| raw[self.pos + 3] < b'0'
|| raw[self.pos + 3] > b'9'
{
return self.err("expecting 3 digits after backslash with decimal"); return self.err("expecting 3 digits after backslash with decimal");
} }
let d1 = raw[self.pos+1] - b'0'; let d1 = raw[self.pos + 1] - b'0';
let d2 = raw[self.pos+2] - b'0'; let d2 = raw[self.pos + 2] - b'0';
let d3 = raw[self.pos+3] - b'0'; let d3 = raw[self.pos + 3] - b'0';
let val = (d1 as u32 * 100) + (d2 as u32 * 10) + (d3 as u32); let val = (d1 as u32 * 100) + (d2 as u32 * 10) + (d3 as u32);
if val > 255 { return self.err("invalid decimal escape"); } if val > 255 {
return self.err("invalid decimal escape");
}
self.pos += 4; self.pos += 4;
Some(Ok(val as u8)) Some(Ok(val as u8))
} else { } else {
@ -194,17 +212,11 @@ mod tests {
use crate::ser::text::{next_quoted_field, quote}; use crate::ser::text::{next_quoted_field, quote};
fn check_quote(data: &[u8], quoted: &str) { fn check_quote(data: &[u8], quoted: &str) {
assert_eq!( assert_eq!(quote(data), quoted);
quote(data),
quoted
);
} }
fn check_unquote(mut input: &str, data: &[u8]) { fn check_unquote(mut input: &str, data: &[u8]) {
assert_eq!( assert_eq!(next_quoted_field(&mut input).unwrap(), data);
next_quoted_field(&mut input).unwrap(),
data
);
assert!(input.is_empty()); assert!(input.is_empty());
} }

View File

@ -1,6 +1,6 @@
use crate::ser::text::{next_field, DnsTextContext, DnsTextData, DnsTextFormatter};
use std::fmt; use std::fmt;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use crate::ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field};
impl DnsTextData for () { impl DnsTextData for () {
fn dns_parse(_context: &DnsTextContext, _data: &mut &str) -> crate::errors::Result<Self> { fn dns_parse(_context: &DnsTextContext, _data: &mut &str) -> crate::errors::Result<Self> {
@ -92,30 +92,22 @@ mod tests {
fn test_ipv6() { fn test_ipv6() {
assert_eq!( assert_eq!(
deserialize::<Ipv6Addr>("FEDC:BA98:7654:3210:FEDC:BA98:7654:3210").unwrap(), deserialize::<Ipv6Addr>("FEDC:BA98:7654:3210:FEDC:BA98:7654:3210").unwrap(),
Ipv6Addr::new( Ipv6Addr::new(0xfedc, 0xba98, 0x7654, 0x3210, 0xfedc, 0xba98, 0x7654, 0x3210)
0xfedc, 0xba98, 0x7654, 0x3210, 0xfedc, 0xba98, 0x7654, 0x3210
)
); );
assert_eq!( assert_eq!(
deserialize::<Ipv6Addr>("1080::8:800:200C:417A").unwrap(), deserialize::<Ipv6Addr>("1080::8:800:200C:417A").unwrap(),
Ipv6Addr::new( Ipv6Addr::new(0x1080, 0, 0, 0, 0x8, 0x800, 0x200c, 0x417a)
0x1080, 0, 0, 0, 0x8, 0x800, 0x200c, 0x417a
)
); );
assert_eq!( assert_eq!(
deserialize::<Ipv6Addr>("::13.1.68.3").unwrap(), deserialize::<Ipv6Addr>("::13.1.68.3").unwrap(),
Ipv6Addr::new( Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0x0d01, 0x4403)
0, 0, 0, 0, 0, 0, 0x0d01, 0x4403
)
); );
assert_eq!( assert_eq!(
deserialize::<Ipv6Addr>("::FFFF:129.144.52.38").unwrap(), deserialize::<Ipv6Addr>("::FFFF:129.144.52.38").unwrap(),
Ipv6Addr::new( Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0x8190, 0x3426)
0, 0, 0, 0, 0, 0xffff, 0x8190, 0x3426
)
); );
} }

View File

@ -1,8 +1,6 @@
#[cfg(not(feature = "no-unsafe"))] #[cfg(not(feature = "no-unsafe"))]
pub fn from_utf8_unchecked(v: &[u8]) -> &str { pub fn from_utf8_unchecked(v: &[u8]) -> &str {
unsafe { unsafe { ::std::str::from_utf8_unchecked(v) }
::std::str::from_utf8_unchecked(v)
}
} }
#[cfg(feature = "no-unsafe")] #[cfg(feature = "no-unsafe")]

View File

@ -1,6 +1,10 @@
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
fn derive_impl(s: &synstructure::Structure, parse_fields: TokenStream, serialize_fields: TokenStream) -> TokenStream { fn derive_impl(
s: &synstructure::Structure,
parse_fields: TokenStream,
serialize_fields: TokenStream,
) -> TokenStream {
s.gen_impl(quote!{ s.gen_impl(quote!{
#[allow(unused_imports)] #[allow(unused_imports)]
use dnsbox_base::_failure::ResultExt; use dnsbox_base::_failure::ResultExt;
@ -23,14 +27,14 @@ fn derive_impl(s: &synstructure::Structure, parse_fields: TokenStream, serialize
fn derive_unit(s: &synstructure::Structure) -> TokenStream { fn derive_unit(s: &synstructure::Structure) -> TokenStream {
let name = &s.ast().ident; let name = &s.ast().ident;
derive_impl(s, quote!{#name}, quote!{}) derive_impl(s, quote! {#name}, quote! {})
} }
fn derive_named(s: &synstructure::Structure, fields: &syn::FieldsNamed) -> TokenStream { fn derive_named(s: &synstructure::Structure, fields: &syn::FieldsNamed) -> TokenStream {
let name = &s.ast().ident; let name = &s.ast().ident;
let mut parse_fields = quote!{}; let mut parse_fields = quote! {};
let mut serialize_fields = quote!{}; let mut serialize_fields = quote! {};
for field in &fields.named { for field in &fields.named {
let field_name = field.ident.as_ref().unwrap(); let field_name = field.ident.as_ref().unwrap();
@ -45,37 +49,46 @@ fn derive_named(s: &synstructure::Structure, fields: &syn::FieldsNamed) -> Token
}); });
} }
derive_impl(s, quote!{#name{ #parse_fields }}, serialize_fields) derive_impl(s, quote! {#name{ #parse_fields }}, serialize_fields)
} }
fn derive_unnamed(s: &synstructure::Structure, fields: &syn::FieldsUnnamed) -> TokenStream { fn derive_unnamed(s: &synstructure::Structure, fields: &syn::FieldsUnnamed) -> TokenStream {
let name = &s.ast().ident; let name = &s.ast().ident;
let mut parse_fields = quote!{}; let mut parse_fields = quote! {};
let mut serialize_fields = quote!{}; let mut serialize_fields = quote! {};
for field in 0..fields.unnamed.len() { for field in 0..fields.unnamed.len() {
let field = syn::Index::from(field); let field = syn::Index::from(field);
parse_fields.extend(quote!{ parse_fields.extend(quote! {
DnsPacketData::deserialize(_data) DnsPacketData::deserialize(_data)
.with_context(|e| format!("failed parsing field {}::{}: {}", stringify!(#name), #field, e))?, .with_context(|e| format!("failed parsing field {}::{}: {}", stringify!(#name), #field, e))?,
}); });
serialize_fields.extend(quote!{ serialize_fields.extend(quote! {
self.#field.serialize(_context, _packet) self.#field.serialize(_context, _packet)
.with_context(|e| format!("failed serializing field {}::{}: {}", stringify!(#name), #field, e))?; .with_context(|e| format!("failed serializing field {}::{}: {}", stringify!(#name), #field, e))?;
}); });
} }
derive_impl(s, quote!{#name(#parse_fields)}, serialize_fields) derive_impl(s, quote! {#name(#parse_fields)}, serialize_fields)
} }
pub fn derive(s: synstructure::Structure) -> TokenStream { pub fn derive(s: synstructure::Structure) -> TokenStream {
let ast = s.ast(); let ast = s.ast();
match &ast.data { match &ast.data {
syn::Data::Struct(syn::DataStruct{ fields: syn::Fields::Unit, .. }) => derive_unit(&s), syn::Data::Struct(syn::DataStruct {
syn::Data::Struct(syn::DataStruct{ fields: syn::Fields::Named(fields), .. }) => derive_named(&s, fields), fields: syn::Fields::Unit,
syn::Data::Struct(syn::DataStruct{ fields: syn::Fields::Unnamed(fields), .. }) => derive_unnamed(&s, fields), ..
}) => derive_unit(&s),
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Named(fields),
..
}) => derive_named(&s, fields),
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Unnamed(fields),
..
}) => derive_unnamed(&s, fields),
_ => panic!("Deriving DnsPacketData not supported for non struct types"), _ => panic!("Deriving DnsPacketData not supported for non struct types"),
} }
} }

View File

@ -1,6 +1,10 @@
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
fn derive_impl(s: &synstructure::Structure, parse_fields: TokenStream, format_fields: TokenStream) -> TokenStream { fn derive_impl(
s: &synstructure::Structure,
parse_fields: TokenStream,
format_fields: TokenStream,
) -> TokenStream {
s.gen_impl(quote!{ s.gen_impl(quote!{
#[allow(unused_imports)] #[allow(unused_imports)]
use dnsbox_base::_failure::ResultExt as _; use dnsbox_base::_failure::ResultExt as _;
@ -23,14 +27,14 @@ fn derive_impl(s: &synstructure::Structure, parse_fields: TokenStream, format_fi
fn derive_unit(s: &synstructure::Structure) -> TokenStream { fn derive_unit(s: &synstructure::Structure) -> TokenStream {
let name = &s.ast().ident; let name = &s.ast().ident;
derive_impl(s, quote!{#name}, quote!{}) derive_impl(s, quote! {#name}, quote! {})
} }
fn derive_named(s: &synstructure::Structure, fields: &syn::FieldsNamed) -> TokenStream { fn derive_named(s: &synstructure::Structure, fields: &syn::FieldsNamed) -> TokenStream {
let name = &s.ast().ident; let name = &s.ast().ident;
let mut parse_fields = quote!{}; let mut parse_fields = quote! {};
let mut format_fields = quote!{}; let mut format_fields = quote! {};
for field in &fields.named { for field in &fields.named {
let field_name = field.ident.as_ref().unwrap(); let field_name = field.ident.as_ref().unwrap();
@ -39,41 +43,50 @@ fn derive_named(s: &synstructure::Structure, fields: &syn::FieldsNamed) -> Token
.with_context(|e| format!("failed parsing field {}::{}: {}", stringify!(#name), stringify!(#field_name), e))?, .with_context(|e| format!("failed parsing field {}::{}: {}", stringify!(#name), stringify!(#field_name), e))?,
}); });
format_fields.extend(quote!{ format_fields.extend(quote! {
DnsTextData::dns_format(&self.#field_name, f)?; DnsTextData::dns_format(&self.#field_name, f)?;
}); });
} }
derive_impl(s, quote!{#name{ #parse_fields }}, format_fields) derive_impl(s, quote! {#name{ #parse_fields }}, format_fields)
} }
fn derive_unnamed(s: &synstructure::Structure, fields: &syn::FieldsUnnamed) -> TokenStream { fn derive_unnamed(s: &synstructure::Structure, fields: &syn::FieldsUnnamed) -> TokenStream {
let name = &s.ast().ident; let name = &s.ast().ident;
let mut parse_fields = quote!{}; let mut parse_fields = quote! {};
let mut format_fields = quote!{}; let mut format_fields = quote! {};
for field in 0..fields.unnamed.len() { for field in 0..fields.unnamed.len() {
let field = syn::Index::from(field); let field = syn::Index::from(field);
parse_fields.extend(quote!{ parse_fields.extend(quote! {
DnsTextData::dns_parse(_context, _data) DnsTextData::dns_parse(_context, _data)
.with_context(|e| format!("failed parsing field {}::{}: {}", stringify!(#name), #field, e))?, .with_context(|e| format!("failed parsing field {}::{}: {}", stringify!(#name), #field, e))?,
}); });
format_fields.extend(quote!{ format_fields.extend(quote! {
DnsTextData::dns_format(&self.#field, f)?; DnsTextData::dns_format(&self.#field, f)?;
}); });
} }
derive_impl(s, quote!{#name(#parse_fields)}, format_fields) derive_impl(s, quote! {#name(#parse_fields)}, format_fields)
} }
pub fn derive(s: synstructure::Structure) -> TokenStream { pub fn derive(s: synstructure::Structure) -> TokenStream {
let ast = s.ast(); let ast = s.ast();
match &ast.data { match &ast.data {
syn::Data::Struct(syn::DataStruct{ fields: syn::Fields::Unit, .. }) => derive_unit(&s), syn::Data::Struct(syn::DataStruct {
syn::Data::Struct(syn::DataStruct{ fields: syn::Fields::Named(fields), .. }) => derive_named(&s, fields), fields: syn::Fields::Unit,
syn::Data::Struct(syn::DataStruct{ fields: syn::Fields::Unnamed(fields), .. }) => derive_unnamed(&s, fields), ..
}) => derive_unit(&s),
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Named(fields),
..
}) => derive_named(&s, fields),
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Unnamed(fields),
..
}) => derive_unnamed(&s, fields),
_ => panic!("Deriving DnsTextData not supported for non struct types"), _ => panic!("Deriving DnsTextData not supported for non struct types"),
} }
} }

View File

@ -1,4 +1,4 @@
#![recursion_limit="128"] #![recursion_limit = "128"]
#[macro_use] #[macro_use]
extern crate synstructure; extern crate synstructure;
@ -6,11 +6,11 @@ extern crate synstructure;
extern crate quote; extern crate quote;
extern crate proc_macro2; extern crate proc_macro2;
mod rrdata;
mod dns_packet_data; mod dns_packet_data;
mod dns_text_data; mod dns_text_data;
mod native_enum; mod native_enum;
mod native_flags; mod native_flags;
mod rrdata;
decl_derive!([DnsPacketData] => dns_packet_data::derive); decl_derive!([DnsPacketData] => dns_packet_data::derive);
decl_derive!([DnsTextData] => dns_text_data::derive); decl_derive!([DnsTextData] => dns_text_data::derive);
@ -20,19 +20,20 @@ decl_attribute!([native_flags] => native_flags::attribute_native_flags);
fn attr_get_single_list_arg(attr_meta: &syn::Meta) -> proc_macro2::TokenStream { fn attr_get_single_list_arg(attr_meta: &syn::Meta) -> proc_macro2::TokenStream {
match attr_meta { match attr_meta {
syn::Meta::Word(_) => { syn::Meta::Word(_) => panic!("{:?} attribute requires an argument", attr_meta.name()),
panic!("{:?} attribute requires an argument", attr_meta.name())
},
syn::Meta::List(l) => { syn::Meta::List(l) => {
if l.nested.len() != 1 { if l.nested.len() != 1 {
panic!("{:?} attribute requires exactly one argument", attr_meta.name()); panic!(
"{:?} attribute requires exactly one argument",
attr_meta.name()
);
} }
let arg = *l.nested.first().unwrap().value(); let arg = *l.nested.first().unwrap().value();
quote!{#arg} quote! {#arg}
}, },
syn::Meta::NameValue(nv) => { syn::Meta::NameValue(nv) => {
let lit = &nv.lit; let lit = &nv.lit;
quote!{#lit} quote! {#lit}
}, },
} }
} }

View File

@ -5,7 +5,11 @@ pub fn attribute_native_enum(
structure: synstructure::Structure, structure: synstructure::Structure,
) -> TokenStream { ) -> TokenStream {
let ast = structure.ast(); let ast = structure.ast();
let in_attrs = ast.attrs.iter().map(|a| quote!{#a}).collect::<TokenStream>(); let in_attrs = ast
.attrs
.iter()
.map(|a| quote! {#a})
.collect::<TokenStream>();
let in_vis = &ast.vis; let in_vis = &ast.vis;
let name = &ast.ident; let name = &ast.ident;
@ -19,7 +23,10 @@ pub fn attribute_native_enum(
{ {
let variants = &enumdata.variants; let variants = &enumdata.variants;
let doc_str = format!("Known enum variants of [`{}`]\n\n[`{}`]: struct.{}.html\n", name, name, name); let doc_str = format!(
"Known enum variants of [`{}`]\n\n[`{}`]: struct.{}.html\n",
name, name, name
);
known_enum = quote! { known_enum = quote! {
#[doc = #doc_str] #[doc = #doc_str]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
@ -34,10 +41,20 @@ pub fn attribute_native_enum(
let mut consts = TokenStream::new(); let mut consts = TokenStream::new();
let mut convert = TokenStream::new(); let mut convert = TokenStream::new();
for variant in &enumdata.variants { for variant in &enumdata.variants {
let variant_attrs = variant.attrs.iter().map(|a| quote!{#a}).collect::<TokenStream>(); let variant_attrs = variant
.attrs
.iter()
.map(|a| quote! {#a})
.collect::<TokenStream>();
let disc_name = &variant.ident; let disc_name = &variant.ident;
let disc = variant.discriminant.as_ref().map(|(_, d)| quote!{#d}).unwrap_or_else(|| quote!{ let disc = variant
.discriminant
.as_ref()
.map(|(_, d)| quote! {#d})
.unwrap_or_else(|| {
quote! {
#known_name::#disc_name as #native_type #known_name::#disc_name as #native_type
}
}); });
consts.extend(quote! { consts.extend(quote! {
#variant_attrs #variant_attrs

View File

@ -5,11 +5,18 @@ pub fn attribute_native_flags(
structure: synstructure::Structure, structure: synstructure::Structure,
) -> TokenStream { ) -> TokenStream {
let ast = structure.ast(); let ast = structure.ast();
let in_attrs = ast.attrs.iter().map(|a| quote!{#a}).collect::<TokenStream>(); let in_attrs = ast
.attrs
.iter()
.map(|a| quote! {#a})
.collect::<TokenStream>();
let in_vis = &ast.vis; let in_vis = &ast.vis;
let name = &ast.ident; let name = &ast.ident;
let hidden_impl_mod = syn::Ident::new(&format!("_{}_hidden_native_flags", name), proc_macro2::Span::call_site()); let hidden_impl_mod = syn::Ident::new(
&format!("_{}_hidden_native_flags", name),
proc_macro2::Span::call_site(),
);
let enumdata = match &ast.data { let enumdata = match &ast.data {
syn::Data::Enum(de) => de, syn::Data::Enum(de) => de,
@ -20,14 +27,22 @@ pub fn attribute_native_flags(
let mut known_mask = TokenStream::new(); let mut known_mask = TokenStream::new();
let mut dbg = TokenStream::new(); let mut dbg = TokenStream::new();
for variant in &enumdata.variants { for variant in &enumdata.variants {
let variant_attrs = variant.attrs.iter().map(|a| quote!{#a}).collect::<TokenStream>(); let variant_attrs = variant
.attrs
.iter()
.map(|a| quote! {#a})
.collect::<TokenStream>();
let disc_name = &variant.ident; let disc_name = &variant.ident;
let disc = variant.discriminant.as_ref().map(|(_, d)| d).expect("all variants need explicit bitmask discriminants"); let disc = variant
.discriminant
.as_ref()
.map(|(_, d)| d)
.expect("all variants need explicit bitmask discriminants");
consts.extend(quote! { consts.extend(quote! {
#variant_attrs #variant_attrs
pub const #disc_name: Flag = Flag { mask: #disc }; pub const #disc_name: Flag = Flag { mask: #disc };
}); });
known_mask.extend(quote!{ known_mask.extend(quote! {
| #disc | #disc
}); });
dbg.extend(quote! { dbg.extend(quote! {

View File

@ -1,6 +1,6 @@
use crate::attr_get_single_list_arg; use crate::attr_get_single_list_arg;
#[derive(Clone,Debug)] #[derive(Clone, Debug)]
enum StructAttribute { enum StructAttribute {
RRTypeName(proc_macro2::TokenStream), RRTypeName(proc_macro2::TokenStream),
RRClass(proc_macro2::TokenStream), RRClass(proc_macro2::TokenStream),
@ -49,11 +49,12 @@ pub fn rrdata_derive(s: synstructure::Structure) -> proc_macro2::TokenStream {
let name_str = name_str.unwrap_or_else(|| { let name_str = name_str.unwrap_or_else(|| {
let name_str = format!("{}", name); let name_str = format!("{}", name);
quote!{#name_str} quote! {#name_str}
}); });
let rr_class = rr_class.unwrap_or_else(|| quote!{ANY}); let rr_class = rr_class.unwrap_or_else(|| quote! {ANY});
let test_mod_name = syn::Ident::new(&format!("test_rr_{}", name), proc_macro2::Span::call_site()); let test_mod_name =
syn::Ident::new(&format!("test_rr_{}", name), proc_macro2::Span::call_site());
let impl_rrdata = s.unbound_impl( let impl_rrdata = s.unbound_impl(
quote!(::dnsbox_base::ser::RRData), quote!(::dnsbox_base::ser::RRData),