This commit is contained in:
Jacob Dufault 2017-03-11 16:36:00 -08:00
parent cc4d49794d
commit fdd798f995
4 changed files with 202 additions and 274 deletions

View File

@ -134,38 +134,44 @@ std::unique_ptr<InMessage> ParseMessage() {
template<typename T>
struct BaseIpcMessage : public IpcMessage {
BaseIpcMessage() : IpcMessage(T::kIpcId) {}
// IpcMessage:
void Serialize(Writer& writer) override {
T& value = *static_cast<T*>(this);
Reflect(writer, value);
}
void Deserialize(Reader& reader) override {
T& value = *static_cast<T*>(this);
Reflect(reader, value);
}
};
struct IpcMessage_Quit : public BaseIpcMessage<IpcMessage_Quit> { struct IpcMessage_Quit : public BaseIpcMessage<IpcMessage_Quit> {
static IpcMessageId kId; static constexpr IpcId kIpcId = IpcId::Quit;
}; };
IpcMessageId IpcMessage_Quit::kId = "Quit"; template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_Quit& value) {}
struct IpcMessage_IsAlive : public BaseIpcMessage<IpcMessage_IsAlive> { struct IpcMessage_IsAlive : public BaseIpcMessage<IpcMessage_IsAlive> {
static IpcMessageId kId; static constexpr IpcId kIpcId = IpcId::IsAlive;
}; };
IpcMessageId IpcMessage_IsAlive::kId = "IsAlive"; template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_IsAlive& value) {}
struct IpcMessage_OpenProject : public BaseIpcMessage<IpcMessage_OpenProject> { struct IpcMessage_OpenProject : public BaseIpcMessage<IpcMessage_OpenProject> {
static IpcMessageId kId; static constexpr IpcId kIpcId = IpcId::OpenProject;
std::string project_path; std::string project_path;
// BaseIpcMessage:
void Serialize(Writer& writer) override {
writer.String(project_path.c_str(), project_path.size());
}
void Deserialize(Reader& reader) override {
project_path = reader.GetString();
}
}; };
IpcMessageId IpcMessage_OpenProject::kId = "OpenProject"; template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_OpenProject& value) {
Reflect(visitor, value.project_path);
}
@ -173,108 +179,72 @@ IpcMessageId IpcMessage_OpenProject::kId = "OpenProject";
struct IpcMessage_LanguageServerRequest : public BaseIpcMessage<IpcMessage_LanguageServerRequest> {
static constexpr IpcId kIpcId = IpcId::LanguageServerRequest;
// TODO: provide a way to get the request state.
lsMethodId method_id;
};
template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_LanguageServerRequest& value) {
REFLECT_MEMBER_START();
REFLECT_MEMBER(method_id);
REFLECT_MEMBER_END();
}
struct IpcMessage_DocumentSymbolsRequest : public BaseIpcMessage<IpcMessage_DocumentSymbolsRequest> { struct IpcMessage_DocumentSymbolsRequest : public BaseIpcMessage<IpcMessage_DocumentSymbolsRequest> {
RequestId id; static constexpr IpcId kIpcId = IpcId::DocumentSymbolsRequest;
RequestId request_id;
std::string document; std::string document;
// BaseIpcMessage:
static IpcMessageId kId;
void Serialize(Writer& visitor) override {
// TODO: dedup
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(document);
REFLECT_MEMBER_END();
}
void Deserialize(Reader& visitor) override {
// TODO: dedup
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(document);
REFLECT_MEMBER_END();
}
}; };
IpcMessageId IpcMessage_DocumentSymbolsRequest::kId = "IpcMessage_DocumentSymbolsRequest"; template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_DocumentSymbolsRequest& value) {
REFLECT_MEMBER_START();
REFLECT_MEMBER(request_id);
REFLECT_MEMBER(document);
REFLECT_MEMBER_END();
}
struct IpcMessage_DocumentSymbolsResponse : public BaseIpcMessage<IpcMessage_DocumentSymbolsResponse> { struct IpcMessage_DocumentSymbolsResponse : public BaseIpcMessage<IpcMessage_DocumentSymbolsResponse> {
RequestId id; static constexpr IpcId kIpcId = IpcId::DocumentSymbolsResponse;
RequestId request_id;
std::vector<lsSymbolInformation> symbols; std::vector<lsSymbolInformation> symbols;
// BaseIpcMessage:
static IpcMessageId kId;
void Serialize(Writer& visitor) override {
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(symbols);
REFLECT_MEMBER_END();
}
void Deserialize(Reader& visitor) override {
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(symbols);
REFLECT_MEMBER_END();
}
}; };
IpcMessageId IpcMessage_DocumentSymbolsResponse::kId = "IpcMessage_DocumentSymbolsResponse"; template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_DocumentSymbolsResponse& value) {
REFLECT_MEMBER_START();
REFLECT_MEMBER(request_id);
REFLECT_MEMBER(symbols);
REFLECT_MEMBER_END();
}
struct IpcMessage_WorkspaceSymbolsRequest : public BaseIpcMessage<IpcMessage_WorkspaceSymbolsRequest> { struct IpcMessage_WorkspaceSymbolsRequest : public BaseIpcMessage<IpcMessage_WorkspaceSymbolsRequest> {
RequestId id; static constexpr IpcId kIpcId = IpcId::WorkspaceSymbolsRequest;
RequestId request_id;
std::string query; std::string query;
// BaseIpcMessage:
static IpcMessageId kId;
void Serialize(Writer& visitor) override {
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(query);
REFLECT_MEMBER_END();
}
void Deserialize(Reader& visitor) override {
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(query);
REFLECT_MEMBER_END();
}
}; };
IpcMessageId IpcMessage_WorkspaceSymbolsRequest::kId = "IpcMessage_WorkspaceSymbolsRequest"; template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_WorkspaceSymbolsRequest& value) {
REFLECT_MEMBER_START();
REFLECT_MEMBER(request_id);
REFLECT_MEMBER(query);
REFLECT_MEMBER_END();
}
struct IpcMessage_WorkspaceSymbolsResponse : public BaseIpcMessage<IpcMessage_WorkspaceSymbolsResponse> { struct IpcMessage_WorkspaceSymbolsResponse : public BaseIpcMessage<IpcMessage_WorkspaceSymbolsResponse> {
RequestId id; static constexpr IpcId kIpcId = IpcId::WorkspaceSymbolsResponse;
RequestId request_id;
std::vector<lsSymbolInformation> symbols; std::vector<lsSymbolInformation> symbols;
// BaseIpcMessage:
static IpcMessageId kId;
void Serialize(Writer& visitor) override {
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(symbols);
REFLECT_MEMBER_END();
}
void Deserialize(Reader& visitor) override {
auto& value = *this;
REFLECT_MEMBER_START();
REFLECT_MEMBER(id);
REFLECT_MEMBER(symbols);
REFLECT_MEMBER_END();
}
}; };
IpcMessageId IpcMessage_WorkspaceSymbolsResponse::kId = "IpcMessage_WorkspaceSymbolsResponse"; template<typename TVisitor>
void Reflect(TVisitor& visitor, IpcMessage_WorkspaceSymbolsResponse& value) {
REFLECT_MEMBER_START();
REFLECT_MEMBER(request_id);
REFLECT_MEMBER(symbols);
REFLECT_MEMBER_END();
}
@ -301,21 +271,23 @@ IpcMessageId IpcMessage_WorkspaceSymbolsResponse::kId = "IpcMessage_WorkspaceSym
void QueryDbMainLoop(IpcServer* ipc, QueryableDatabase* db) { void QueryDbMainLoop(IpcServer* ipc, QueryableDatabase* db) {
std::vector<std::unique_ptr<BaseIpcMessageElided>> messages = ipc->TakeMessages(); std::vector<std::unique_ptr<IpcMessage>> messages = ipc->TakeMessages();
for (auto& message : messages) { for (auto& message : messages) {
std::cerr << "Processing message " << message->runtime_id() << " (hash " << message->hashed_runtime_id() << ")" << std::endl; std::cerr << "Processing message " << static_cast<int>(message->ipc_id) << std::endl;
if (IpcMessage_Quit::kId == message->runtime_id()) { switch (message->ipc_id) {
case IpcId::Quit: {
break; break;
} }
else if (IpcMessage_IsAlive::kId == message->runtime_id()) { case IpcId::IsAlive: {
IpcMessage_IsAlive response; IpcMessage_IsAlive response;
ipc->SendToClient(0, &response); // todo: make non-blocking ipc->SendToClient(0, &response); // todo: make non-blocking
break;
} }
else if (IpcMessage_OpenProject::kId == message->runtime_id()) { case IpcId::OpenProject: {
IpcMessage_OpenProject* msg = static_cast<IpcMessage_OpenProject*>(message.get()); IpcMessage_OpenProject* msg = static_cast<IpcMessage_OpenProject*>(message.get());
std::string path = msg->project_path; std::string path = msg->project_path;
@ -330,14 +302,14 @@ void QueryDbMainLoop(IpcServer* ipc, QueryableDatabase* db) {
db->ApplyIndexUpdate(&update); db->ApplyIndexUpdate(&update);
} }
std::cerr << "Done" << std::endl; std::cerr << "Done" << std::endl;
break;
} }
case IpcId::DocumentSymbolsRequest: {
else if (IpcMessage_DocumentSymbolsRequest::kId == message->runtime_id()) {
auto msg = static_cast<IpcMessage_DocumentSymbolsRequest*>(message.get()); auto msg = static_cast<IpcMessage_DocumentSymbolsRequest*>(message.get());
IpcMessage_DocumentSymbolsResponse response; IpcMessage_DocumentSymbolsResponse response;
response.id = msg->id; response.request_id = msg->request_id;
std::cerr << "Wanted file " << msg->document << std::endl; std::cerr << "Wanted file " << msg->document << std::endl;
for (auto& file : db->files) { for (auto& file : db->files) {
@ -355,7 +327,7 @@ void QueryDbMainLoop(IpcServer* ipc, QueryableDatabase* db) {
lsSymbolInformation info; lsSymbolInformation info;
info.location.range.start.line = ref.loc.line - 1; // TODO: cleanup indexer to negate by 1. info.location.range.start.line = ref.loc.line - 1; // TODO: cleanup indexer to negate by 1.
info.location.range.start.character = ref.loc.column - 1; // TODO: cleanup indexer to negate by 1. info.location.range.start.character = ref.loc.column - 1; // TODO: cleanup indexer to negate by 1.
// TODO: store range information. // TODO: store range information.
info.location.range.end.line = info.location.range.start.line; info.location.range.end.line = info.location.range.start.line;
info.location.range.end.character = info.location.range.start.character; info.location.range.end.character = info.location.range.start.character;
@ -405,13 +377,15 @@ void QueryDbMainLoop(IpcServer* ipc, QueryableDatabase* db) {
ipc->SendToClient(0, &response); ipc->SendToClient(0, &response);
break;
} }
else if (IpcMessage_WorkspaceSymbolsRequest::kId == message->runtime_id()) { case IpcId::WorkspaceSymbolsRequest: {
auto msg = static_cast<IpcMessage_WorkspaceSymbolsRequest*>(message.get()); auto msg = static_cast<IpcMessage_WorkspaceSymbolsRequest*>(message.get());
IpcMessage_WorkspaceSymbolsResponse response; IpcMessage_WorkspaceSymbolsResponse response;
response.id = msg->id; response.request_id = msg->request_id;
std::cerr << "- Considering " << db->qualified_names.size() << " candidates " << std::endl; std::cerr << "- Considering " << db->qualified_names.size() << " candidates " << std::endl;
@ -493,12 +467,14 @@ void QueryDbMainLoop(IpcServer* ipc, QueryableDatabase* db) {
ipc->SendToClient(0, &response); ipc->SendToClient(0, &response);
break;
} }
else { default: {
std::cerr << "Unhandled IPC message with kind " << message->runtime_id() << " (hash " << message->hashed_runtime_id() << ")" << std::endl; std::cerr << "Unhandled IPC message with kind " << static_cast<int>(message->ipc_id) << std::endl;
exit(1); exit(1);
} }
}
} }
} }
@ -575,7 +551,7 @@ void LanguageServerStdinLoop(IpcClient* ipc) {
auto request = static_cast<In_DocumentSymbolRequest*>(message.get()); auto request = static_cast<In_DocumentSymbolRequest*>(message.get());
IpcMessage_DocumentSymbolsRequest ipc_request; IpcMessage_DocumentSymbolsRequest ipc_request;
ipc_request.id = request->id.value(); ipc_request.request_id = request->id.value();
ipc_request.document = request->params.textDocument.uri.GetPath(); ipc_request.document = request->params.textDocument.uri.GetPath();
std::cerr << "Request textDocument=" << ipc_request.document << std::endl; std::cerr << "Request textDocument=" << ipc_request.document << std::endl;
ipc->SendToServer(&ipc_request); ipc->SendToServer(&ipc_request);
@ -586,7 +562,7 @@ void LanguageServerStdinLoop(IpcClient* ipc) {
{ {
auto request = static_cast<In_WorkspaceSymbolRequest*>(message.get()); auto request = static_cast<In_WorkspaceSymbolRequest*>(message.get());
IpcMessage_WorkspaceSymbolsRequest ipc_request; IpcMessage_WorkspaceSymbolsRequest ipc_request;
ipc_request.id = request->id.value(); ipc_request.request_id = request->id.value();
ipc_request.query = request->params.query; ipc_request.query = request->params.query;
std::cerr << "Request query=" << ipc_request.query << std::endl; std::cerr << "Request query=" << ipc_request.query << std::endl;
ipc->SendToServer(&ipc_request); ipc->SendToServer(&ipc_request);
@ -597,37 +573,41 @@ void LanguageServerStdinLoop(IpcClient* ipc) {
} }
void LanguageServerMainLoop(IpcClient* ipc) { void LanguageServerMainLoop(IpcClient* ipc) {
std::vector<std::unique_ptr<BaseIpcMessageElided>> messages = ipc->TakeMessages(); std::vector<std::unique_ptr<IpcMessage>> messages = ipc->TakeMessages();
for (auto& message : messages) { for (auto& message : messages) {
if (IpcMessage_Quit::kId == message->runtime_id()) { switch (message->ipc_id) {
case IpcId::Quit: {
exit(0); exit(0);
break;
} }
case IpcId::DocumentSymbolsResponse: {
else if (IpcMessage_DocumentSymbolsResponse::kId == message->runtime_id()) {
auto msg = static_cast<IpcMessage_DocumentSymbolsResponse*>(message.get()); auto msg = static_cast<IpcMessage_DocumentSymbolsResponse*>(message.get());
auto response = Out_DocumentSymbolResponse(); auto response = Out_DocumentSymbolResponse();
response.id = msg->id; response.id = msg->request_id;
response.result = msg->symbols; response.result = msg->symbols;
response.Send(); response.Send();
std::cerr << "Send symbol response to client (" << response.result.size() << " symbols)" << std::endl; std::cerr << "Send symbol response to client (" << response.result.size() << " symbols)" << std::endl;
break;
} }
else if (IpcMessage_WorkspaceSymbolsResponse::kId == message->runtime_id()) { case IpcId::WorkspaceSymbolsResponse: {
auto msg = static_cast<IpcMessage_WorkspaceSymbolsResponse*>(message.get()); auto msg = static_cast<IpcMessage_WorkspaceSymbolsResponse*>(message.get());
auto response = Out_WorkspaceSymbolResponse(); auto response = Out_WorkspaceSymbolResponse();
response.id = msg->id; response.id = msg->request_id;
response.result = msg->symbols; response.result = msg->symbols;
response.Send(); response.Send();
std::cerr << "Send symbol response to client (" << response.result.size() << " symbols)" << std::endl; std::cerr << "Send symbol response to client (" << response.result.size() << " symbols)" << std::endl;
break;
} }
else { default: {
std::cerr << "Unhandled IPC message with kind " << message->runtime_id() << " (hash " << message->hashed_runtime_id() << ")" << std::endl; std::cerr << "Unhandled IPC message with kind " << static_cast<int>(message->ipc_id) << std::endl;
exit(1); exit(1);
} }
}
} }
} }
@ -645,10 +625,10 @@ void LanguageServerMain(std::string process_name) {
std::this_thread::sleep_for(std::chrono::milliseconds(20)); std::this_thread::sleep_for(std::chrono::milliseconds(20));
// Check if we got an IsAlive message back. // Check if we got an IsAlive message back.
std::vector<std::unique_ptr<BaseIpcMessageElided>> messages = client_ipc.TakeMessages(); std::vector<std::unique_ptr<IpcMessage>> messages = client_ipc.TakeMessages();
bool has_server = false; bool has_server = false;
for (auto& message : messages) { for (auto& message : messages) {
if (message->runtime_id() == IpcMessage_IsAlive::kId) { if (IpcId::IsAlive == message->ipc_id) {
has_server = true; has_server = true;
break; break;
} }
@ -668,13 +648,13 @@ void LanguageServerMain(std::string process_name) {
/*stderr*/[](const char* bytes, size_t n) { /*stderr*/[](const char* bytes, size_t n) {
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)
std::cerr << bytes[i]; std::cerr << bytes[i];
}, },
/*open_stdin*/false); /*open_stdin*/false);
std::this_thread::sleep_for(std::chrono::seconds(1)); std::this_thread::sleep_for(std::chrono::seconds(1));
// Pass empty process name so we only try to start the querydb once. // Pass empty process name so we only try to start the querydb once.
LanguageServerMain(""); LanguageServerMain("");
return; return;
} }
#endif #endif
// for debugging attach // for debugging attach
@ -767,16 +747,14 @@ int main(int argc, char** argv) {
_setmode(_fileno(stdin), O_BINARY); _setmode(_fileno(stdin), O_BINARY);
#endif #endif
IpcRegistry::instance()->Register<IpcMessage_Quit>(); IpcRegistry::instance()->Register<IpcMessage_Quit>(IpcId::Quit);
IpcRegistry::instance()->Register<IpcMessage_IsAlive>(IpcId::IsAlive);
IpcRegistry::instance()->Register<IpcMessage_OpenProject>(IpcId::OpenProject);
IpcRegistry::instance()->Register<IpcMessage_IsAlive>(); IpcRegistry::instance()->Register<IpcMessage_DocumentSymbolsRequest>(IpcId::DocumentSymbolsRequest);
IpcRegistry::instance()->Register<IpcMessage_OpenProject>(); IpcRegistry::instance()->Register<IpcMessage_DocumentSymbolsResponse>(IpcId::DocumentSymbolsResponse);
IpcRegistry::instance()->Register<IpcMessage_WorkspaceSymbolsRequest>(IpcId::WorkspaceSymbolsRequest);
IpcRegistry::instance()->Register<IpcMessage_DocumentSymbolsRequest>(); IpcRegistry::instance()->Register<IpcMessage_WorkspaceSymbolsResponse>(IpcId::WorkspaceSymbolsResponse);
IpcRegistry::instance()->Register<IpcMessage_DocumentSymbolsResponse>();
IpcRegistry::instance()->Register<IpcMessage_WorkspaceSymbolsRequest>();
IpcRegistry::instance()->Register<IpcMessage_WorkspaceSymbolsResponse>();
MessageRegistry::instance()->Register<In_CancelRequest>(); MessageRegistry::instance()->Register<In_CancelRequest>();
MessageRegistry::instance()->Register<In_InitializeRequest>(); MessageRegistry::instance()->Register<In_InitializeRequest>();

58
ipc.cc
View File

@ -1,15 +1,19 @@
#include "ipc.h" #include "ipc.h"
#include "serializer.h"
namespace { namespace {
struct JsonMessage { struct JsonMessage {
int message_id; IpcId ipc_id;
size_t payload_size; size_t payload_size;
const char* payload(); const char* payload() {
void SetPayload(size_t payload_size, const char* payload); return reinterpret_cast<const char*>(this) + sizeof(JsonMessage);
}
void SetPayload(size_t payload_size, const char* payload) {
char* payload_dest = reinterpret_cast<char*>(this) + sizeof(JsonMessage);
this->payload_size = payload_size;
memcpy(payload_dest, payload, payload_size);
}
}; };
JsonMessage* get_free_message(IpcDirectionalChannel* channel) { JsonMessage* get_free_message(IpcDirectionalChannel* channel) {
@ -34,24 +38,10 @@ namespace {
} }
} }
const char* JsonMessage::payload() {
return reinterpret_cast<const char*>(this) + sizeof(JsonMessage);
}
void JsonMessage::SetPayload(size_t payload_size, const char* payload) {
char* payload_dest = reinterpret_cast<char*>(this) + sizeof(JsonMessage);
this->payload_size = payload_size;
memcpy(payload_dest, payload, payload_size);
}
void BaseIpcMessageElided::Serialize(Writer& writer) {}
void BaseIpcMessageElided::Deserialize(Reader& reader) {}
IpcRegistry* IpcRegistry::instance_ = nullptr; IpcRegistry* IpcRegistry::instance_ = nullptr;
std::unique_ptr<BaseIpcMessageElided> IpcRegistry::Allocate(int id) { std::unique_ptr<IpcMessage> IpcRegistry::Allocate(IpcId id) {
return std::unique_ptr<BaseIpcMessageElided>((*allocators)[id]()); return std::unique_ptr<IpcMessage>((*allocators)[id]());
} }
IpcDirectionalChannel::IpcDirectionalChannel(const std::string& name) { IpcDirectionalChannel::IpcDirectionalChannel(const std::string& name) {
@ -64,7 +54,9 @@ IpcDirectionalChannel::~IpcDirectionalChannel() {
delete[] local_block; delete[] local_block;
} }
void IpcDirectionalChannel::PushMessage(BaseIpcMessageElided* message) { void IpcDirectionalChannel::PushMessage(IpcMessage* message) {
assert(message->ipc_id != IpcId::Invalid);
rapidjson::StringBuffer output; rapidjson::StringBuffer output;
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(output); rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(output);
writer.SetFormatOptions( writer.SetFormatOptions(
@ -95,18 +87,18 @@ void IpcDirectionalChannel::PushMessage(BaseIpcMessageElided* message) {
if ((*shared->shared_bytes_used + sizeof(JsonMessage) + payload_size) >= shmem_size) if ((*shared->shared_bytes_used + sizeof(JsonMessage) + payload_size) >= shmem_size)
continue; continue;
get_free_message(this)->message_id = message->hashed_runtime_id(); get_free_message(this)->ipc_id = message->ipc_id;
get_free_message(this)->SetPayload(payload_size, output.GetString()); get_free_message(this)->SetPayload(payload_size, output.GetString());
*shared->shared_bytes_used += sizeof(JsonMessage) + get_free_message(this)->payload_size; *shared->shared_bytes_used += sizeof(JsonMessage) + get_free_message(this)->payload_size;
assert(*shared->shared_bytes_used < shmem_size); assert(*shared->shared_bytes_used < shmem_size);
get_free_message(this)->message_id = -1; get_free_message(this)->ipc_id = IpcId::Invalid;
break; break;
} }
} }
std::vector<std::unique_ptr<BaseIpcMessageElided>> IpcDirectionalChannel::TakeMessages() { std::vector<std::unique_ptr<IpcMessage>> IpcDirectionalChannel::TakeMessages() {
size_t remaining_bytes = 0; size_t remaining_bytes = 0;
// Move data from shared memory into a local buffer. Do this // Move data from shared memory into a local buffer. Do this
// before parsing the blocks so that other processes can begin // before parsing the blocks so that other processes can begin
@ -117,14 +109,14 @@ std::vector<std::unique_ptr<BaseIpcMessageElided>> IpcDirectionalChannel::TakeMe
memcpy(local_block, shared->shared_start, *shared->shared_bytes_used); memcpy(local_block, shared->shared_start, *shared->shared_bytes_used);
*shared->shared_bytes_used = 0; *shared->shared_bytes_used = 0;
get_free_message(this)->message_id = -1; get_free_message(this)->ipc_id = IpcId::Invalid;
} }
std::vector<std::unique_ptr<BaseIpcMessageElided>> result; std::vector<std::unique_ptr<IpcMessage>> result;
char* message = local_block; char* message = local_block;
while (remaining_bytes > 0) { while (remaining_bytes > 0) {
std::unique_ptr<BaseIpcMessageElided> base_message = IpcRegistry::instance()->Allocate(as_message(message)->message_id); std::unique_ptr<IpcMessage> base_message = IpcRegistry::instance()->Allocate(as_message(message)->ipc_id);
rapidjson::Document document; rapidjson::Document document;
document.Parse(as_message(message)->payload(), as_message(message)->payload_size); document.Parse(as_message(message)->payload(), as_message(message)->payload_size);
@ -147,7 +139,7 @@ std::vector<std::unique_ptr<BaseIpcMessageElided>> IpcDirectionalChannel::TakeMe
IpcServer::IpcServer(const std::string& name) IpcServer::IpcServer(const std::string& name)
: name_(name), server_(NameToServerName(name)) {} : name_(name), server_(NameToServerName(name)) {}
void IpcServer::SendToClient(int client_id, BaseIpcMessageElided* message) { void IpcServer::SendToClient(int client_id, IpcMessage* message) {
// Find or create the client. // Find or create the client.
auto it = clients_.find(client_id); auto it = clients_.find(client_id);
if (it == clients_.end()) if (it == clients_.end())
@ -156,17 +148,17 @@ void IpcServer::SendToClient(int client_id, BaseIpcMessageElided* message) {
clients_[client_id]->PushMessage(message); clients_[client_id]->PushMessage(message);
} }
std::vector<std::unique_ptr<BaseIpcMessageElided>> IpcServer::TakeMessages() { std::vector<std::unique_ptr<IpcMessage>> IpcServer::TakeMessages() {
return server_.TakeMessages(); return server_.TakeMessages();
} }
IpcClient::IpcClient(const std::string& name, int client_id) IpcClient::IpcClient(const std::string& name, int client_id)
: server_(NameToServerName(name)), client_(NameToClientName(name, client_id)) {} : server_(NameToServerName(name)), client_(NameToClientName(name, client_id)) {}
void IpcClient::SendToServer(BaseIpcMessageElided* message) { void IpcClient::SendToServer(IpcMessage* message) {
server_.PushMessage(message); server_.PushMessage(message);
} }
std::vector<std::unique_ptr<BaseIpcMessageElided>> IpcClient::TakeMessages() { std::vector<std::unique_ptr<IpcMessage>> IpcClient::TakeMessages() {
return client_.TakeMessages(); return client_.TakeMessages();
} }

126
ipc.h
View File

@ -15,63 +15,46 @@
// TODO: We need to add support for payloads larger than the maximum shared memory buffer size. // TODO: We need to add support for payloads larger than the maximum shared memory buffer size.
using IpcMessageId = std::string; enum class IpcId : int {
// Invalid request id.
Invalid = 0,
struct BaseIpcMessageElided { Quit = 1,
virtual IpcMessageId runtime_id() const = 0; IsAlive,
virtual int hashed_runtime_id() const = 0; OpenProject,
virtual void Serialize(Writer& writer); // This is a language server request. The actual request method
virtual void Deserialize(Reader& reader); // id is embedded within the request state.
LanguageServerRequest,
// TODO: remove
DocumentSymbolsRequest,
DocumentSymbolsResponse,
WorkspaceSymbolsRequest,
WorkspaceSymbolsResponse
}; };
// Usage:
//
// class IpcMessage_Foo : public BaseIpcMessage<IpcMessage_Foo> {
// static IpcMessageId kId;
//
// // BaseIpcMessage:
// ...
// }
// IpcMessageId IpcMessage_Foo::kId = "Foo";
//
//
// main() {
// IpcRegistry::instance()->Register<IpcMessage_Foo>();
// }
//
// Note: This is a template so that the statics are stored separately
// per type.
template<typename T>
struct BaseIpcMessage : BaseIpcMessageElided {
BaseIpcMessage();
virtual ~BaseIpcMessage();
// Populated by IpcRegistry::RegisterAllocator. struct IpcMessage {
static IpcMessageId runtime_id_; IpcMessage(IpcId ipc_id) : ipc_id(ipc_id) {}
static int hashed_runtime_id_; virtual ~IpcMessage() {}
// BaseIpcMessageElided: const IpcId ipc_id;
IpcMessageId runtime_id() const override {
return runtime_id_; virtual void Serialize(Writer& writer) = 0;
} virtual void Deserialize(Reader& reader) = 0;
int hashed_runtime_id() const override {
return hashed_runtime_id_;
}
}; };
struct IpcRegistry { struct IpcRegistry {
using Allocator = std::function<BaseIpcMessageElided*()>; using Allocator = std::function<IpcMessage*()>;
// Use unique_ptrs so we can initialize on first use // Use unique_ptrs so we can initialize on first use
// (static init order might not be right). // (static init order might not be right).
std::unique_ptr<std::unordered_map<int, Allocator>> allocators; std::unique_ptr<std::unordered_map<IpcId, Allocator>> allocators;
std::unique_ptr<std::unordered_map<int, std::string>> hash_to_id;
template<typename T> template<typename T>
void Register(); void Register(IpcId id);
std::unique_ptr<BaseIpcMessageElided> Allocate(int id); std::unique_ptr<IpcMessage> Allocate(IpcId id);
static IpcRegistry* instance() { static IpcRegistry* instance() {
if (!instance_) if (!instance_)
@ -82,25 +65,16 @@ struct IpcRegistry {
}; };
template<typename T> template<typename T>
void IpcRegistry::Register() { void IpcRegistry::Register(IpcId id) {
if (!allocators) { if (!allocators)
allocators = MakeUnique<std::unordered_map<int, Allocator>>(); allocators = MakeUnique<std::unordered_map<IpcId, Allocator>>();
hash_to_id = MakeUnique<std::unordered_map<int, std::string>>();
}
IpcMessageId id = T::kId; assert(allocators->find(id) == allocators->end() &&
"There is already an IPC message with the given id");
int hash = std::hash<IpcMessageId>()(id); (*allocators)[id] = [id]() {
auto it = allocators->find(hash);
assert(allocators->find(hash) == allocators->end() && "There is already an IPC message with the given id");
(*hash_to_id)[hash] = id;
(*allocators)[hash] = []() {
return new T(); return new T();
}; };
T::runtime_id_ = id;
T::hashed_runtime_id_ = hash;
} }
@ -115,8 +89,8 @@ struct IpcDirectionalChannel {
explicit IpcDirectionalChannel(const std::string& name); explicit IpcDirectionalChannel(const std::string& name);
~IpcDirectionalChannel(); ~IpcDirectionalChannel();
void PushMessage(BaseIpcMessageElided* message); void PushMessage(IpcMessage* message);
std::vector<std::unique_ptr<BaseIpcMessageElided>> TakeMessages(); std::vector<std::unique_ptr<IpcMessage>> TakeMessages();
// Pointer to process shared memory and process shared mutex. // Pointer to process shared memory and process shared mutex.
std::unique_ptr<PlatformSharedMemory> shared; std::unique_ptr<PlatformSharedMemory> shared;
@ -129,8 +103,8 @@ struct IpcDirectionalChannel {
struct IpcServer { struct IpcServer {
IpcServer(const std::string& name); IpcServer(const std::string& name);
void SendToClient(int client_id, BaseIpcMessageElided* message); void SendToClient(int client_id, IpcMessage* message);
std::vector<std::unique_ptr<BaseIpcMessageElided>> TakeMessages(); std::vector<std::unique_ptr<IpcMessage>> TakeMessages();
private: private:
std::string name_; std::string name_;
@ -141,36 +115,12 @@ private:
struct IpcClient { struct IpcClient {
IpcClient(const std::string& name, int client_id); IpcClient(const std::string& name, int client_id);
void SendToServer(BaseIpcMessageElided* message); void SendToServer(IpcMessage* message);
std::vector<std::unique_ptr<BaseIpcMessageElided>> TakeMessages(); std::vector<std::unique_ptr<IpcMessage>> TakeMessages();
IpcDirectionalChannel* client() { return &client_; } IpcDirectionalChannel* client() { return &client_; }
private: private:
IpcDirectionalChannel server_; IpcDirectionalChannel server_;
IpcDirectionalChannel client_; IpcDirectionalChannel client_;
}; };
template<typename T>
BaseIpcMessage<T>::BaseIpcMessage() {
assert(!runtime_id_.empty() && "Message is not registered using IpcRegistry::RegisterAllocator");
}
template<typename T>
BaseIpcMessage<T>::~BaseIpcMessage() {}
template<typename T>
IpcMessageId BaseIpcMessage<T>::runtime_id_;
template<typename T>
int BaseIpcMessage<T>::hashed_runtime_id_ = -1;

View File

@ -18,6 +18,21 @@ using std::experimental::nullopt;
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
enum class lsMethodId : int {
// Language server specific requests.
CancelRequest = 0,
Initialize,
Initialized,
TextDocumentDocumentSymbol,
WorkspaceSymbol,
};
template<typename TVisitor>
void Reflect(TVisitor& visitor, lsMethodId& value) {
int value0 = static_cast<int>(value);
Reflect(visitor, value0);
value = static_cast<lsMethodId>(value0);
}
struct RequestId { struct RequestId {
optional<int> id0; optional<int> id0;
@ -25,11 +40,12 @@ struct RequestId {
}; };
void Reflect(Writer& visitor, RequestId& value) { void Reflect(Writer& visitor, RequestId& value) {
assert(value.id0.has_value() || value.id1.has_value());
if (value.id0) { if (value.id0) {
Reflect(visitor, value.id0.value()); Reflect(visitor, value.id0.value());
} }
else { else {
assert(value.id1.has_value());
Reflect(visitor, value.id1.value()); Reflect(visitor, value.id1.value());
} }
} }
@ -214,14 +230,6 @@ struct OutNotificationMessage : public OutMessage {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
enum class lsMethodId {
CancelRequest,
Initialize,
Initialized,
TextDocumentDocumentSymbol,
WorkspaceSymbol,
};
const char* MethodIdToString(lsMethodId id) { const char* MethodIdToString(lsMethodId id) {
switch (id) { switch (id) {
case lsMethodId::CancelRequest: case lsMethodId::CancelRequest:
@ -303,18 +311,18 @@ struct InMessage {
const lsMethodId method_id; const lsMethodId method_id;
optional<RequestId> id; optional<RequestId> id;
InMessage(lsMethodId method_id, optional<RequestId> id, Reader& reader) InMessage(lsMethodId method_id, optional<RequestId> id, Reader& reader)
// We verify there are no duplicate hashes inside of MessageRegistry. // We verify there are no duplicate hashes inside of MessageRegistry.
: method_id(method_id), id(id) {} : method_id(method_id), id(id) {}
}; };
struct InRequestMessage : public InMessage { struct InRequestMessage : public InMessage {
InRequestMessage(lsMethodId method, optional<RequestId> id, Reader& reader) InRequestMessage(lsMethodId method, optional<RequestId> id, Reader& reader)
: InMessage(method, id, reader) {} : InMessage(method, id, reader) {}
}; };
struct InNotificationMessage : public InMessage { struct InNotificationMessage : public InMessage {
InNotificationMessage(lsMethodId method, optional<RequestId> id, Reader& reader) InNotificationMessage(lsMethodId method, optional<RequestId> id, Reader& reader)
: InMessage(method, id, reader) {} : InMessage(method, id, reader) {}
}; };