diff --git a/lib/dnsbox-base/src/common_types/binary.rs b/lib/dnsbox-base/src/common_types/binary.rs index f9203af..215d3ee 100644 --- a/lib/dnsbox-base/src/common_types/binary.rs +++ b/lib/dnsbox-base/src/common_types/binary.rs @@ -1,8 +1,8 @@ -use bytes::Bytes; +use bytes::{Bytes, BufMut}; use data_encoding::{self, HEXLOWER_PERMISSIVE}; use errors::*; use failure::{Fail, ResultExt}; -use ser::packet::{DnsPacketData, remaining_bytes, short_blob}; +use ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes, short_blob, write_short_blob}; use ser::text::*; use std::fmt; use std::io::Cursor; @@ -37,6 +37,10 @@ impl DnsPacketData for HexShortBlob { fn deserialize(data: &mut Cursor) -> Result { Ok(HexShortBlob(short_blob(data)?)) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + write_short_blob(&self.0, packet) + } } impl DnsTextData for HexShortBlob { @@ -73,6 +77,12 @@ impl DnsPacketData for Base64RemainingBlob { fn deserialize(data: &mut Cursor) -> Result { Ok(Base64RemainingBlob(remaining_bytes(data))) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(self.0.len()); + packet.put_slice(&self.0); + Ok(()) + } } impl DnsTextData for Base64RemainingBlob { @@ -98,6 +108,12 @@ impl DnsPacketData for HexRemainingBlob { fn deserialize(data: &mut Cursor) -> Result { Ok(HexRemainingBlob(remaining_bytes(data))) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(self.0.len()); + packet.put_slice(&self.0); + Ok(()) + } } impl DnsTextData for HexRemainingBlob { diff --git a/lib/dnsbox-base/src/common_types/classes.rs b/lib/dnsbox-base/src/common_types/classes.rs index 4200d65..bcfae07 100644 --- a/lib/dnsbox-base/src/common_types/classes.rs +++ b/lib/dnsbox-base/src/common_types/classes.rs @@ -2,7 +2,7 @@ use bytes::Bytes; use errors::*; -use ser::DnsPacketData; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field}; use std::fmt; use std::io::Cursor; @@ -201,6 +201,10 @@ impl DnsPacketData for Class { fn deserialize(data: &mut Cursor) -> Result { Ok(Class(DnsPacketData::deserialize(data)?)) } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + self.0.serialize(context, packet) + } } impl DnsTextData for Class { diff --git a/lib/dnsbox-base/src/common_types/name/mod.rs b/lib/dnsbox-base/src/common_types/name/mod.rs index 93bde52..36b7e22 100644 --- a/lib/dnsbox-base/src/common_types/name/mod.rs +++ b/lib/dnsbox-base/src/common_types/name/mod.rs @@ -3,7 +3,7 @@ use bytes::Bytes; use errors::*; -use ser::DnsPacketData; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use smallvec::SmallVec; use std::fmt; use std::io::Cursor; @@ -29,7 +29,7 @@ enum LabelOffset { #[derive(Clone,PartialEq,Eq,PartialOrd,Ord,Hash,Debug)] enum LabelOffsets { Uncompressed(SmallVec<[u8;16]>), - Compressed(usize, SmallVec<[LabelOffset;8]>), + Compressed(usize, SmallVec<[LabelOffset;4]>), } impl LabelOffsets { @@ -65,7 +65,7 @@ impl LabelOffsets { /// /// Uses the "original" raw representation for storage (i.e. can share /// memory with a parsed packet) -#[derive(Clone,Hash)] +#[derive(Clone)] pub struct DnsName { // in uncompressed form always includes terminating null octect; // but even in uncompressed form can include unused bytes at the @@ -198,6 +198,10 @@ impl DnsPacketData for DnsName { fn deserialize(data: &mut Cursor) -> Result { DnsName::parse_name(data, false) } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + context.write_uncompressed_name(packet, self) + } } /// Similar to `DnsName`, but allows using compressed labels in the @@ -264,6 +268,10 @@ impl DnsPacketData for DnsCompressedName { fn deserialize(data: &mut Cursor) -> Result { Ok(DnsCompressedName(DnsName::parse_name(data, true)?)) } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + context.write_compressed_name(packet, self) + } } /// Iterator type for [`DnsName::labels`] diff --git a/lib/dnsbox-base/src/common_types/nsec.rs b/lib/dnsbox-base/src/common_types/nsec.rs index e50f2ce..61a46ce 100644 --- a/lib/dnsbox-base/src/common_types/nsec.rs +++ b/lib/dnsbox-base/src/common_types/nsec.rs @@ -1,9 +1,9 @@ -use bytes::{Bytes, Buf}; +use bytes::{Bytes, Buf, BufMut}; use common_types::Type; use data_encoding; use errors::*; use failure::{Fail, ResultExt}; -use ser::packet::{DnsPacketData, remaining_bytes, short_blob}; +use ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes, short_blob, write_short_blob}; use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, skip_whitespace, next_field}; use std::collections::BTreeSet; use std::fmt; @@ -100,6 +100,12 @@ impl DnsPacketData for NsecTypeBitmap { set: set, }) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(self.raw.len()); + packet.put_slice(&self.raw); + Ok(()) + } } impl DnsTextData for NsecTypeBitmap { @@ -134,6 +140,11 @@ impl DnsPacketData for NextHashedOwnerName { ensure!(text.len() > 0, "NextHashedOwnerName must not be empty"); Ok(NextHashedOwnerName(text)) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + ensure!(self.0.len() > 0, "NextHashedOwnerName must not be empty"); + write_short_blob(&self.0, packet) + } } impl DnsTextData for NextHashedOwnerName { diff --git a/lib/dnsbox-base/src/common_types/nxt.rs b/lib/dnsbox-base/src/common_types/nxt.rs index 1b86ecf..40209aa 100644 --- a/lib/dnsbox-base/src/common_types/nxt.rs +++ b/lib/dnsbox-base/src/common_types/nxt.rs @@ -1,7 +1,7 @@ -use bytes::{Bytes, Buf}; +use bytes::{Bytes, Buf, BufMut}; use common_types::Type; use errors::*; -use ser::packet::{DnsPacketData, remaining_bytes}; +use ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes}; use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, skip_whitespace}; use std::collections::BTreeSet; use std::fmt; @@ -72,6 +72,12 @@ impl DnsPacketData for NxtTypeBitmap { set: set, }) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(self.raw.len()); + packet.put_slice(&self.raw); + Ok(()) + } } impl DnsTextData for NxtTypeBitmap { diff --git a/lib/dnsbox-base/src/common_types/text.rs b/lib/dnsbox-base/src/common_types/text.rs index 5642ba8..30dcbda 100644 --- a/lib/dnsbox-base/src/common_types/text.rs +++ b/lib/dnsbox-base/src/common_types/text.rs @@ -1,6 +1,6 @@ -use bytes::{Bytes, Buf}; +use bytes::{Bytes, Buf, BufMut}; use errors::*; -use ser::packet::{DnsPacketData, short_blob, remaining_bytes}; +use ser::packet::{DnsPacketData, DnsPacketWriteContext, short_blob, write_short_blob, remaining_bytes}; use ser::text::*; use std::fmt; use std::io::Cursor; @@ -13,6 +13,10 @@ impl DnsPacketData for ShortText { fn deserialize(data: &mut Cursor) -> Result { Ok(ShortText(short_blob(data)?)) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + write_short_blob(&self.0, packet) + } } impl DnsTextData for ShortText { @@ -43,6 +47,14 @@ impl DnsPacketData for LongText { } Ok(LongText(texts)) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + ensure!(self.0.len() > 0, "empty LongText not allowed"); + for t in &self.0 { + write_short_blob(t, packet)?; + } + Ok(()) + } } impl DnsTextData for LongText { @@ -77,6 +89,10 @@ impl DnsPacketData for UnquotedShortText { fn deserialize(data: &mut Cursor) -> Result { Ok(UnquotedShortText(short_blob(data)?)) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + write_short_blob(&self.0, packet) + } } impl DnsTextData for UnquotedShortText { @@ -102,6 +118,12 @@ impl DnsPacketData for RemainingText { fn deserialize(data: &mut Cursor) -> Result { Ok(RemainingText(remaining_bytes(data))) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(self.0.len()); + packet.put_slice(&self.0); + Ok(()) + } } impl DnsTextData for RemainingText { diff --git a/lib/dnsbox-base/src/common_types/time.rs b/lib/dnsbox-base/src/common_types/time.rs index 5deaed0..1649331 100644 --- a/lib/dnsbox-base/src/common_types/time.rs +++ b/lib/dnsbox-base/src/common_types/time.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use errors::*; -use ser::packet::DnsPacketData; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field}; use std::fmt; use std::io::Cursor; @@ -15,6 +15,10 @@ impl DnsPacketData for Time { fn deserialize(data: &mut Cursor) -> Result { Ok(Time(DnsPacketData::deserialize(data)?)) } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + self.0.serialize(context, packet) + } } impl DnsTextData for Time { diff --git a/lib/dnsbox-base/src/common_types/types.rs b/lib/dnsbox-base/src/common_types/types.rs index 3ca06a4..c1514df 100644 --- a/lib/dnsbox-base/src/common_types/types.rs +++ b/lib/dnsbox-base/src/common_types/types.rs @@ -3,7 +3,7 @@ use bytes::Bytes; use errors::*; use records::registry::{lookup_type_to_name, lookup_type_name}; -use ser::DnsPacketData; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext, next_field}; use std::borrow::Cow; use std::fmt; @@ -573,6 +573,10 @@ impl DnsPacketData for Type { fn deserialize(data: &mut Cursor) -> Result { Ok(Type(DnsPacketData::deserialize(data)?)) } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + self.0.serialize(context, packet) + } } impl DnsTextData for Type { diff --git a/lib/dnsbox-base/src/common_types/uri.rs b/lib/dnsbox-base/src/common_types/uri.rs index 993f7aa..e4a28e9 100644 --- a/lib/dnsbox-base/src/common_types/uri.rs +++ b/lib/dnsbox-base/src/common_types/uri.rs @@ -1,6 +1,6 @@ -use bytes::Bytes; +use bytes::{Bytes, BufMut}; use errors::*; -use ser::packet::{DnsPacketData, remaining_bytes}; +use ser::packet::{DnsPacketData, DnsPacketWriteContext, remaining_bytes}; use ser::text::*; use std::fmt; use std::io::Cursor; @@ -21,6 +21,13 @@ impl DnsPacketData for UriText { ensure!(!raw.is_empty(), "URI must not be empty"); Ok(UriText(raw)) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + ensure!(self.0.is_empty(), "URI must not be empty"); + packet.reserve(self.0.len()); + packet.put_slice(&self.0); + Ok(()) + } } impl DnsTextData for UriText { diff --git a/lib/dnsbox-base/src/records/unknown.rs b/lib/dnsbox-base/src/records/unknown.rs index 8bb79f8..31f2014 100644 --- a/lib/dnsbox-base/src/records/unknown.rs +++ b/lib/dnsbox-base/src/records/unknown.rs @@ -1,9 +1,9 @@ -use bytes::Bytes; +use bytes::{Bytes, BufMut}; use common_types::*; use common_types::binary::HEXLOWER_PERMISSIVE_ALLOW_WS; use errors::*; use failure::{ResultExt, Fail}; -use ser::packet::remaining_bytes; +use ser::packet::{DnsPacketWriteContext, remaining_bytes}; use ser::{RRData, RRDataPacket, RRDataText}; use ser::text::{DnsTextFormatter, DnsTextContext, next_field}; use std::borrow::Cow; @@ -54,6 +54,12 @@ impl RRDataPacket for UnknownRecord { fn rr_type(&self) -> Type { self.rr_type } + + fn serialize_rr_data(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(self.raw.len()); + packet.put_slice(&self.raw); + Ok(()) + } } impl RRDataText for UnknownRecord { @@ -69,7 +75,7 @@ impl RRDataText for UnknownRecord { UnknownRecord::dns_parse(t, data) } - // format might fail if there is no (known) text representation. + /// this must never fail unless the underlying buffer fails. fn dns_format_rr_data(&self, f: &mut DnsTextFormatter) -> fmt::Result { write!(f, "\\# {} {}", self.raw.len(), HEXLOWER_PERMISSIVE_ALLOW_WS.encode(&self.raw)) } diff --git a/lib/dnsbox-base/src/records/weird_structs.rs b/lib/dnsbox-base/src/records/weird_structs.rs index fdb288d..0200528 100644 --- a/lib/dnsbox-base/src/records/weird_structs.rs +++ b/lib/dnsbox-base/src/records/weird_structs.rs @@ -1,7 +1,8 @@ -use bytes::{Bytes, Buf}; +use bytes::{Bytes, Buf, BufMut}; +use errors::*; use common_types::*; use failure::ResultExt; -use ser::DnsPacketData; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext}; use std::fmt; use std::io::Read; @@ -22,7 +23,7 @@ pub enum LOC { } impl DnsPacketData for LOC { - fn deserialize(data: &mut ::std::io::Cursor) -> ::errors::Result { + fn deserialize(data: &mut ::std::io::Cursor) -> Result { let version: u8 = DnsPacketData::deserialize(data)?; if 0 == version { Ok(LOC::Version0(DnsPacketData::deserialize(data)?)) @@ -33,10 +34,26 @@ impl DnsPacketData for LOC { }) } } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + match *self { + LOC::Version0(ref l0) => { + packet.reserve(1); + packet.put_u8(0); + l0.serialize(context, packet) + }, + LOC::UnknownVersion{version, ref data} => { + packet.reserve(data.len() + 1); + packet.put_u8(version); + packet.put_slice(data); + Ok(()) + }, + } + } } impl DnsTextData for LOC { - fn dns_parse(_context: &DnsTextContext, _data: &mut &str) -> ::errors::Result { + fn dns_parse(_context: &DnsTextContext, _data: &mut &str) -> Result { unimplemented!() } @@ -67,7 +84,7 @@ pub struct A6 { } impl DnsPacketData for A6 { - fn deserialize(data: &mut ::std::io::Cursor) -> ::errors::Result { + fn deserialize(data: &mut ::std::io::Cursor) -> Result { let prefix: u8 = DnsPacketData::deserialize(data) .context("failed parsing field A6::prefix")?; ensure!(prefix <= 128, "invalid A6::prefix {}", prefix); @@ -98,10 +115,24 @@ impl DnsPacketData for A6 { prefix_name, }) } + + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + let suffix_offset = (self.prefix / 8) as usize; + debug_assert!(suffix_offset <= 16); + let suffix = self.dirty_suffix.octets(); + let suffix_data = &suffix[suffix_offset..]; + packet.reserve(1 /* prefix */ + suffix_data.len()); + packet.put_u8(self.prefix); + packet.put_slice(suffix_data); + if let Some(ref n) = self.prefix_name { + n.serialize(context, packet)?; + } + Ok(()) + } } impl DnsTextData for A6 { - fn dns_parse(context: &DnsTextContext, data: &mut &str) -> ::errors::Result { + fn dns_parse(context: &DnsTextContext, data: &mut &str) -> Result { let prefix: u8 = DnsTextData::dns_parse(context, data) .context("failed parsing field A6::prefix")?; ensure!(prefix <= 128, "invalid A6::prefix {}", prefix); diff --git a/lib/dnsbox-base/src/ser/mod.rs b/lib/dnsbox-base/src/ser/mod.rs index 174d01c..dfb5a46 100644 --- a/lib/dnsbox-base/src/ser/mod.rs +++ b/lib/dnsbox-base/src/ser/mod.rs @@ -2,6 +2,4 @@ pub mod packet; pub mod text; mod rrdata; -pub use self::packet::DnsPacketData; -pub use self::text::DnsTextData; pub use self::rrdata::{RRDataPacket, RRDataText, RRData, StaticRRData}; diff --git a/lib/dnsbox-base/src/ser/packet/mod.rs b/lib/dnsbox-base/src/ser/packet/mod.rs index 78ce4d1..8814e82 100644 --- a/lib/dnsbox-base/src/ser/packet/mod.rs +++ b/lib/dnsbox-base/src/ser/packet/mod.rs @@ -1,11 +1,15 @@ -use bytes::{Bytes, Buf}; +use bytes::{Bytes, Buf, BufMut}; use errors::*; use std::io::Cursor; mod std_impls; +mod write; + +pub use self::write::*; pub trait DnsPacketData: Sized { fn deserialize(data: &mut Cursor) -> Result; + fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()>; } pub fn deserialize_with(data: Bytes, parser: F) -> Result @@ -37,3 +41,11 @@ pub fn short_blob(data: &mut Cursor) -> Result { data.advance(blob_len); Ok(blob) } + +pub fn write_short_blob(data: &[u8], packet: &mut Vec) -> Result<()> { + ensure!(data.len() < 256, "short blob must be at most 255 bytes long"); + packet.reserve(data.len() + 1); + packet.put_u8(data.len() as u8); + packet.put_slice(data); + Ok(()) +} diff --git a/lib/dnsbox-base/src/ser/packet/std_impls.rs b/lib/dnsbox-base/src/ser/packet/std_impls.rs index 4c507fa..4d134bc 100644 --- a/lib/dnsbox-base/src/ser/packet/std_impls.rs +++ b/lib/dnsbox-base/src/ser/packet/std_impls.rs @@ -1,6 +1,6 @@ -use bytes::{Bytes,Buf,BigEndian}; +use bytes::{Bytes, Buf, BufMut, BigEndian}; use errors::*; -use ser::packet::DnsPacketData; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use std::io::{Cursor, Read}; use std::mem::size_of; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -10,6 +10,12 @@ impl DnsPacketData for u8 { check_enough_data!(data, size_of::(), "u8"); Ok(data.get_u8()) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(size_of::()); + packet.put_u8(*self); + Ok(()) + } } impl DnsPacketData for u16 { @@ -17,6 +23,12 @@ impl DnsPacketData for u16 { check_enough_data!(data, size_of::(), "u16"); Ok(data.get_u16::()) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(size_of::()); + packet.put_u16::(*self); + Ok(()) + } } impl DnsPacketData for u32 { @@ -24,6 +36,12 @@ impl DnsPacketData for u32 { check_enough_data!(data, size_of::(), "u32"); Ok(data.get_u32::()) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + packet.reserve(size_of::()); + packet.put_u32::(*self); + Ok(()) + } } impl DnsPacketData for Ipv4Addr { @@ -31,6 +49,14 @@ impl DnsPacketData for Ipv4Addr { check_enough_data!(data, size_of::(), "Ipv4Addr"); Ok(data.get_u32::().into()) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + let data = self.octets(); + debug_assert!(data.len() == 4); + packet.reserve(data.len()); + packet.put_slice(&data); + Ok(()) + } } impl DnsPacketData for Ipv6Addr { @@ -40,6 +66,14 @@ impl DnsPacketData for Ipv6Addr { data.read_exact(&mut buf)?; Ok(buf.into()) } + + fn serialize(&self, _context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + let data = self.octets(); + debug_assert!(data.len() == 16); + packet.reserve(data.len()); + packet.put_slice(&data); + Ok(()) + } } #[cfg(test)] diff --git a/lib/dnsbox-base/src/ser/packet/write.rs b/lib/dnsbox-base/src/ser/packet/write.rs new file mode 100644 index 0000000..638fbaf --- /dev/null +++ b/lib/dnsbox-base/src/ser/packet/write.rs @@ -0,0 +1,184 @@ +use bytes::{BufMut, BigEndian}; +use errors::*; +use common_types::name::{DnsName, DnsCompressedName, DnsLabelRef}; + +// only points to uncompressed labels; if a label of a name is stored, +// all following labels must be stored too, even if their pos >= 0x4000. +// +// the entries are ordered by pos. +#[derive(Clone, Copy, Debug, Default)] +struct LabelEntry { + pos: usize, // serialized at position in packet + next_entry: usize, // offset in labels vector; points to itself for TLD +} + +impl LabelEntry { + fn label_ref<'a>(&self, packet: &'a Vec) -> DnsLabelRef<'a> { + let p = self.pos as usize; + let len = packet[p] as usize; + DnsLabelRef::new(&packet[p+1..][..len]).unwrap() + } + + fn next(&self, labels: &Vec) -> Option { + let next = labels[self.next_entry as usize]; + if next.pos == self.pos { + None + } else { + Some(next) + } + } + + fn matches(&self, packet: &Vec, labels: &Vec, name: &DnsName, min: u8) -> Option { + 'outer: for i in 0..min { + if name.label_ref(i) != self.label_ref(packet) { + continue; + } + let mut l = *self; + for j in i+1..name.label_count() { + l = match l.next(labels) { + None => continue 'outer, + Some(l) => l, + }; + if name.label_ref(j) != l.label_ref(packet) { + continue 'outer; + } + } + match l.next(labels) { + None => return Some(i), + Some(_) => (), + }; + } + None + } +} + +fn write_label(packet: &mut Vec, label: DnsLabelRef) { + let l = label.len(); + debug_assert!(l < 64); + packet.reserve(l as usize + 1); + packet.put_u8(l); + packet.put_slice(label.as_raw()); +} + +fn write_name(packet: &mut Vec, name: &DnsName) { + for label in name { + write_label(packet, label); + } + packet.reserve(1); + packet.put_u8(0); +} + +fn write_label_remember(packet: &mut Vec, labels: &mut Vec, label: DnsLabelRef, next_entry: usize) { + labels.push(LabelEntry { + pos: packet.len(), + next_entry: next_entry, + }); + write_label(packet, label); +} + +#[derive(Clone, Debug, Default)] +pub struct DnsPacketWriteContext { + labels: Option>, +} + +impl DnsPacketWriteContext { + pub fn new() -> Self { + Default::default() + } + + pub fn enable_compression(&mut self) { + self.labels = Some(Vec::new()); + } + + pub fn write_uncompressed_name(&mut self, packet: &mut Vec, name: &DnsName) -> Result<()> { + // for now we don't remember labels of these names. + // + // if we did: would we want to check whether a suffix is already + // known before we store a new variant? the list could grow big + // with duplicates... + write_name(packet, name); + Ok(()) + } + + pub fn write_compressed_name(&mut self, packet: &mut Vec, name: &DnsCompressedName) -> Result<()> { + if name.is_root() { + write_name(packet, name); + return Ok(()); + } + + let labels = match self.labels { + Some(ref mut labels) => labels, + None => { + // compression disabled + write_name(packet, name); + return Ok(()); + } + }; + + let mut best_match = None; + let mut best_match_len = name.label_count(); + + for (e_ndx, e) in (labels as &Vec).into_iter().enumerate() { + if e.pos >= 0x4000 { break; } // this and following labels can't be used for compression + if let Some(l) = e.matches(packet, labels, name, best_match_len) { + debug_assert!(l < best_match_len); + best_match_len = l; + best_match = Some(e_ndx); + if best_match_len == 0 { + // can't improve + break; + } + } + } + + match best_match { + Some(e_ndx) => { + // found compressable suffix + if best_match_len > 0 { + // but not for complete name, need to write some labels + if packet.len() < 0x4000 { + // remember labels + for i in 0..best_match_len - 1 { + let n = labels.len() + 1; // next label follows directly + write_label_remember(packet, labels, name.label_ref(i), n); + } + // the next label following is at e_ndx + write_label_remember(packet, labels, name.label_ref(best_match_len-1), e_ndx); + } else { + // no need to remember, can't be used for compression + for i in 0..best_match_len { + write_label(packet, name.label_ref(i)); + } + } + } + let p = labels[e_ndx].pos; + debug_assert!(p < 0x4000); + packet.reserve(2); + packet.put_u16::(0xc000 | p as u16); + }, + None => { + // no suffix written already + debug_assert!(best_match_len > 0); + debug_assert!(best_match_len == name.label_count()); + if packet.len() < 0x4000 { + // remember all labels for the name + for i in 0..best_match_len - 1 { + let n = labels.len() + 1; // next label follows directly + write_label_remember(packet, labels, name.label_ref(i), n); + } + // the next label is the TLD + let n = labels.len(); // point to itself + write_label_remember(packet, labels, name.label_ref(best_match_len-1), n); + // terminate name + packet.reserve(1); + packet.put_u8(0); + } else { + // no need to remember, can't be used for compression + write_name(packet, name); + } + } + } + + Ok(()) + } +} \ No newline at end of file diff --git a/lib/dnsbox-base/src/ser/rrdata.rs b/lib/dnsbox-base/src/ser/rrdata.rs index 7bf9482..2aeeda6 100644 --- a/lib/dnsbox-base/src/ser/rrdata.rs +++ b/lib/dnsbox-base/src/ser/rrdata.rs @@ -1,11 +1,12 @@ use bytes::Bytes; use common_types::{Class, Type, classes}; use errors::*; -use ser::DnsPacketData; +use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use ser::text::{DnsTextData, DnsTextFormatter, DnsTextContext}; use std::borrow::Cow; use std::fmt; use std::io::Cursor; +use records::UnknownRecord; pub trait RRDataPacket { fn deserialize_rr_data(ttl: u32, rr_class: Class, rr_type: Type, data: &mut Cursor) -> Result @@ -14,6 +15,8 @@ pub trait RRDataPacket { ; fn rr_type(&self) -> Type; + + fn serialize_rr_data(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()>; } impl RRDataPacket for T { @@ -28,6 +31,10 @@ impl RRDataPacket for T { fn rr_type(&self) -> Type { T::TYPE } + + fn serialize_rr_data(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { + self.serialize(context, packet) + } } pub trait RRDataText { @@ -42,7 +49,10 @@ pub trait RRDataText { fn rr_type_txt(&self) -> Cow<'static, str>; // (type, rrdata) - fn text(&self) -> Result<(String, String)> { + fn text(&self) -> Result<(String, String)> + where + Self: RRDataPacket, + { let mut buf = String::new(); match self.dns_format_rr_data(&mut DnsTextFormatter::new(&mut buf)) { Ok(()) => { @@ -50,8 +60,13 @@ pub trait RRDataText { }, Err(_) => (), } + let mut raw = Vec::new(); + self.serialize_rr_data(&mut DnsPacketWriteContext::new(), &mut raw)?; + let ur = UnknownRecord::new(self.rr_type(), raw.into()); + // formatting UnknownRecord should not fail buf.clear(); - unimplemented!() + self.dns_format_rr_data(&mut DnsTextFormatter::new(&mut buf)).unwrap(); + Ok((ur.rr_type_txt().into(), buf)) } } diff --git a/lib/dnsbox-derive/src/dns_packet_data.rs b/lib/dnsbox-derive/src/dns_packet_data.rs index 2032ff0..8a63e28 100644 --- a/lib/dnsbox-derive/src/dns_packet_data.rs +++ b/lib/dnsbox-derive/src/dns_packet_data.rs @@ -23,23 +23,37 @@ pub fn build(ast: &syn::DeriveInput) -> quote::Tokens { let name = &ast.ident; let mut parse_fields = quote!{}; + let mut serialize_fields = quote!{}; for field in fields { let field_name = field.ident.as_ref().unwrap(); - let ctx_msg = format!("failed parsing field {}::{}", name, field_name); + let parse_ctx_msg = format!("failed parsing field {}::{}", name, field_name); + let serialize_ctx_msg = format!("failed serializing field {}::{}", name, field_name); parse_fields = quote!{#parse_fields - #field_name: DnsPacketData::deserialize(_data).context(#ctx_msg)?, + #field_name: DnsPacketData::deserialize(_data).context(#parse_ctx_msg)?, + }; + + serialize_fields = quote!{#serialize_fields + self.#field_name.serialize(_context, _packet).context(#serialize_ctx_msg)?; }; } quote!{ - impl ::dnsbox_base::ser::DnsPacketData for #name { + impl ::dnsbox_base::ser::packet::DnsPacketData for #name { #[allow(unused_imports)] fn deserialize(_data: &mut ::std::io::Cursor<::dnsbox_base::bytes::Bytes>) -> ::dnsbox_base::errors::Result { use ::dnsbox_base::failure::ResultExt; - use ::dnsbox_base::ser::DnsPacketData; + use ::dnsbox_base::ser::packet::DnsPacketData; Ok(#name{ #parse_fields }) } + + #[allow(unused_imports)] + fn serialize(&self, _context: &mut ::dnsbox_base::ser::packet::DnsPacketWriteContext, _packet: &mut Vec) -> ::dnsbox_base::errors::Result<()> { + use ::dnsbox_base::failure::ResultExt; + use ::dnsbox_base::ser::packet::DnsPacketData; + #serialize_fields + Ok(()) + } } } } diff --git a/lib/dnsbox-derive/src/dns_text_data.rs b/lib/dnsbox-derive/src/dns_text_data.rs index 8212757..83933ac 100644 --- a/lib/dnsbox-derive/src/dns_text_data.rs +++ b/lib/dnsbox-derive/src/dns_text_data.rs @@ -38,15 +38,15 @@ pub fn build(ast: &syn::DeriveInput) -> quote::Tokens { quote!{ #[allow(unused_imports)] - impl ::dnsbox_base::ser::DnsTextData for #name { + impl ::dnsbox_base::ser::text::DnsTextData for #name { fn dns_parse(_context: &::dnsbox_base::ser::text::DnsTextContext, _data: &mut &str) -> ::dnsbox_base::errors::Result { use dnsbox_base::failure::ResultExt; - use dnsbox_base::ser::DnsTextData; + use dnsbox_base::ser::text::DnsTextData; Ok(#name{ #parse_fields }) } fn dns_format(&self, f: &mut ::dnsbox_base::ser::text::DnsTextFormatter) -> ::std::fmt::Result { - use dnsbox_base::ser::DnsTextData; + use dnsbox_base::ser::text::DnsTextData; use std::fmt::{self, Write}; #format_fields Ok(())