use bytes::{BufMut, BigEndian}; use errors::*; use common_types::name::{DnsName, DnsCompressedName, DnsLabelRef}; // only points to uncompressed labels; if a label of a name is stored, // all following labels must be stored too, even if their pos >= 0x4000. // // the entries are ordered by pos. #[derive(Clone, Copy, Debug, Default)] struct LabelEntry { pos: usize, // serialized at position in packet next_entry: usize, // offset in labels vector; points to itself for TLD } impl LabelEntry { fn label_ref<'a>(&self, packet: &'a Vec) -> DnsLabelRef<'a> { let p = self.pos as usize; let len = packet[p] as usize; DnsLabelRef::new(&packet[p+1..][..len]).unwrap() } fn next(&self, labels: &Vec) -> Option { let next = labels[self.next_entry as usize]; if next.pos == self.pos { None } else { Some(next) } } fn matches(&self, packet: &Vec, labels: &Vec, name: &DnsName, min: u8) -> Option { 'outer: for i in 0..min { if name.label_ref(i) != self.label_ref(packet) { continue; } let mut l = *self; for j in i+1..name.label_count() { l = match l.next(labels) { None => continue 'outer, Some(l) => l, }; if name.label_ref(j) != l.label_ref(packet) { continue 'outer; } } match l.next(labels) { None => return Some(i), Some(_) => (), }; } None } } fn write_label(packet: &mut Vec, label: DnsLabelRef) { let l = label.len(); debug_assert!(l < 64); packet.reserve(l as usize + 1); packet.put_u8(l); packet.put_slice(label.as_raw()); } fn write_name(packet: &mut Vec, name: &DnsName) { for label in name { write_label(packet, label); } packet.reserve(1); packet.put_u8(0); } fn write_label_remember(packet: &mut Vec, labels: &mut Vec, label: DnsLabelRef, next_entry: usize) { labels.push(LabelEntry { pos: packet.len(), next_entry: next_entry, }); write_label(packet, label); } #[derive(Clone, Debug, Default)] pub struct DnsPacketWriteContext { labels: Option>, } impl DnsPacketWriteContext { pub fn new() -> Self { Default::default() } pub fn enable_compression(&mut self) { self.labels = Some(Vec::new()); } pub fn write_uncompressed_name(&mut self, packet: &mut Vec, name: &DnsName) -> Result<()> { // for now we don't remember labels of these names. // // if we did: would we want to check whether a suffix is already // known before we store a new variant? the list could grow big // with duplicates... write_name(packet, name); Ok(()) } pub fn write_compressed_name(&mut self, packet: &mut Vec, name: &DnsCompressedName) -> Result<()> { if name.is_root() { write_name(packet, name); return Ok(()); } let labels = match self.labels { Some(ref mut labels) => labels, None => { // compression disabled write_name(packet, name); return Ok(()); } }; let mut best_match = None; let mut best_match_len = name.label_count(); for (e_ndx, e) in (labels as &Vec).into_iter().enumerate() { if e.pos >= 0x4000 { break; } // this and following labels can't be used for compression if let Some(l) = e.matches(packet, labels, name, best_match_len) { debug_assert!(l < best_match_len); best_match_len = l; best_match = Some(e_ndx); if best_match_len == 0 { // can't improve break; } } } match best_match { Some(e_ndx) => { // found compressable suffix if best_match_len > 0 { // but not for complete name, need to write some labels if packet.len() < 0x4000 { // remember labels for i in 0..best_match_len - 1 { let n = labels.len() + 1; // next label follows directly write_label_remember(packet, labels, name.label_ref(i), n); } // the next label following is at e_ndx write_label_remember(packet, labels, name.label_ref(best_match_len-1), e_ndx); } else { // no need to remember, can't be used for compression for i in 0..best_match_len { write_label(packet, name.label_ref(i)); } } } let p = labels[e_ndx].pos; debug_assert!(p < 0x4000); packet.reserve(2); packet.put_u16::(0xc000 | p as u16); }, None => { // no suffix written already debug_assert!(best_match_len > 0); debug_assert!(best_match_len == name.label_count()); if packet.len() < 0x4000 { // remember all labels for the name for i in 0..best_match_len - 1 { let n = labels.len() + 1; // next label follows directly write_label_remember(packet, labels, name.label_ref(i), n); } // the next label is the TLD let n = labels.len(); // point to itself write_label_remember(packet, labels, name.label_ref(best_match_len-1), n); // terminate name packet.reserve(1); packet.put_u8(0); } else { // no need to remember, can't be used for compression write_name(packet, name); } } } Ok(()) } }