rust-dnsbox/lib/dnsbox-base/src/ser/packet/write.rs

184 lines
4.9 KiB
Rust

use bytes::{BufMut, BigEndian};
use errors::*;
use common_types::name::{DnsName, DnsCompressedName, DnsLabelRef};
// only points to uncompressed labels; if a label of a name is stored,
// all following labels must be stored too, even if their pos >= 0x4000.
//
// the entries are ordered by pos.
#[derive(Clone, Copy, Debug, Default)]
struct LabelEntry {
pos: usize, // serialized at position in packet
next_entry: usize, // offset in labels vector; points to itself for TLD
}
impl LabelEntry {
fn label_ref<'a>(&self, packet: &'a Vec<u8>) -> DnsLabelRef<'a> {
let p = self.pos as usize;
let len = packet[p] as usize;
DnsLabelRef::new(&packet[p+1..][..len]).unwrap()
}
fn next(&self, labels: &Vec<LabelEntry>) -> Option<Self> {
let next = labels[self.next_entry as usize];
if next.pos == self.pos {
None
} else {
Some(next)
}
}
fn matches(&self, packet: &Vec<u8>, labels: &Vec<LabelEntry>, name: &DnsName, min: u8) -> Option<u8> {
'outer: for i in 0..min {
if name.label_ref(i) != self.label_ref(packet) {
continue;
}
let mut l = *self;
for j in i+1..name.label_count() {
l = match l.next(labels) {
None => continue 'outer,
Some(l) => l,
};
if name.label_ref(j) != l.label_ref(packet) {
continue 'outer;
}
}
match l.next(labels) {
None => return Some(i),
Some(_) => (),
};
}
None
}
}
fn write_label(packet: &mut Vec<u8>, label: DnsLabelRef) {
let l = label.len();
debug_assert!(l < 64);
packet.reserve(l as usize + 1);
packet.put_u8(l);
packet.put_slice(label.as_raw());
}
fn write_name(packet: &mut Vec<u8>, name: &DnsName) {
for label in name {
write_label(packet, label);
}
packet.reserve(1);
packet.put_u8(0);
}
fn write_label_remember(packet: &mut Vec<u8>, labels: &mut Vec<LabelEntry>, label: DnsLabelRef, next_entry: usize) {
labels.push(LabelEntry {
pos: packet.len(),
next_entry: next_entry,
});
write_label(packet, label);
}
#[derive(Clone, Debug, Default)]
pub struct DnsPacketWriteContext {
labels: Option<Vec<LabelEntry>>,
}
impl DnsPacketWriteContext {
pub fn new() -> Self {
Default::default()
}
pub fn enable_compression(&mut self) {
self.labels = Some(Vec::new());
}
pub fn write_uncompressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsName) -> Result<()> {
// for now we don't remember labels of these names.
//
// if we did: would we want to check whether a suffix is already
// known before we store a new variant? the list could grow big
// with duplicates...
write_name(packet, name);
Ok(())
}
pub fn write_compressed_name(&mut self, packet: &mut Vec<u8>, name: &DnsCompressedName) -> Result<()> {
if name.is_root() {
write_name(packet, name);
return Ok(());
}
let labels = match self.labels {
Some(ref mut labels) => labels,
None => {
// compression disabled
write_name(packet, name);
return Ok(());
}
};
let mut best_match = None;
let mut best_match_len = name.label_count();
for (e_ndx, e) in (labels as &Vec<LabelEntry>).into_iter().enumerate() {
if e.pos >= 0x4000 { break; } // this and following labels can't be used for compression
if let Some(l) = e.matches(packet, labels, name, best_match_len) {
debug_assert!(l < best_match_len);
best_match_len = l;
best_match = Some(e_ndx);
if best_match_len == 0 {
// can't improve
break;
}
}
}
match best_match {
Some(e_ndx) => {
// found compressable suffix
if best_match_len > 0 {
// but not for complete name, need to write some labels
if packet.len() < 0x4000 {
// remember labels
for i in 0..best_match_len - 1 {
let n = labels.len() + 1; // next label follows directly
write_label_remember(packet, labels, name.label_ref(i), n);
}
// the next label following is at e_ndx
write_label_remember(packet, labels, name.label_ref(best_match_len-1), e_ndx);
} else {
// no need to remember, can't be used for compression
for i in 0..best_match_len {
write_label(packet, name.label_ref(i));
}
}
}
let p = labels[e_ndx].pos;
debug_assert!(p < 0x4000);
packet.reserve(2);
packet.put_u16::<BigEndian>(0xc000 | p as u16);
},
None => {
// no suffix written already
debug_assert!(best_match_len > 0);
debug_assert!(best_match_len == name.label_count());
if packet.len() < 0x4000 {
// remember all labels for the name
for i in 0..best_match_len - 1 {
let n = labels.len() + 1; // next label follows directly
write_label_remember(packet, labels, name.label_ref(i), n);
}
// the next label is the TLD
let n = labels.len(); // point to itself
write_label_remember(packet, labels, name.label_ref(best_match_len-1), n);
// terminate name
packet.reserve(1);
packet.put_u8(0);
} else {
// no need to remember, can't be used for compression
write_name(packet, name);
}
}
}
Ok(())
}
}