working windows IPC

This commit is contained in:
Jacob Dufault 2017-03-01 00:36:11 -08:00
parent 2fc419faa3
commit 0738b8f57a
8 changed files with 537 additions and 102 deletions

View File

@ -1,19 +1,23 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LocalDebuggerCommandArguments>--project C:\Users\jacob\Desktop\superindex\indexer\full_tests\simple_cross_reference</LocalDebuggerCommandArguments>
<LocalDebuggerCommandArguments>
</LocalDebuggerCommandArguments>
<DebuggerFlavor>WindowsLocalDebugger</DebuggerFlavor>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LocalDebuggerCommandArguments>--project C:\Users\jacob\Desktop\superindex\indexer\full_tests\simple_cross_reference</LocalDebuggerCommandArguments>
<LocalDebuggerCommandArguments>
</LocalDebuggerCommandArguments>
<DebuggerFlavor>WindowsLocalDebugger</DebuggerFlavor>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LocalDebuggerCommandArguments>--project C:\Users\jacob\Desktop\superindex\indexer\full_tests\simple_cross_reference</LocalDebuggerCommandArguments>
<LocalDebuggerCommandArguments>
</LocalDebuggerCommandArguments>
<DebuggerFlavor>WindowsLocalDebugger</DebuggerFlavor>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LocalDebuggerCommandArguments>--project C:\Users\jacob\Desktop\superindex\indexer\full_tests\simple_cross_reference</LocalDebuggerCommandArguments>
<LocalDebuggerCommandArguments>
</LocalDebuggerCommandArguments>
<DebuggerFlavor>WindowsLocalDebugger</DebuggerFlavor>
</PropertyGroup>
</Project>

View File

@ -1,7 +1,5 @@
#include "serializer.h"
#include "indexer.h"
@ -29,50 +27,6 @@ void Serialize(Writer& writer, const char* key, const std::vector<Location>& loc
writer.EndArray();
}
template<typename T>
void Serialize(Writer& writer, const char* key, Id<T> id) {
if (key) writer.Key(key);
writer.Uint64(id.id);
}
template<typename T>
void Serialize(Writer& writer, const char* key, optional<Id<T>> id) {
if (id) {
Serialize(writer, key, id.value());
}
}
template<typename T>
void Serialize(Writer& writer, const char* key, const std::vector<Id<T>>& ids) {
if (ids.size() == 0)
return;
if (key) writer.Key(key);
writer.StartArray();
for (Id<T> id : ids)
Serialize(writer, nullptr, id);
writer.EndArray();
}
template<typename T>
void Serialize(Writer& writer, const char* key, Ref<T> ref) {
if (key) writer.Key(key);
std::string s = std::to_string(ref.id.id) + "@" + ref.loc.ToString();
writer.String(s.c_str());
}
template<typename T>
void Serialize(Writer& writer, const char* key, const std::vector<Ref<T>>& refs) {
if (refs.size() == 0)
return;
if (key) writer.Key(key);
writer.StartArray();
for (Ref<T> ref : refs)
Serialize(writer, nullptr, ref);
writer.EndArray();
}
void Serialize(Writer& writer, const char* key, const std::string& value) {
if (value.size() == 0)
return;
@ -81,6 +35,18 @@ void Serialize(Writer& writer, const char* key, const std::string& value) {
writer.String(value.c_str());
}
void Serialize(Writer& writer, const char* key, const std::vector<std::string>& value) {
if (value.size() == 0)
return;
if (key) writer.Key(key);
writer.StartArray();
for (const std::string& s : value)
writer.String(s.c_str());
writer.EndArray();
}
void Serialize(Writer& writer, const char* key, uint64_t value) {
if (key) writer.Key(key);
writer.Uint64(value);
@ -174,7 +140,15 @@ void Serialize(Writer& writer, IndexedFile* file) {
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::string& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd())
output = document[name].GetString();
output = it->value.GetString();
}
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::vector<std::string>& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd()) {
for (auto& entry : it->value.GetArray())
output.push_back(entry.GetString());
}
}
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, optional<Location>& output) {
@ -191,35 +165,6 @@ void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const cha
}
}
template<typename T>
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, optional<Id<T>>& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd())
output = Id<T>(it->value.GetUint64());
}
template<typename T>
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::vector<Id<T>>& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd()) {
for (auto& array_value : it->value.GetArray())
output.push_back(Id<T>(array_value.GetUint64()));
}
}
template<typename T>
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::vector<Ref<T>>& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd()) {
for (auto& array_value : it->value.GetArray()) {
const char* str_value = array_value.GetString();
uint64_t id = atoi(str_value);
const char* loc_string = strchr(str_value, '@') + 1;
output.push_back(Ref<T>(Id<T>(id), Location(loc_string)));
}
}
}
void Deserialize(Reader& reader, IndexedFile* file) {
#define DESERIALIZE(json_name, member_name) Deserialize(entry, json_name, def.##member_name)

View File

@ -1,11 +1,99 @@
#include <rapidjson/document.h>
#include <rapidjson/prettywriter.h>
#include "indexer.h"
struct IndexedFile;
using Writer = rapidjson::PrettyWriter<rapidjson::StringBuffer>;
using Reader = rapidjson::Document;
template<typename T>
void Serialize(Writer& writer, const char* key, Id<T> id) {
if (key) writer.Key(key);
writer.Uint64(id.id);
}
template<typename T>
void Serialize(Writer& writer, const char* key, optional<Id<T>> id) {
if (id) {
Serialize(writer, key, id.value());
}
}
template<typename T>
void Serialize(Writer& writer, const char* key, const std::vector<Id<T>>& ids) {
if (ids.size() == 0)
return;
if (key) writer.Key(key);
writer.StartArray();
for (Id<T> id : ids)
Serialize(writer, nullptr, id);
writer.EndArray();
}
template<typename T>
void Serialize(Writer& writer, const char* key, Ref<T> ref) {
if (key) writer.Key(key);
std::string s = std::to_string(ref.id.id) + "@" + ref.loc.ToString();
writer.String(s.c_str());
}
template<typename T>
void Serialize(Writer& writer, const char* key, const std::vector<Ref<T>>& refs) {
if (refs.size() == 0)
return;
if (key) writer.Key(key);
writer.StartArray();
for (Ref<T> ref : refs)
Serialize(writer, nullptr, ref);
writer.EndArray();
}
void Serialize(Writer& writer, const char* key, Location location);
void Serialize(Writer& writer, const char* key, optional<Location> location);
void Serialize(Writer& writer, const char* key, const std::vector<Location>& locs);
void Serialize(Writer& writer, const char* key, const std::string& value);
void Serialize(Writer& writer, const char* key, const std::vector<std::string>& value);
void Serialize(Writer& writer, const char* key, uint64_t value);
void Serialize(Writer& writer, IndexedFile* file);
template<typename T>
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, optional<Id<T>>& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd())
output = Id<T>(it->value.GetUint64());
}
template<typename T>
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::vector<Id<T>>& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd()) {
for (auto& array_value : it->value.GetArray())
output.push_back(Id<T>(array_value.GetUint64()));
}
}
template<typename T>
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::vector<Ref<T>>& output) {
auto it = document.FindMember(name);
if (it != document.MemberEnd()) {
for (auto& array_value : it->value.GetArray()) {
const char* str_value = array_value.GetString();
uint64_t id = atoi(str_value);
const char* loc_string = strchr(str_value, '@') + 1;
output.push_back(Ref<T>(Id<T>(id), Location(loc_string)));
}
}
}
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::string& output);
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::vector<std::string>& output);
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, optional<Location>& output);
void Deserialize(rapidjson::GenericValue<rapidjson::UTF8<>>& document, const char* name, std::vector<Location>& output);
void Deserialize(Reader& reader, IndexedFile* file);
std::string Serialize(IndexedFile* file);

View File

@ -1,8 +1,285 @@
#include <iostream>
#include <vector>
#include <memory>
#include <iostream>
#include <chrono>
#include <thread>
#include <Windows.h>
const int shmem_size = 16; // 16byte
#include <rapidjson/document.h>
#include <rapidjson/prettywriter.h>
#include "serializer.h"
using Writer = rapidjson::PrettyWriter<rapidjson::StringBuffer>;
using Reader = rapidjson::Document;
struct ProcessMutex {
HANDLE mutex_ = INVALID_HANDLE_VALUE;
ProcessMutex() {
mutex_ = ::CreateMutex(nullptr, false /*initial_owner*/, "indexer_shmem_mutex");
assert(GetLastError() != ERROR_INVALID_HANDLE);
}
~ProcessMutex() {
::ReleaseMutex(mutex_);
mutex_ = INVALID_HANDLE_VALUE;
}
};
struct ScopedProcessLock {
HANDLE mutex_;
ScopedProcessLock(ProcessMutex* mutex) : mutex_(mutex->mutex_) {
WaitForSingleObject(mutex_, INFINITE);
}
~ScopedProcessLock() {
::ReleaseMutex(mutex_);
}
};
// Messages are funky objects. They contain potentially variable amounts of
// data and are passed between processes. This means that they need to be
// fully relocatable, ie, it is possible to memmove them in memory to a
// completely different address.
// TODO: Let's just pipe JSON.
struct JsonMessage {
enum class Kind {
Invalid,
CreateIndex,
ImportIndex
};
Kind kind;
size_t payload_size;
const char* payload() {
return reinterpret_cast<const char*>(this) + sizeof(JsonMessage);
}
void set_payload(size_t payload_size, const char* payload) {
char* payload_dest = reinterpret_cast<char*>(this) + sizeof(JsonMessage);
this->payload_size = payload_size;
memcpy(payload_dest, payload, payload_size);
}
};
struct BaseMessage {
JsonMessage::Kind kind;
virtual void Serialize(Writer& writer) = 0;
virtual void Deserialize(Reader& reader) = 0;
};
struct Message_ImportIndex : public BaseMessage {
std::string path;
Message_ImportIndex() {
kind = JsonMessage::Kind::ImportIndex;
}
// BaseMessage:
void Serialize(Writer& writer) override {
writer.StartObject();
::Serialize(writer, "path", path);
writer.EndObject();
}
void Deserialize(Reader& reader) override {
::Deserialize(reader, "path", path);
}
};
struct Message_CreateIndex : public BaseMessage {
std::string path;
std::vector<std::string> args;
Message_CreateIndex() {
kind = JsonMessage::Kind::CreateIndex;
}
// BaseMessage:
void Serialize(Writer& writer) override {
writer.StartObject();
::Serialize(writer, "path", path);
::Serialize(writer, "args", args);
writer.EndObject();
}
void Deserialize(Reader& reader) override {
::Deserialize(reader, "path", path);
::Deserialize(reader, "args", args);
}
};
const int shmem_size = 1024; // number of chars/bytes (256kb)
struct PlatformSharedMemory {
HANDLE shmem_;
void* shared_start_real_;
size_t* shared_bytes_used;
char* shared_start;
PlatformSharedMemory() {
shmem_ = ::CreateFileMapping(
INVALID_HANDLE_VALUE,
NULL,
PAGE_READWRITE,
0,
shmem_size,
"shared_memory_name"
);
shared_start_real_ = MapViewOfFile(shmem_, FILE_MAP_ALL_ACCESS, 0, 0, shmem_size);
shared_bytes_used = reinterpret_cast<size_t*>(shared_start_real_);
*shared_bytes_used = 0;
shared_start = reinterpret_cast<char*>(shared_bytes_used + 1);
}
~PlatformSharedMemory() {
::UnmapViewOfFile(shared_start_real_);
}
};
struct MessageMemoryBlock {
JsonMessage* ToMessage(char* ptr) {
return reinterpret_cast<JsonMessage*>(ptr);
}
JsonMessage* get_free_message() {
return reinterpret_cast<JsonMessage*>(shared.shared_start + *shared.shared_bytes_used);
}
// NOTE: We keep all pointers in terms of char* so pointer arithmetic is
// always relative to bytes.
// Pointers to shared memory.
PlatformSharedMemory shared;
ProcessMutex mutex;
char* local_block;
MessageMemoryBlock() {
local_block = new char[shmem_size];
}
~MessageMemoryBlock() {
delete[] local_block;
}
void PushMessage(BaseMessage* message) {
rapidjson::StringBuffer output;
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(output);
writer.SetFormatOptions(
rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
writer.SetIndent(' ', 2);
message->Serialize(writer);
size_t payload_size = strlen(output.GetString());
assert(payload_size < shmem_size && "Increase shared memory size, payload will never fit");
bool first = true;
bool did_log = false;
while (true) {
using namespace std::chrono_literals;
if (!first) {
if (!did_log) {
std::cout << "[info]: shmem full, waiting" << std::endl; // TODO: remove
did_log = true;
}
std::this_thread::sleep_for(16ms);
}
first = false;
ScopedProcessLock lock(&mutex);
// Try again later when there is room in shared memory.
if ((*shared.shared_bytes_used + sizeof(JsonMessage) + payload_size) >= shmem_size)
continue;
get_free_message()->kind = message->kind;
get_free_message()->set_payload(payload_size, output.GetString());
*shared.shared_bytes_used += sizeof(JsonMessage) + get_free_message()->payload_size;
assert(*shared.shared_bytes_used < shmem_size);
get_free_message()->kind = JsonMessage::Kind::Invalid;
break;
}
}
std::vector<std::unique_ptr<BaseMessage>> PopMessage() {
size_t remaining_bytes = 0;
// Move data from shared memory into a local buffer. Do this
// before parsing the blocks so that other processes can begin
// posting data as soon as possible.
{
ScopedProcessLock lock(&mutex);
remaining_bytes = *shared.shared_bytes_used;
memcpy(local_block, shared.shared_start, *shared.shared_bytes_used);
*shared.shared_bytes_used = 0;
get_free_message()->kind = JsonMessage::Kind::Invalid;
}
std::vector<std::unique_ptr<BaseMessage>> result;
char* message = local_block;
while (remaining_bytes > 0) {
std::unique_ptr<BaseMessage> base_message;
switch (ToMessage(message)->kind) {
case JsonMessage::Kind::CreateIndex:
base_message = std::make_unique<Message_CreateIndex>();
break;
case JsonMessage::Kind::ImportIndex:
base_message = std::make_unique<Message_ImportIndex>();
break;
default:
assert(false);
}
rapidjson::Document document;
document.Parse(ToMessage(message)->payload(), ToMessage(message)->payload_size);
bool has_error = document.HasParseError();
auto error = document.GetParseError();
base_message->Deserialize(document);
result.emplace_back(std::move(base_message));
remaining_bytes -= sizeof(JsonMessage) + ToMessage(message)->payload_size;
message = message + sizeof(JsonMessage) + ToMessage(message)->payload_size;
}
return result;
}
};
void reader() {
HANDLE shmem = INVALID_HANDLE_VALUE;
@ -85,5 +362,49 @@ int main52525252(int argc, char** argv) {
else
reader();
return 0;
}
int main(int argc, char** argv) {
if (argc == 2) {
MessageMemoryBlock block;
int i = 0;
while (true) {
Message_ImportIndex m;
m.path = "foo #" + std::to_string(i);
block.PushMessage(&m);
std::cout << "Sent " << i << std::endl;;
using namespace std::chrono_literals;
std::this_thread::sleep_for(10ms);
++i;
}
}
else {
MessageMemoryBlock block;
while (true) {
std::vector<std::unique_ptr<BaseMessage>> messages = block.PopMessage();
std::cout << "Got " << messages.size() << " messages" << std::endl;
for (auto& message : messages) {
rapidjson::StringBuffer output;
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(output);
writer.SetFormatOptions(
rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
writer.SetIndent(' ', 2);
message->Serialize(writer);
std::cout << " kind=" << static_cast<int>(message->kind) << ", json=" << output.GetString() << std::endl;
}
using namespace std::chrono_literals;
std::this_thread::sleep_for(5s);
}
}
return 0;
}

101
task.cc
View File

@ -1,13 +1,17 @@
#include <cassert>>
#include <condition_variable>
#include <iostream>
#include <thread>
#include <vector>
#include "compilation_database_loader.h"
#include "indexer.h"
#include "query.h"
#include "optional.h"
#include "utils.h"
#include "third_party/tiny-process-library/process.hpp"
#include <algorithm>
#include <queue>
#include <mutex>
#include <condition_variable>
@ -60,11 +64,44 @@ private:
struct Task {
int priority = 0;
bool writes_to_index = false;
bool should_exit = false;
enum class Kind {
CreateIndex,
IndexImport,
Exit
};
Kind kind;
struct CreateIndexState {
CompilationEntry data;
};
struct IndexImportState {
std::string path;
};
struct ExitState {};
// TODO: Move into a union?
CreateIndexState create_index;
IndexImportState index_import;
ExitState exit;
static Task MakeExit() {
Task task;
task.should_exit = true;
task.kind = Kind::Exit;
return task;
}
static Task MakeCreateIndexTask(CompilationEntry compilation_entry) {
Task task;
task.kind = Kind::CreateIndex;
task.create_index.data = compilation_entry;
return task;
}
static Task MakeIndexImportTask(std::string filename) {
Task task;
task.kind = Kind::IndexImport;
task.index_import.path = filename;
return task;
}
@ -91,9 +128,14 @@ struct Task {
// TODO: QueryTask
// Task running a query against the global database. Run in main process,
// separate thread.
Command query;
Location location;
std::string argument;
//Command query;
//Location location;
//std::string argument;
};
struct Config {
// Cache directory. Always ends with /
std::string cache_directory;
};
// NOTE: When something enters a value into master db, it will have to have a
@ -105,13 +147,42 @@ struct TaskManager {
// Available threads.
std::vector<std::thread> threads;
TaskManager(int num_threads);
TaskManager(int num_threads, Config* config);
};
static void ThreadMain(int id, TaskManager* tm) {
void PostTaskToIndexer(TaskManager* tm, Task task) {
tm->queued_tasks.enqueue(task);
}
void RunIndexTask(Config* config, TaskManager* tm, CompilationEntry entry) {
IndexedFile file = Parse(entry.filename, entry.args);
std::string cleaned_file_path = entry.directory + "/" + entry.filename;
std::replace(cleaned_file_path.begin(), cleaned_file_path.end(), '/', '_');
std::replace(cleaned_file_path.begin(), cleaned_file_path.end(), '\\', '_');
std::string filename = config->cache_directory + cleaned_file_path;
WriteToFile(filename, file.ToString());
PostTaskToIndexer(tm, Task::MakeIndexImportTask(filename));
}
void LoadProject(Config* config, TaskManager* tm, std::vector<CompilationEntry> entries) {
for (CompilationEntry entry : entries) {
tm->queued_tasks.enqueue(Task::MakeCreateIndexTask(entry));
}
}
static void ThreadMain(int id, Config* config, TaskManager* tm) {
while (true) {
Task task = tm->queued_tasks.dequeue();
if (task.should_exit) {
switch (task.kind) {
case Task::Kind::CreateIndex:
RunIndexTask(config, tm, task.create_index.data);
break;
case Task::Kind::IndexImport:
assert(false);
break;
case Task::Kind::Exit:
std::cout << id << ": Exiting" << std::endl;
return;
}
@ -121,9 +192,9 @@ static void ThreadMain(int id, TaskManager* tm) {
}
TaskManager::TaskManager(int num_threads) {
TaskManager::TaskManager(int num_threads, Config* config) {
for (int i = 0; i < num_threads; ++i) {
threads.push_back(std::thread(&ThreadMain, i, this));
threads.push_back(std::thread(&ThreadMain, i, config, this));
}
}
@ -131,8 +202,10 @@ void Pump(TaskManager* tm) {
//tm->threads[0].
}
int main5555555555(int argc, char** argv) {
TaskManager tm(5);
int main252525225(int argc, char** argv) {
Config config;
TaskManager tm(5, &config);
LoadProject(&config, &tm, LoadCompilationEntriesFromDirectory("full_tests/simple_cross_reference"));
// TODO: looks like we will have to write shared memory support.
@ -140,8 +213,8 @@ int main5555555555(int argc, char** argv) {
// Repeat until we encounter a writer, wait for all threads to signal
// they are done.
// TODO: Let's use a thread safe queue/vector/etc instead.
for (int i = 0; i < 10; ++i)
tm.queued_tasks.enqueue(Task::MakeExit());
//for (int i = 0; i < 10; ++i)
// tm.queued_tasks.enqueue(Task::MakeExit());
for (std::thread& thread : tm.threads)
thread.join();

View File

@ -1,5 +1,6 @@
#include "indexer.h"
#include "serializer.h"
#include "utils.h"
void Write(const std::vector<std::string>& strs) {
for (const std::string& str : strs)
@ -71,11 +72,6 @@ void DiffDocuments(rapidjson::Document& expected, rapidjson::Document& actual) {
}
}
void WriteToFile(const std::string& filename, const std::string& content) {
std::ofstream file(filename);
file << content;
}
void VerifySerializeToFrom(IndexedFile* file) {
std::string expected = file->ToString();
std::string actual = Deserialize("foo.cc", Serialize(file)).ToString();
@ -85,7 +81,7 @@ void VerifySerializeToFrom(IndexedFile* file) {
}
}
int main(int argc, char** argv) {
int main333(int argc, char** argv) {
// TODO: Assert that we need to be on clang >= 3.9.1
/*

View File

@ -72,4 +72,9 @@ void ParseTestExpectation(std::string filename, std::string* expected_output) {
void Fail(const std::string& message) {
std::cerr << "Fatal error: " << message << std::endl;
std::exit(1);
}
void WriteToFile(const std::string& filename, const std::string& content) {
std::ofstream file(filename);
file << content;
}

View File

@ -7,4 +7,7 @@ std::vector<std::string> GetFilesInFolder(std::string folder);
std::vector<std::string> ReadLines(std::string filename);
void ParseTestExpectation(std::string filename, std::string* expected_output);
void Fail(const std::string& message);
void Fail(const std::string& message);
void WriteToFile(const std::string& filename, const std::string& content);