Rewritten DHTRoutingTableDeserializer using stdio instead of stream.

This commit is contained in:
Tatsuhiro Tsujikawa 2011-08-05 20:17:19 +09:00
parent 5eb338ad87
commit f141cd4228
4 changed files with 49 additions and 65 deletions

View File

@ -36,7 +36,7 @@
#include <cstring> #include <cstring>
#include <cassert> #include <cassert>
#include <istream> #include <cstdio>
#include <utility> #include <utility>
#include "DHTNode.h" #include "DHTNode.h"
@ -56,27 +56,27 @@ DHTRoutingTableDeserializer::DHTRoutingTableDeserializer(int family):
DHTRoutingTableDeserializer::~DHTRoutingTableDeserializer() {} DHTRoutingTableDeserializer::~DHTRoutingTableDeserializer() {}
namespace { #define FREAD_CHECK(ptr, count, fp) \
void readBytes(unsigned char* buf, size_t buflen, if(fread((ptr), 1, (count), (fp)) != (count)) { \
std::istream& in, size_t readlen)
{
assert(readlen <= buflen);
in.read(reinterpret_cast<char*>(buf), readlen);
}
} // namespace
#define CHECK_STREAM(in, length) \
if(in.gcount() != length) { \
throw DL_ABORT_EX \
(fmt("Failed to load DHT routing table. cause:%s", \
"Unexpected EOF")); \
} \
if(!in) { \
throw DL_ABORT_EX("Failed to load DHT routing table."); \ throw DL_ABORT_EX("Failed to load DHT routing table."); \
} }
void DHTRoutingTableDeserializer::deserialize(std::istream& in) namespace {
void readBytes(unsigned char* buf, size_t buflen,
FILE* fp, size_t readlen)
{ {
assert(readlen <= buflen);
FREAD_CHECK(buf, readlen, fp);
}
} // namespace
void DHTRoutingTableDeserializer::deserialize(const std::string& filename)
{
FILE* fp = a2fopen(utf8ToWChar(filename).c_str(), "rb");
if(!fp) {
throw DL_ABORT_EX("Failed to load DHT routing table.");
}
auto_delete_r<FILE*, int> deleter(fp, fclose);
char header[8]; char header[8];
memset(header, 0, sizeof(header)); memset(header, 0, sizeof(header));
// magic // magic
@ -109,8 +109,7 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
array_wrapper<unsigned char, 255> buf; array_wrapper<unsigned char, 255> buf;
// header // header
readBytes(buf, buf.size(), in, 8); readBytes(buf, buf.size(), fp, 8);
CHECK_STREAM(in, 8);
if(memcmp(header, buf, 8) == 0) { if(memcmp(header, buf, 8) == 0) {
version = 3; version = 3;
} else if(memcmp(headerCompat, buf, 8) == 0) { } else if(memcmp(headerCompat, buf, 8) == 0) {
@ -125,37 +124,29 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
uint64_t temp64; uint64_t temp64;
// time // time
if(version == 2) { if(version == 2) {
in.read(reinterpret_cast<char*>(&temp32), sizeof(temp32)); FREAD_CHECK(&temp32, sizeof(temp32), fp);
CHECK_STREAM(in, sizeof(temp32));
serializedTime_.setTimeInSec(ntohl(temp32)); serializedTime_.setTimeInSec(ntohl(temp32));
// 4bytes reserved // 4bytes reserved
readBytes(buf, buf.size(), in, 4); readBytes(buf, buf.size(), fp, 4);
CHECK_STREAM(in, 4);
} else { } else {
in.read(reinterpret_cast<char*>(&temp64), sizeof(temp64)); FREAD_CHECK(&temp64, sizeof(temp64), fp);
CHECK_STREAM(in, sizeof(temp64));
serializedTime_.setTimeInSec(ntoh64(temp64)); serializedTime_.setTimeInSec(ntoh64(temp64));
} }
// localnode // localnode
// 8bytes reserved // 8bytes reserved
readBytes(buf, buf.size(), in, 8); readBytes(buf, buf.size(), fp, 8);
CHECK_STREAM(in, 8);
// localnode ID // localnode ID
readBytes(buf, buf.size(), in, DHT_ID_LENGTH); readBytes(buf, buf.size(), fp, DHT_ID_LENGTH);
CHECK_STREAM(in, DHT_ID_LENGTH);
SharedHandle<DHTNode> localNode(new DHTNode(buf)); SharedHandle<DHTNode> localNode(new DHTNode(buf));
// 4bytes reserved // 4bytes reserved
readBytes(buf, buf.size(), in, 4); readBytes(buf, buf.size(), fp, 4);
CHECK_STREAM(in, 4);
// number of nodes // number of nodes
in.read(reinterpret_cast<char*>(&temp32), sizeof(temp32)); FREAD_CHECK(&temp32, sizeof(temp32), fp);
CHECK_STREAM(in, sizeof(temp32));
uint32_t numNodes = ntohl(temp32); uint32_t numNodes = ntohl(temp32);
// 4bytes reserved // 4bytes reserved
readBytes(buf, buf.size(), in, 4); readBytes(buf, buf.size(), fp, 4);
CHECK_STREAM(in, 4);
std::vector<SharedHandle<DHTNode> > nodes; std::vector<SharedHandle<DHTNode> > nodes;
// nodes // nodes
@ -163,45 +154,38 @@ void DHTRoutingTableDeserializer::deserialize(std::istream& in)
for(size_t i = 0; i < numNodes; ++i) { for(size_t i = 0; i < numNodes; ++i) {
// 1byte compact peer info length // 1byte compact peer info length
uint8_t peerInfoLen; uint8_t peerInfoLen;
in >> peerInfoLen; FREAD_CHECK(&peerInfoLen, sizeof(peerInfoLen), fp);
if(peerInfoLen != compactlen) { if(peerInfoLen != compactlen) {
// skip this entry // skip this entry
readBytes(buf, buf.size(), in, 7+48); readBytes(buf, buf.size(), fp, 7+48);
CHECK_STREAM(in, 7+48);
continue; continue;
} }
// 7bytes reserved // 7bytes reserved
readBytes(buf, buf.size(), in, 7); readBytes(buf, buf.size(), fp, 7);
CHECK_STREAM(in, 7);
// compactlen bytes compact peer info // compactlen bytes compact peer info
readBytes(buf, buf.size(), in, compactlen); readBytes(buf, buf.size(), fp, compactlen);
CHECK_STREAM(in, compactlen);
if(memcmp(zero, buf, compactlen) == 0) { if(memcmp(zero, buf, compactlen) == 0) {
// skip this entry // skip this entry
readBytes(buf, buf.size(), in, 48-compactlen); readBytes(buf, buf.size(), fp, 48-compactlen);
CHECK_STREAM(in, 48-compactlen);
continue; continue;
} }
std::pair<std::string, uint16_t> peer = std::pair<std::string, uint16_t> peer =
bittorrent::unpackcompact(buf, family_); bittorrent::unpackcompact(buf, family_);
if(peer.first.empty()) { if(peer.first.empty()) {
// skip this entry // skip this entry
readBytes(buf, buf.size(), in, 48-compactlen); readBytes(buf, buf.size(), fp, 48-compactlen);
CHECK_STREAM(in, 48-compactlen);
continue; continue;
} }
// 24-compactlen bytes reserved // 24-compactlen bytes reserved
readBytes(buf, buf.size(), in, 24-compactlen); readBytes(buf, buf.size(), fp, 24-compactlen);
// node ID // node ID
readBytes(buf, buf.size(), in, DHT_ID_LENGTH); readBytes(buf, buf.size(), fp, DHT_ID_LENGTH);
CHECK_STREAM(in, DHT_ID_LENGTH);
SharedHandle<DHTNode> node(new DHTNode(buf)); SharedHandle<DHTNode> node(new DHTNode(buf));
node->setIPAddress(peer.first); node->setIPAddress(peer.first);
node->setPort(peer.second); node->setPort(peer.second);
// 4bytes reserved // 4bytes reserved
readBytes(buf, buf.size(), in, 4); readBytes(buf, buf.size(), fp, 4);
CHECK_STREAM(in, 4);
nodes.push_back(node); nodes.push_back(node);
} }

View File

@ -38,7 +38,7 @@
#include "common.h" #include "common.h"
#include <vector> #include <vector>
#include <iosfwd> #include <string>
#include "SharedHandle.h" #include "SharedHandle.h"
#include "TimeA2.h" #include "TimeA2.h"
@ -76,7 +76,7 @@ public:
return serializedTime_; return serializedTime_;
} }
void deserialize(std::istream& in); void deserialize(const std::string& filename);
}; };
} // namespace aria2 } // namespace aria2

View File

@ -99,11 +99,7 @@ void DHTSetup::setup
e->getOption()->get(family == AF_INET?PREF_DHT_FILE_PATH: e->getOption()->get(family == AF_INET?PREF_DHT_FILE_PATH:
PREF_DHT_FILE_PATH6); PREF_DHT_FILE_PATH6);
try { try {
std::ifstream in(dhtFile.c_str(), std::ios::binary); deserializer.deserialize(dhtFile);
if(!in) {
throw DL_ABORT_EX("Could not open file");
}
deserializer.deserialize(in);
localNode = deserializer.getLocalNode(); localNode = deserializer.getLocalNode();
} catch(RecoverableException& e) { } catch(RecoverableException& e) {
A2_LOG_ERROR_EX A2_LOG_ERROR_EX

View File

@ -54,11 +54,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize()
s.setLocalNode(localNode); s.setLocalNode(localNode);
s.setNodes(nodes); s.setNodes(nodes);
std::stringstream ss; std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize";
s.serialize(ss); std::ofstream outfile(filename.c_str(), std::ios::binary);
s.serialize(outfile);
outfile.close();
DHTRoutingTableDeserializer d(AF_INET); DHTRoutingTableDeserializer d(AF_INET);
d.deserialize(ss); d.deserialize(filename);
CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(), CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(),
DHT_ID_LENGTH) == 0); DHT_ID_LENGTH) == 0);
@ -93,11 +95,13 @@ void DHTRoutingTableDeserializerTest::testDeserialize6()
s.setLocalNode(localNode); s.setLocalNode(localNode);
s.setNodes(nodes); s.setNodes(nodes);
std::stringstream ss; std::string filename = A2_TEST_OUT_DIR"/aria2_DHTRoutingTableDeserializerTest_testDeserialize6";
s.serialize(ss); std::ofstream outfile(filename.c_str(), std::ios::binary);
s.serialize(outfile);
outfile.close();
DHTRoutingTableDeserializer d(AF_INET6); DHTRoutingTableDeserializer d(AF_INET6);
d.deserialize(ss); d.deserialize(filename);
CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(), CPPUNIT_ASSERT(memcmp(localNode->getID(), d.getLocalNode()->getID(),
DHT_ID_LENGTH) == 0); DHT_ID_LENGTH) == 0);