use byteorder::ByteOrder; use bytes::{Bytes, Buf, BufMut, BigEndian}; use common_types::{Type, Class, DnsCompressedName, types}; use errors::*; use ser::RRData; use ser::packet::{DnsPacketData, DnsPacketWriteContext}; use std::io::Cursor; use records::registry::deserialize_rr_data; pub mod opt; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub enum QueryResponse { Query, Response, } impl Default for QueryResponse { fn default() -> Self { QueryResponse::Query } } #[derive(Clone, Copy, PartialEq, Eq, Debug, Default)] pub struct DnsHeaderFlags { pub qr: QueryResponse, pub opcode: u8, // 0...15 pub authoritative_answer: bool, pub truncation: bool, pub recursion_desired: bool, pub recursion_available: bool, pub reserved_bit9: bool, pub authentic_data: bool, pub checking_disabled: bool, pub rcode: u8, // 0...15 } impl DnsPacketData for DnsHeaderFlags { fn deserialize(data: &mut Cursor) -> Result { let raw = u16::deserialize(data)?; let qr = if 0 == raw & 0x8000 { QueryResponse::Query } else { QueryResponse::Response }; let opcode = 0xf & (raw >> 11) as u8; let authoritative_answer = 0 != raw & 0x0400; let truncation = 0 != raw & 0x0200; let recursion_desired = 0 != raw & 0x0100; let recursion_available = 0 != raw & 0x0080; let reserved_bit9 = 0 != raw & 0x0040; let authentic_data = 0 != raw & 0x0020; let checking_disabled = 0 != raw & 0x0010; let rcode = 0xf & raw as u8; Ok(DnsHeaderFlags{ qr, opcode, authoritative_answer, truncation, recursion_desired, recursion_available, reserved_bit9, authentic_data, checking_disabled, rcode, }) } fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { let flags: u16 = 0 | match self.qr { QueryResponse::Query => 0, QueryResponse::Response => 1, } | (((0xf & self.opcode) as u16) << 11) | if self.authoritative_answer { 0x0400 } else { 0 } | if self.truncation { 0x0200 } else { 0 } | if self.recursion_desired { 0x0100 } else { 0 } | if self.recursion_available { 0x0080 } else { 0 } | if self.reserved_bit9 { 0x0040 } else { 0 } | if self.authentic_data { 0x0020 } else { 0 } | if self.checking_disabled { 0x0010 } else { 0 } | (0xf & self.rcode) as u16 ; flags.serialize(context, packet) } } #[derive(Clone, PartialEq, Eq, Debug, DnsPacketData)] pub struct DnsHeader { pub id: u16, pub flags: DnsHeaderFlags, pub qdcount: u16, pub ancount: u16, pub nscount: u16, pub arcount: u16, } #[derive(Clone, PartialEq, Eq, Debug, DnsPacketData)] pub struct Question { pub qname: DnsCompressedName, pub qtype: Type, pub qclass: Class, } #[derive(Clone, Debug)] pub struct Resource { pub name: DnsCompressedName, pub class: Class, pub ttl: u32, pub data: Box, } impl DnsPacketData for Resource { fn deserialize(data: &mut Cursor) -> Result { let name = DnsCompressedName::deserialize(data)?; let rr_type = Type::deserialize(data)?; let class = Class::deserialize(data)?; let ttl = u32::deserialize(data)?; let rdlength = u16::deserialize(data)? as usize; check_enough_data!(data, rdlength, "RDATA"); let pos = data.position() as usize; let rrdata_from0 = data.get_ref().slice(0, pos + rdlength); data.advance(rdlength); let mut rrdata = Cursor::new(rrdata_from0); rrdata.advance(pos); let rd = deserialize_rr_data(ttl, class, rr_type, &mut rrdata)?; ensure!(!rrdata.has_remaining(), "data remaining: {} bytes", rrdata.remaining()); Ok(Resource{ name, class, ttl, data: rd, }) } fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { self.name.serialize(context, packet)?; let rrtype = self.data.rr_type(); rrtype.serialize(context, packet)?; self.class.serialize(context, packet)?; self.ttl.serialize(context, packet)?; let rdlen_pos = packet.len(); packet.reserve(2); packet.put_u16::(0); // stub let rd_start = packet.len(); self.data.serialize_rr_data(context, packet)?; let rd_end = packet.len(); let rdlen = rd_end - rd_start; ensure!(rdlen < 0x1_0000, "RDATA too big"); // now patch length BigEndian::write_u16(&mut packet[rdlen_pos..][..2], rdlen as u16); Ok(()) } } #[derive(Clone, Debug)] pub struct DnsPacket { pub id: u16, pub flags: DnsHeaderFlags, pub question: Vec, pub answer: Vec, pub authority: Vec, pub additional: Vec, pub opt: Option>, } impl DnsPacket { /// also serializes OPT before conversion if present pub fn to_bytes(&mut self) -> Result> { if self.opt.is_some() { // delete other OPTs, so only call it if there is a "new" OPT self.serialize_opt()?; } let mut buf = Vec::new(); let mut ctx = DnsPacketWriteContext::new(); ctx.enable_compression(); self.serialize(&mut ctx, &mut buf)?; Ok(buf) } /// overwrites existing OPT records, and serializes a new one (if /// self.opt is not None) pub fn serialize_opt(&mut self) -> Result<()> { // delete all additional OPT records self.additional.retain(|r| r.data.rr_type() != types::OPT); match self.opt.take() { Some(Err(e)) => bail!("can't serialize broken OPT: {:?}", e), Some(Ok(opt)) => { self.additional.push(opt.serialize()?); }, None => (), } Ok(()) } } impl Default for DnsPacket { fn default() -> Self { DnsPacket{ id: 0, flags: DnsHeaderFlags::default(), question: Vec::new(), answer: Vec::new(), authority: Vec::new(), additional: Vec::new(), opt: None, } } } impl DnsPacketData for DnsPacket { fn deserialize(data: &mut Cursor) -> Result { let header = DnsHeader::deserialize(data)?; let mut p = DnsPacket { id: header.id, flags: header.flags, question: (0..header.qdcount).map(|_| Question::deserialize(data)).collect::>>()?, answer: (0..header.ancount).map(|_| Resource::deserialize(data)).collect::>>()?, authority: (0..header.nscount).map(|_| Resource::deserialize(data)).collect::>>()?, additional: (0..header.arcount).map(|_| Resource::deserialize(data)).collect::>>()?, opt: None, }; let mut opt_resource_ndx = None; for (i, r) in p.additional.iter().enumerate() { if r.data.rr_type() == types::OPT { ensure!(opt_resource_ndx.is_none(), "multiple OPT resource records"); opt_resource_ndx = Some(i); } } if let Some(ndx) = opt_resource_ndx { let opt_rr = p.additional.remove(ndx); p.opt = Some(opt::Opt::deserialize(&opt_rr)?); } Ok(p) } fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec) -> Result<()> { ensure!(self.question.len() < 0x1_0000, "too many question entries"); ensure!(self.answer.len() < 0x1_0000, "too many answer entries"); ensure!(self.authority.len() < 0x1_0000, "too many authority entries"); ensure!(self.additional.len() < 0x1_0000, "too many additional entries"); let header = DnsHeader{ id: self.id, flags: self.flags, qdcount: self.question.len() as u16, ancount: self.answer.len() as u16, nscount: self.authority.len() as u16, arcount: self.additional.len() as u16, }; header.serialize(context, packet)?; for r in &self.question { r.serialize(context, packet)?; } for r in &self.answer { r.serialize(context, packet)?; } for r in &self.authority { r.serialize(context, packet)?; } for r in &self.additional { r.serialize(context, packet)?; } Ok(()) } }