root/trunk/pdns/pdns/dnsparser.hh @ 498

Revision 498, 5.0 KB (checked in by ahu, 8 years ago)

add shims so the packetreader is symmetrical with the dnspacketwriter
make dnsparser do slightly more bounds checking

Line 
1#ifndef DNSPARSER_HH
2#define DNSPARSER_HH
3
4#include <map>
5#include <sstream>
6#include <stdexcept>
7#include <pcap.h>
8#include <iostream>
9#include <vector>
10#include <errno.h>
11#include <netinet/in.h>
12#include <arpa/nameser.h>
13#include "misc.hh"
14#include <boost/shared_ptr.hpp>
15#include <boost/lexical_cast.hpp>
16#include <boost/tuple/tuple.hpp>
17#include <boost/tuple/tuple_comparison.hpp>
18
19/** DNS records have three representations:
20    1) in the packet
21    2) parsed in a class, ready for use
22    3) in the zone
23
24    We should implement bidirectional transitions between 1&2 and 2&3.
25    Currently we have: 1 -> 2
26                       2 -> 3
27
28    We can add:        2 -> 1  easily by reversing the packetwriter
29    And we might be able to reverse 2 -> 3 as well
30*/
31   
32
33namespace {
34  typedef HEADER dnsheader;
35}
36
37using namespace std;
38using namespace boost;
39typedef runtime_error MOADNSException;
40
41struct dnsrecordheader
42{
43  uint16_t d_type;
44  uint16_t d_class;
45  uint32_t d_ttl;
46  uint16_t d_clen;
47} __attribute__((packed));
48
49
50class MOADNSParser;
51
52class PacketReader
53{
54public:
55  PacketReader(const vector<uint8_t>& content) 
56    : d_pos(0), d_content(content)
57  {}
58
59  uint32_t get32BitInt();
60  uint16_t get16BitInt();
61  uint8_t get8BitInt();
62
63  void xfr32BitInt(uint32_t& val)
64  {
65    val=get32BitInt();
66  }
67
68  void xfr16BitInt(uint16_t& val)
69  {
70    val=get16BitInt();
71  }
72
73  void xfrLabel(string &label)
74  {
75    label=getLabel();
76  }
77
78  void xfrText(string &text)
79  {
80    text=getText();
81  }
82
83  static uint16_t get16BitInt(const vector<unsigned char>&content, uint16_t& pos);
84  static void getLabelFromContent(const vector<uint8_t>& content, uint16_t& frompos, string& ret, int recurs);
85
86  void getDnsrecordheader(struct dnsrecordheader &ah);
87  void copyRecord(vector<unsigned char>& dest, uint16_t len);
88  void copyRecord(unsigned char* dest, uint16_t len);
89
90  string getLabel(unsigned int recurs=0);
91  string getText();
92
93  uint16_t d_pos;
94
95private:
96  const vector<uint8_t>& d_content;
97};
98
99class DNSRecord;
100
101class DNSRecordContent
102{
103public:
104  static DNSRecordContent* mastermake(const DNSRecord &dr, PacketReader& pr);
105
106  virtual std::string getZoneRepresentation() const = 0;
107  virtual ~DNSRecordContent() {}
108
109  std::string label;
110  struct dnsrecordheader header;
111
112  typedef DNSRecordContent* makerfunc_t(const struct DNSRecord& dr, PacketReader& pr); 
113  static void regist(uint16_t cl, uint16_t ty, makerfunc_t* f, const char* name)
114  {
115    typemap[make_pair(cl,ty)]=f;
116    namemap[make_pair(cl,ty)]=name;
117  }
118
119  static uint16_t TypeToNumber(const string& name)
120  {
121    for(namemap_t::const_iterator i=namemap.begin(); i!=namemap.end();++i)
122      if(!strcasecmp(i->second.c_str(), name.c_str()))
123        return i->first.second;
124
125    throw runtime_error("Unknown DNS type '"+name+"'");
126  }
127
128  static const string NumberToType(uint16_t num)
129  {
130    if(!namemap.count(make_pair(1,num)))
131      return "#" + lexical_cast<string>(num);
132      //      throw runtime_error("Unknown DNS type with numerical id "+lexical_cast<string>(num));
133    return namemap[make_pair(1,num)];
134  }
135
136protected:
137
138  typedef std::map<std::pair<uint16_t, uint16_t>, makerfunc_t* > typemap_t;
139  static typemap_t typemap;
140  typedef std::map<std::pair<uint16_t, uint16_t>, string > namemap_t;
141  static namemap_t namemap;
142};
143
144struct DNSRecord
145{
146  std::string d_label;
147  uint16_t d_type;
148  uint16_t d_class;
149  uint32_t d_ttl;
150  uint16_t d_clen;
151  enum {Answer, Nameserver, Additional} d_place;
152  boost::shared_ptr<DNSRecordContent> d_content;
153
154  bool operator<(const DNSRecord& rhs) const
155  {
156    string lzrp, rzrp;
157    if(d_content)
158      lzrp=toLower(d_content->getZoneRepresentation());
159    if(rhs.d_content)
160      rzrp=toLower(rhs.d_content->getZoneRepresentation());
161   
162    string llabel=toLower(d_label);
163    string rlabel=toLower(rhs.d_label);
164
165    return 
166      tie(llabel,     d_type,     d_class, lzrp) <
167      tie(rlabel, rhs.d_type, rhs.d_class, rzrp);
168  }
169
170  bool operator==(const DNSRecord& rhs) const
171  {
172    string lzrp, rzrp;
173    if(d_content)
174      lzrp=toLower(d_content->getZoneRepresentation());
175    if(rhs.d_content)
176      rzrp=toLower(rhs.d_content->getZoneRepresentation());
177   
178    string llabel=toLower(d_label);
179    string rlabel=toLower(rhs.d_label);
180   
181    return 
182      tie(llabel,     d_type,     d_class, lzrp) ==
183      tie(rlabel, rhs.d_type, rhs.d_class, rzrp);
184  }
185};
186
187
188class MOADNSParser
189{
190public:
191  MOADNSParser(const string& buffer) 
192  {
193    init(buffer.c_str(), buffer.size());
194  }
195
196  MOADNSParser(const char *packet, unsigned int len)
197  {
198    init(packet, len);
199  }
200  dnsheader d_header;
201  string d_qname;
202  uint16_t d_qclass, d_qtype;
203  uint8_t d_rcode;
204
205  typedef vector<pair<DNSRecord, uint16_t > > answers_t;
206  answers_t d_answers;
207
208  shared_ptr<PacketReader> getPacketReader(uint16_t offset)
209  {
210    shared_ptr<PacketReader> pr(new PacketReader(d_content));
211    pr->d_pos=offset;
212    return pr;
213  }
214private:
215  void getDnsrecordheader(struct dnsrecordheader &ah);
216  void init(const char *packet, unsigned int len);
217  vector<uint8_t> d_content;
218};
219
220
221
222
223#endif
Note: See TracBrowser for help on using the browser.