diff --git a/src/language_server_api.cc b/src/language_server_api.cc index dc9c4357..00c88a23 100644 --- a/src/language_server_api.cc +++ b/src/language_server_api.cc @@ -155,7 +155,13 @@ std::unique_ptr MessageRegistry::Parse(Reader& visitor) { } Allocator& allocator = allocators[method]; - return allocator(visitor); + // FIXME Print error message for deserialization error + try { + return allocator(visitor); + } catch (std::invalid_argument& e) { + LOG_S(ERROR) << "Unable to deserialize request '" << method << "'"; + return nullptr; + } } MessageRegistry* MessageRegistry::instance() { diff --git a/src/messages/initialize.cc b/src/messages/initialize.cc index 611d5ec6..e829d527 100644 --- a/src/messages/initialize.cc +++ b/src/messages/initialize.cc @@ -10,7 +10,7 @@ #include -#include +#include // TODO Cleanup global variables extern std::string g_init_options; @@ -71,7 +71,18 @@ struct InitializeHandler : BaseMessageHandler { reader.Parse(g_init_options.c_str()); if (!reader.HasParseError()) { JsonReader json_reader{&reader}; - Reflect(json_reader, *config); + try { + Reflect(json_reader, *config); + } catch (std::invalid_argument& ex) { + // FIXME This is not triggered. Need to pass error from + // MessageRegistry::Parse in language_server_api.cc + Out_ShowLogMessage out; + out.display_type = Out_ShowLogMessage::DisplayType::Show; + out.params.type = lsMessageType::Error; + out.params.message = "Failed to deserialize " + + json_reader.GetPath() + " " + ex.what(); + out.Write(std::cout); + } } } g_enable_comments = config->enableComments; diff --git a/src/serializer.cc b/src/serializer.cc index 25b75154..99ec54b6 100644 --- a/src/serializer.cc +++ b/src/serializer.cc @@ -8,6 +8,8 @@ #include #include +#include + namespace { bool gTestOutputMode = false; } // namespace @@ -15,6 +17,7 @@ bool gTestOutputMode = false; //// Elementary types void Reflect(Reader& visitor, uint8_t& value) { + if (!visitor.IsInt()) throw std::invalid_argument("uint8_t"); value = (uint8_t)visitor.GetInt(); } void Reflect(Writer& visitor, uint8_t& value) { @@ -22,6 +25,7 @@ void Reflect(Writer& visitor, uint8_t& value) { } void Reflect(Reader& visitor, int16_t& value) { + if (!visitor.IsInt()) throw std::invalid_argument("int16_t"); value = (int16_t)visitor.GetInt(); } void Reflect(Writer& visitor, int16_t& value) { @@ -29,6 +33,7 @@ void Reflect(Writer& visitor, int16_t& value) { } void Reflect(Reader& visitor, int& value) { + if (!visitor.IsInt()) throw std::invalid_argument("int"); value = visitor.GetInt(); } void Reflect(Writer& visitor, int& value) { @@ -36,6 +41,7 @@ void Reflect(Writer& visitor, int& value) { } void Reflect(Reader& visitor, unsigned& value) { + if (!visitor.IsInt()) throw std::invalid_argument("unsigned"); value = visitor.GetUint32(); } void Reflect(Writer& visitor, unsigned& value) { @@ -43,6 +49,7 @@ void Reflect(Writer& visitor, unsigned& value) { } void Reflect(Reader& visitor, long& value) { + if (!visitor.IsInt64()) throw std::invalid_argument("long"); value = long(visitor.GetInt64()); } void Reflect(Writer& visitor, long& value) { @@ -50,6 +57,7 @@ void Reflect(Writer& visitor, long& value) { } void Reflect(Reader& visitor, unsigned long& value) { + if (!visitor.IsUint64()) throw std::invalid_argument("unsigned long"); value = (unsigned long)visitor.GetUint64(); } void Reflect(Writer& visitor, unsigned long& value) { @@ -57,6 +65,7 @@ void Reflect(Writer& visitor, unsigned long& value) { } void Reflect(Reader& visitor, long long& value) { + if (!visitor.IsInt64()) throw std::invalid_argument("long long"); value = visitor.GetInt64(); } void Reflect(Writer& visitor, long long& value) { @@ -64,6 +73,7 @@ void Reflect(Writer& visitor, long long& value) { } void Reflect(Reader& visitor, unsigned long long& value) { + if (!visitor.IsUint64()) throw std::invalid_argument("unsigned long long"); value = visitor.GetUint64(); } void Reflect(Writer& visitor, unsigned long long& value) { @@ -71,6 +81,7 @@ void Reflect(Writer& visitor, unsigned long long& value) { } void Reflect(Reader& visitor, double& value) { + if (!visitor.IsDouble()) throw std::invalid_argument("double"); value = visitor.GetDouble(); } void Reflect(Writer& visitor, double& value) { @@ -78,6 +89,7 @@ void Reflect(Writer& visitor, double& value) { } void Reflect(Reader& visitor, bool& value) { + if (!visitor.IsBool()) throw std::invalid_argument("bool"); value = visitor.GetBool(); } void Reflect(Writer& visitor, bool& value) { @@ -86,8 +98,8 @@ void Reflect(Writer& visitor, bool& value) { // std::string void Reflect(Reader& visitor, std::string& value) { - if (visitor.IsString()) - value = visitor.GetString(); + if (!visitor.IsString()) throw std::invalid_argument("std::string"); + value = visitor.GetString(); } void Reflect(Writer& visitor, std::string& value) { visitor.String(value.c_str(), (rapidjson::SizeType)value.size()); @@ -289,7 +301,14 @@ std::unique_ptr Deserialize(SerializeFormat format, file = MakeUnique(path, nullopt); JsonReader json_reader{&reader}; - Reflect(json_reader, *file); + try { + Reflect(json_reader, *file); + } catch (std::invalid_argument& e) { + LOG_S(ERROR) << "'" << path << "': failed to deserialize " + << json_reader.GetPath() << "." + << e.what(); + return nullptr; + } break; } diff --git a/src/serializer.h b/src/serializer.h index 10aedea5..c71c6ae4 100644 --- a/src/serializer.h +++ b/src/serializer.h @@ -19,12 +19,13 @@ class Reader { virtual ~Reader() {} virtual SerializeFormat Format() const = 0; - // virtual bool IsBool() = 0; + virtual bool IsBool() = 0; virtual bool IsNull() = 0; virtual bool IsArray() = 0; virtual bool IsInt() = 0; virtual bool IsInt64() = 0; - // virtual bool IsUint64() = 0; + virtual bool IsUint64() = 0; + virtual bool IsDouble() = 0; virtual bool IsString() = 0; virtual void GetNull() = 0; diff --git a/src/serializers/json.h b/src/serializers/json.h index 67cb53b0..7e766a15 100644 --- a/src/serializers/json.h +++ b/src/serializers/json.h @@ -7,17 +7,19 @@ class JsonReader : public Reader { rapidjson::GenericValue>* m_; + std::vector path_; public: JsonReader(rapidjson::GenericValue>* m) : m_(m) {} SerializeFormat Format() const override { return SerializeFormat::Json; } - // bool IsBool() override { return m_->IsBool(); } + bool IsBool() override { return m_->IsBool(); } bool IsNull() override { return m_->IsNull(); } bool IsArray() override { return m_->IsArray(); } bool IsInt() override { return m_->IsInt(); } bool IsInt64() override { return m_->IsInt64(); } - // bool IsUint64() override { return m_->IsUint64(); } + bool IsUint64() override { return m_->IsUint64(); } + bool IsDouble() override { return m_->IsDouble(); } bool IsString() override { return m_->IsString(); } void GetNull() override {} @@ -36,22 +38,35 @@ class JsonReader : public Reader { } void IterArray(std::function fn) override { + if (!m_->IsArray()) + throw std::invalid_argument("array"); + // Use "0" to indicate any element for now. + path_.push_back("0"); for (auto& entry : m_->GetArray()) { JsonReader sub(&entry); fn(sub); } + path_.pop_back(); } void DoMember(const char* name, std::function fn) override { - if (m_->GetType() != rapidjson::Type::kObjectType) - return; // FIXME: signal an error that object was not deserialized - // correctly? - + path_.push_back(name); auto it = m_->FindMember(name); if (it != m_->MemberEnd()) { JsonReader sub(&it->value); fn(sub); } + path_.pop_back(); + } + + std::string GetPath() const { + std::string ret; + for (auto& t : path_) { + ret += '/'; + ret += t; + } + ret.pop_back(); + return ret; } }; diff --git a/src/serializers/msgpack.h b/src/serializers/msgpack.h index 91b08667..f9b1666b 100644 --- a/src/serializers/msgpack.h +++ b/src/serializers/msgpack.h @@ -21,6 +21,7 @@ class MessagePackReader : public Reader { return SerializeFormat::MessagePack; } + bool IsBool() override { return oh_.get().type == msgpack::type::BOOLEAN; } bool IsNull() override { return oh_.get().is_nil(); } bool IsArray() override { return oh_.get().type == msgpack::type::ARRAY; } bool IsInt() override { @@ -28,6 +29,8 @@ class MessagePackReader : public Reader { oh_.get().type == msgpack::type::NEGATIVE_INTEGER; } bool IsInt64() override { return IsInt(); } + bool IsUint64() override { return IsInt(); } + bool IsDouble() override { return oh_.get().type == msgpack::type::FLOAT64; }; bool IsString() override { return oh_.get().type == msgpack::type::STR; } void GetNull() override { pk_->next(oh_); }