diff --git a/src/task.cc b/src/task.cc new file mode 100644 index 00000000..83143235 --- /dev/null +++ b/src/task.cc @@ -0,0 +1,160 @@ +#include "utils.h" + +#include +#include + +#include +#include +#include +#include +#include + +using std::experimental::optional; +using std::experimental::nullopt; + +enum class TaskTargetThread { + Indexer, + QueryDb, +}; + +// TODO: IdleTask returns a bool indicating if it did work. +// TODO: Hookup IdleTask +// TODO: Move target_thread out of task and into PostTask. + +struct Task { + // The thread the task will execute on. + TaskTargetThread target_thread; + + // The action the task will perform. + using TAction = std::function; + TAction action; + + Task(TaskTargetThread target, const TAction& action); +}; + +Task::Task(TaskTargetThread target, const TAction& action) + : target_thread(target), action(action) {} + + +struct TaskQueue { + optional idle_task; + std::vector tasks; + std::mutex tasks_mutex; +}; + +struct TaskManager { + TaskManager(); + + // Run |task| at some point in the future. This will run the task as soon as possible. + void PostTask(const Task& task); + + // Run |task| whenever there is nothing else to run. + void SetIdleTask(const Task& task); + + // Run pending tasks for |thread|. Stop running tasks after |max_time| has elapsed. + void RunTasks(TaskTargetThread thread, optional> max_time); + + std::unordered_map> pending_tasks_; +}; + +TaskManager::TaskManager() { + pending_tasks_[TaskTargetThread::Indexer] = MakeUnique(); + pending_tasks_[TaskTargetThread::QueryDb] = MakeUnique(); +} + +void TaskManager::PostTask(const Task& task) { + TaskQueue* queue = pending_tasks_[task.target_thread].get(); + std::lock_guard lock_guard(queue->tasks_mutex); + queue->tasks.push_back(task); +} + +void TaskManager::SetIdleTask(const Task& task) { + TaskQueue* queue = pending_tasks_[task.target_thread].get(); + std::lock_guard lock_guard(queue->tasks_mutex); + assert(!queue->idle_task && "There is already an idle task"); + queue->idle_task = task; +} + +void TaskManager::RunTasks(TaskTargetThread thread, optional> max_time) { + auto start = std::chrono::high_resolution_clock::now(); + TaskQueue* queue = pending_tasks_[thread].get(); + + while (true) { + optional task; + + // Get a task. + { + std::lock_guard lock_guard(queue->tasks_mutex); + if (queue->tasks.empty()) + return; + task = queue->tasks[queue->tasks.size() - 1]; + queue->tasks.pop_back(); + } + + // Execute task. + assert(task); + task->action(); + + // If we've run past our max time stop. + auto elapsed = std::chrono::high_resolution_clock::now() - start; + if (max_time && elapsed > *max_time) + return; + } +} + +TEST_SUITE("Task"); + +TEST_CASE("tasks are run as soon as they are posted") { + TaskManager tm; + + // Post three tasks. + int next = 1; + int a = 0, b = 0, c = 0; + tm.PostTask(Task(TaskTargetThread::QueryDb, [&] { + a = next++; + })); + tm.PostTask(Task(TaskTargetThread::QueryDb, [&] { + b = next++; + })); + tm.PostTask(Task(TaskTargetThread::QueryDb, [&] { + c = next++; + })); + + // Execute all tasks. + tm.RunTasks(TaskTargetThread::QueryDb, nullopt); + + // Tasks are executed in reverse order. + REQUIRE(a == 3); + REQUIRE(b == 2); + REQUIRE(c == 1); +} + +TEST_CASE("post from inside task manager") { + TaskManager tm; + + // Post three tasks. + int next = 1; + int a = 0, b = 0, c = 0; + tm.PostTask(Task(TaskTargetThread::QueryDb, [&] () { + a = next++; + + tm.PostTask(Task(TaskTargetThread::QueryDb, [&] { + b = next++; + + tm.PostTask(Task(TaskTargetThread::QueryDb, [&] { + c = next++; + })); + })); + })); + + // Execute all tasks. + tm.RunTasks(TaskTargetThread::QueryDb, nullopt); + + // Tasks are executed in normal order because the next task is not posted + // until the previous one is executed. + REQUIRE(a == 1); + REQUIRE(b == 2); + REQUIRE(c == 3); +} + +TEST_SUITE_END(); \ No newline at end of file