diff --git a/src/DHTInteractionCommand.cc b/src/DHTInteractionCommand.cc index 2eb87fae..ca6554c5 100644 --- a/src/DHTInteractionCommand.cc +++ b/src/DHTInteractionCommand.cc @@ -54,6 +54,7 @@ #include "UDPTrackerRequest.h" #include "fmt.h" #include "wallclock.h" +#include "TrackerWatcherCommand.h" namespace aria2 { @@ -66,6 +67,7 @@ DHTInteractionCommand::DHTInteractionCommand(cuid_t cuid, DownloadEngine* e) receiver_{nullptr}, taskQueue_{nullptr} { + setStatusRealtime(); } DHTInteractionCommand::~DHTInteractionCommand() @@ -96,10 +98,12 @@ bool DHTInteractionCommand::execute() // needs this. if (e_->getRequestGroupMan()->downloadFinished() || (e_->isHaltRequested() && udpTrackerClient_->getNumWatchers() == 0)) { + A2_LOG_DEBUG("DHTInteractionCommand exiting"); return true; } else if (e_->isForceHaltRequested()) { udpTrackerClient_->failAll(); + A2_LOG_DEBUG("DHTInteractionCommand exiting"); return true; } @@ -122,8 +126,18 @@ bool DHTInteractionCommand::execute() } else { // this may be udp tracker response. nothrow. - udpTrackerClient_->receiveReply(data.data(), length, remoteAddr, - remotePort, global::wallclock()); + std::shared_ptr req; + if (udpTrackerClient_->receiveReply(req, data.data(), length, + remoteAddr, remotePort, + global::wallclock()) == 0) { + if (req->action == UDPT_ACT_ANNOUNCE) { + auto c = static_cast(req->user_data); + if (c) { + c->setStatus(Command::STATUS_ONESHOT_REALTIME); + e_->setNoWait(true); + } + } + } } } } @@ -150,7 +164,7 @@ bool DHTInteractionCommand::execute() udpTrackerClient_->requestFail(UDPT_ERR_NETWORK); } } - e_->addCommand(std::unique_ptr(this)); + e_->addRoutineCommand(std::unique_ptr(this)); return false; } diff --git a/src/DHTSetup.cc b/src/DHTSetup.cc index e9f3c34a..45ec95a0 100644 --- a/src/DHTSetup.cc +++ b/src/DHTSetup.cc @@ -80,14 +80,16 @@ DHTSetup::DHTSetup() {} DHTSetup::~DHTSetup() {} -std::vector> DHTSetup::setup(DownloadEngine* e, - int family) +std::pair>, + std::vector>> +DHTSetup::setup(DownloadEngine* e, int family) { std::vector> tempCommands; + std::vector> tempRoutineCommands; if ((family != AF_INET && family != AF_INET6) || (family == AF_INET && DHTRegistry::isInitialized()) || (family == AF_INET6 && DHTRegistry::isInitialized6())) { - return tempCommands; + return {}; } try { // load routing table and localnode id here @@ -212,7 +214,7 @@ std::vector> DHTSetup::setup(DownloadEngine* e, command->setReadCheckSocket(connection->getSocket()); command->setConnection(std::move(connection)); command->setUDPTrackerClient(udpTrackerClient); - tempCommands.push_back(std::move(command)); + tempRoutineCommands.push_back(std::move(command)); } { auto command = make_unique( @@ -290,6 +292,7 @@ std::vector> DHTSetup::setup(DownloadEngine* e, " DHT is disabled."), ex); tempCommands.clear(); + tempRoutineCommands.clear(); if (family == AF_INET) { DHTRegistry::clearData(); e->getBtRegistry()->setUDPTrackerClient( @@ -299,7 +302,8 @@ std::vector> DHTSetup::setup(DownloadEngine* e, DHTRegistry::clearData6(); } } - return tempCommands; + return std::make_pair(std::move(tempCommands), + std::move(tempRoutineCommands)); } } // namespace aria2 diff --git a/src/DHTSetup.h b/src/DHTSetup.h index 93384590..30eb8a42 100644 --- a/src/DHTSetup.h +++ b/src/DHTSetup.h @@ -51,7 +51,12 @@ public: ~DHTSetup(); - std::vector> setup(DownloadEngine* e, int family); + // Returns two vector of Commands. First one contains regular + // commands. Secod one contains so called routine commands, which + // executed once per event poll returns. + std::pair>, + std::vector>> + setup(DownloadEngine* e, int family); }; } // namespace aria2 diff --git a/src/RequestGroup.cc b/src/RequestGroup.cc index 41d05545..a1f7080d 100644 --- a/src/RequestGroup.cc +++ b/src/RequestGroup.cc @@ -349,12 +349,24 @@ void RequestGroup::createInitialCommand( option_->getAsBool(PREF_ENABLE_DHT6))) { if (option_->getAsBool(PREF_ENABLE_DHT)) { - e->addCommand(DHTSetup().setup(e, AF_INET)); + std::vector> c, rc; + std::tie(c, rc) = DHTSetup().setup(e, AF_INET); + + e->addCommand(std::move(c)); + for (auto& a : rc) { + e->addRoutineCommand(std::move(a)); + } } if (!e->getOption()->getAsBool(PREF_DISABLE_IPV6) && option_->getAsBool(PREF_ENABLE_DHT6)) { - e->addCommand(DHTSetup().setup(e, AF_INET6)); + std::vector> c, rc; + std::tie(c, rc) = DHTSetup().setup(e, AF_INET6); + + e->addCommand(std::move(c)); + for (auto& a : rc) { + e->addRoutineCommand(std::move(a)); + } } const auto& nodes = torrentAttrs->nodes; // TODO Are nodes in torrent IPv4 only? diff --git a/src/TrackerWatcherCommand.cc b/src/TrackerWatcherCommand.cc index 30c27edd..2e6f4f31 100644 --- a/src/TrackerWatcherCommand.cc +++ b/src/TrackerWatcherCommand.cc @@ -229,6 +229,7 @@ bool TrackerWatcherCommand::execute() trackerRequest_ = createAnnounce(e_); if (trackerRequest_) { trackerRequest_->issue(e_); + A2_LOG_DEBUG("tracker request created"); } } else if (trackerRequest_->stopped()) { @@ -259,6 +260,12 @@ bool TrackerWatcherCommand::execute() } } } + + if (!trackerRequest_ && btAnnounce_->noMoreAnnounce()) { + A2_LOG_DEBUG("no more announce"); + return true; + } + e_->addCommand(std::unique_ptr(this)); return false; } @@ -325,8 +332,10 @@ std::unique_ptr TrackerWatcherCommand::createUDPAnnRequest(const std::string& host, uint16_t port, uint16_t localPort) { - return make_unique( - btAnnounce_->createUDPTrackerRequest(host, port, localPort)); + auto req = btAnnounce_->createUDPTrackerRequest(host, port, localPort); + req->user_data = this; + + return make_unique(std::move(req)); } namespace { diff --git a/src/UDPTrackerClient.cc b/src/UDPTrackerClient.cc index 21e5c48f..cffbf1c9 100644 --- a/src/UDPTrackerClient.cc +++ b/src/UDPTrackerClient.cc @@ -132,7 +132,8 @@ struct CollectAddrPortMatch { }; } // namespace -int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length, +int UDPTrackerClient::receiveReply(std::shared_ptr& recvReq, + const unsigned char* data, size_t length, const std::string& remoteAddr, uint16_t remotePort, const Timer& now) { @@ -167,6 +168,9 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length, CollectAddrPortMatch(reqs, remoteAddr, remotePort)), connectRequests_.end()); pendingRequests_.insert(pendingRequests_.begin(), reqs.begin(), reqs.end()); + + recvReq = std::move(req); + break; } case UDPT_ACT_ANNOUNCE: { @@ -209,6 +213,9 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length, getUDPTrackerEventStr(req->event), util::toHex(req->infohash).c_str(), req->reply->interval, req->reply->leechers, req->reply->seeders, numPeers)); + + recvReq = std::move(req); + break; } case UDPT_ACT_ERROR: { @@ -236,6 +243,9 @@ int UDPTrackerClient::receiveReply(const unsigned char* data, size_t length, if (req->action == UDPT_ACT_CONNECT) { failConnect(req->remoteAddr, req->remotePort, UDPT_ERR_TRACKER); } + + recvReq = std::move(req); + break; } case UDPT_ACT_SCRAPE: diff --git a/src/UDPTrackerClient.h b/src/UDPTrackerClient.h index 1fdade1e..2a387bb3 100644 --- a/src/UDPTrackerClient.h +++ b/src/UDPTrackerClient.h @@ -74,7 +74,8 @@ public: UDPTrackerClient(); ~UDPTrackerClient(); - int receiveReply(const unsigned char* data, size_t length, + int receiveReply(std::shared_ptr& req, + const unsigned char* data, size_t length, const std::string& remoteAddr, uint16_t remotePort, const Timer& now); diff --git a/src/UDPTrackerRequest.cc b/src/UDPTrackerRequest.cc index f7c6e2b5..bc64cbad 100644 --- a/src/UDPTrackerRequest.cc +++ b/src/UDPTrackerRequest.cc @@ -58,7 +58,8 @@ UDPTrackerRequest::UDPTrackerRequest() state(UDPT_STA_PENDING), error(UDPT_ERR_SUCCESS), dispatched(Timer::zero()), - failCount(0) + failCount(0), + user_data(nullptr) { } diff --git a/src/UDPTrackerRequest.h b/src/UDPTrackerRequest.h index 7ec291fe..f61f228c 100644 --- a/src/UDPTrackerRequest.h +++ b/src/UDPTrackerRequest.h @@ -101,6 +101,7 @@ struct UDPTrackerRequest { Timer dispatched; int failCount; std::shared_ptr reply; + void *user_data; UDPTrackerRequest(); }; diff --git a/test/UDPTrackerClientTest.cc b/test/UDPTrackerClientTest.cc index 948ed2d4..fc2f34e6 100644 --- a/test/UDPTrackerClientTest.cc +++ b/test/UDPTrackerClientTest.cc @@ -156,6 +156,7 @@ void UDPTrackerClientTest::testConnectFollowedByAnnounce() std::string remoteAddr; uint16_t remotePort; Timer now; + std::shared_ptr recvReq; std::shared_ptr req1( createAnnounce("192.168.0.1", 6991, 0)); @@ -190,8 +191,12 @@ void UDPTrackerClientTest::testConnectFollowedByAnnounce() uint64_t connectionId = 12345; rv = createConnectReply(data, sizeof(data), connectionId, transactionId); - rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now); + rv = tr.receiveReply(recvReq, data, rv, req1->remoteAddr, req1->remotePort, + now); CPPUNIT_ASSERT_EQUAL(0, (int)rv); + if (rv == 0) { + CPPUNIT_ASSERT_EQUAL((int32_t)UDPT_ACT_CONNECT, recvReq->action); + } // Now 2 requests get back to pending CPPUNIT_ASSERT_EQUAL((size_t)2, tr.getPendingRequests().size()); @@ -229,15 +234,23 @@ void UDPTrackerClientTest::testConnectFollowedByAnnounce() // Reply for req2 rv = createAnnounceReply(data, sizeof(data), transactionId2); - rv = tr.receiveReply(data, rv, req2->remoteAddr, req2->remotePort, now); + rv = tr.receiveReply(recvReq, data, rv, req2->remoteAddr, req2->remotePort, + now); CPPUNIT_ASSERT_EQUAL(0, (int)rv); + if (rv == 0) { + CPPUNIT_ASSERT_EQUAL((int32_t)UDPT_ACT_ANNOUNCE, recvReq->action); + } CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req2->state); CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_SUCCESS, req2->error); // Reply for req1 rv = createAnnounceReply(data, sizeof(data), transactionId1, 2); - rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now); + rv = tr.receiveReply(recvReq, data, rv, req1->remoteAddr, req1->remotePort, + now); CPPUNIT_ASSERT_EQUAL(0, (int)rv); + if (rv == 0) { + CPPUNIT_ASSERT_EQUAL((int32_t)UDPT_ACT_ANNOUNCE, recvReq->action); + } CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state); CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_SUCCESS, req1->error); CPPUNIT_ASSERT_EQUAL((size_t)2, req1->reply->peers.size()); @@ -280,6 +293,8 @@ void UDPTrackerClientTest::testRequestFailure() std::string remoteAddr; uint16_t remotePort; Timer now; + std::shared_ptr recvReq; + { std::shared_ptr req1( createAnnounce("192.168.0.1", 6991, 0)); @@ -315,7 +330,12 @@ void UDPTrackerClientTest::testRequestFailure() tr.requestSent(now); rv = createErrorReply(data, sizeof(data), transactionId, "error"); - rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now); + rv = tr.receiveReply(recvReq, data, rv, req1->remoteAddr, req1->remotePort, + now); + CPPUNIT_ASSERT_EQUAL((ssize_t)0, rv); + if (rv == 0) { + CPPUNIT_ASSERT_EQUAL((int32_t)UDPT_ACT_CONNECT, recvReq->action); + } CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state); CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TRACKER, req1->error); CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req2->state); @@ -338,7 +358,8 @@ void UDPTrackerClientTest::testRequestFailure() uint64_t connectionId = 12345; rv = createConnectReply(data, sizeof(data), connectionId, transactionId); - rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now); + rv = tr.receiveReply(recvReq, data, rv, req1->remoteAddr, req1->remotePort, + now); CPPUNIT_ASSERT_EQUAL(0, (int)rv); rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now); @@ -348,7 +369,8 @@ void UDPTrackerClientTest::testRequestFailure() tr.requestSent(now); rv = createErrorReply(data, sizeof(data), transactionId, "announce error"); - rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now); + rv = tr.receiveReply(recvReq, data, rv, req1->remoteAddr, req1->remotePort, + now); CPPUNIT_ASSERT_EQUAL((int)UDPT_STA_COMPLETE, req1->state); CPPUNIT_ASSERT_EQUAL((int)UDPT_ERR_TRACKER, req1->error); CPPUNIT_ASSERT(tr.getConnectRequests().empty()); @@ -365,6 +387,8 @@ void UDPTrackerClientTest::testTimeout() uint16_t remotePort; Timer now; UDPTrackerClient tr; + std::shared_ptr recvReq; + { std::shared_ptr req1( createAnnounce("192.168.0.1", 6991, 0)); @@ -414,7 +438,8 @@ void UDPTrackerClientTest::testTimeout() uint64_t connectionId = 12345; rv = createConnectReply(data, sizeof(data), connectionId, transactionId); - rv = tr.receiveReply(data, rv, req1->remoteAddr, req1->remotePort, now); + rv = tr.receiveReply(recvReq, data, rv, req1->remoteAddr, req1->remotePort, + now); CPPUNIT_ASSERT_EQUAL(0, (int)rv); rv = tr.createRequest(data, sizeof(data), remoteAddr, remotePort, now);