You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

233 lines
6.4 KiB

use byteorder::ByteOrder;
use bytes::{Bytes, Buf, BufMut, BigEndian};
use common_types::{Type, Class, DnsCompressedName};
use errors::*;
use ser::RRData;
use ser::packet::{DnsPacketData, DnsPacketWriteContext};
use std::io::Cursor;
use records::registry::deserialize_rr_data;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub enum QueryResponse {
Query,
Response,
}
impl Default for QueryResponse {
fn default() -> Self {
QueryResponse::Query
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub struct DnsHeaderFlags {
pub qr: QueryResponse,
pub opcode: u8, // 0...15
pub authoritative_answer: bool,
pub truncation: bool,
pub recursion_desired: bool,
pub recursion_available: bool,
pub reserved_bit9: bool,
pub authentic_data: bool,
pub checking_disabled: bool,
pub rcode: u8, // 0...15
}
impl DnsPacketData for DnsHeaderFlags {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
let raw = u16::deserialize(data)?;
let qr = if 0 == raw & 0x8000 { QueryResponse::Query } else { QueryResponse::Response };
let opcode = 0xf & (raw >> 11) as u8;
let authoritative_answer = 0 != raw & 0x0400;
let truncation = 0 != raw & 0x0200;
let recursion_desired = 0 != raw & 0x0100;
let recursion_available = 0 != raw & 0x0080;
let reserved_bit9 = 0 != raw & 0x0040;
let authentic_data = 0 != raw & 0x0020;
let checking_disabled = 0 != raw & 0x0010;
let rcode = 0xf & raw as u8;
Ok(DnsHeaderFlags{
qr,
opcode,
authoritative_answer,
truncation,
recursion_desired,
recursion_available,
reserved_bit9,
authentic_data,
checking_disabled,
rcode,
})
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
let flags: u16 = 0
| match self.qr {
QueryResponse::Query => 0,
QueryResponse::Response => 1,
}
| (((0xf & self.opcode) as u16) << 11)
| if self.authoritative_answer { 0x0400 } else { 0 }
| if self.truncation { 0x0200 } else { 0 }
| if self.recursion_desired { 0x0100 } else { 0 }
| if self.recursion_available { 0x0080 } else { 0 }
| if self.reserved_bit9 { 0x0040 } else { 0 }
| if self.authentic_data { 0x0020 } else { 0 }
| if self.checking_disabled { 0x0010 } else { 0 }
| (0xf & self.rcode) as u16
;
flags.serialize(context, packet)
}
}
#[derive(Clone, PartialEq, Eq, Debug, DnsPacketData)]
pub struct DnsHeader {
pub id: u16,
pub flags: DnsHeaderFlags,
pub qdcount: u16,
pub ancount: u16,
pub nscount: u16,
pub arcount: u16,
}
#[derive(Clone, PartialEq, Eq, Debug, DnsPacketData)]
pub struct Question {
pub qname: DnsCompressedName,
pub qtype: Type,
pub qclass: Class,
}
#[derive(Clone, Debug)]
pub struct Resource {
pub name: DnsCompressedName,
pub class: Class,
pub ttl: u32,
pub data: Box<RRData>,
}
impl DnsPacketData for Resource {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
let name = DnsCompressedName::deserialize(data)?;
let rr_type = Type::deserialize(data)?;
let class = Class::deserialize(data)?;
let ttl = u32::deserialize(data)?;
let rdlength = u16::deserialize(data)? as usize;
check_enough_data!(data, rdlength, "RDATA");
let pos = data.position() as usize;
let rrdata_from0 = data.get_ref().slice(0, pos + rdlength);
data.advance(rdlength);
let mut rrdata = Cursor::new(rrdata_from0);
rrdata.advance(pos);
let rd = deserialize_rr_data(ttl, class, rr_type, &mut rrdata)?;
ensure!(!rrdata.has_remaining(), "data remaining: {} bytes", rrdata.remaining());
Ok(Resource{
name,
class,
ttl,
data: rd,
})
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
self.name.serialize(context, packet)?;
let rrtype = self.data.rr_type();
rrtype.serialize(context, packet)?;
self.class.serialize(context, packet)?;
self.ttl.serialize(context, packet)?;
let rdlen_pos = packet.len();
packet.reserve(2);
packet.put_u16::<BigEndian>(0); // stub
let rd_start = packet.len();
self.data.serialize_rr_data(context, packet)?;
let rd_end = packet.len();
let rdlen = rd_end - rd_start;
ensure!(rdlen < 0x1_0000, "RDATA too big");
// now patch length
BigEndian::write_u16(&mut packet[rdlen_pos..][..2], rdlen as u16);
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct DnsPacket {
pub id: u16,
pub flags: DnsHeaderFlags,
pub question: Vec<Question>,
pub answer: Vec<Resource>,
pub authority: Vec<Resource>,
pub additional: Vec<Resource>,
}
impl DnsPacket {
pub fn to_bytes(&self) -> Result<Vec<u8>> {
let mut buf = Vec::new();
let mut ctx = DnsPacketWriteContext::new();
ctx.enable_compression();
self.serialize(&mut ctx, &mut buf)?;
Ok(buf)
}
}
impl Default for DnsPacket {
fn default() -> Self {
DnsPacket{
id: 0,
flags: DnsHeaderFlags::default(),
question: Vec::new(),
answer: Vec::new(),
authority: Vec::new(),
additional: Vec::new(),
}
}
}
impl DnsPacketData for DnsPacket {
fn deserialize(data: &mut Cursor<Bytes>) -> Result<Self> {
let header = DnsHeader::deserialize(data)?;
Ok(DnsPacket {
id: header.id,
flags: header.flags,
question: (0..header.qdcount).map(|_| Question::deserialize(data)).collect::<Result<Vec<_>>>()?,
answer: (0..header.ancount).map(|_| Resource::deserialize(data)).collect::<Result<Vec<_>>>()?,
authority: (0..header.nscount).map(|_| Resource::deserialize(data)).collect::<Result<Vec<_>>>()?,
additional: (0..header.arcount).map(|_| Resource::deserialize(data)).collect::<Result<Vec<_>>>()?,
})
}
fn serialize(&self, context: &mut DnsPacketWriteContext, packet: &mut Vec<u8>) -> Result<()> {
ensure!(self.question.len() < 0x1_0000, "too many question entries");
ensure!(self.answer.len() < 0x1_0000, "too many answer entries");
ensure!(self.authority.len() < 0x1_0000, "too many authority entries");
ensure!(self.additional.len() < 0x1_0000, "too many additional entries");
let header = DnsHeader{
id: self.id,
flags: self.flags,
qdcount: self.question.len() as u16,
ancount: self.answer.len() as u16,
nscount: self.authority.len() as u16,
arcount: self.additional.len() as u16,
};
header.serialize(context, packet)?;
for r in &self.question {
r.serialize(context, packet)?;
}
for r in &self.answer {
r.serialize(context, packet)?;
}
for r in &self.authority {
r.serialize(context, packet)?;
}
for r in &self.additional {
r.serialize(context, packet)?;
}
Ok(())
}
}