diff --git a/src/command_line.cc b/src/command_line.cc index 0dc0b43f..5f35d42e 100644 --- a/src/command_line.cc +++ b/src/command_line.cc @@ -135,6 +135,7 @@ std::unique_ptr BuildIpcMessageQueue(const std::string& name, s RegisterId(ipc.get()); RegisterId(ipc.get()); RegisterId(ipc.get()); + RegisterId(ipc.get()); RegisterId(ipc.get()); RegisterId(ipc.get()); RegisterId(ipc.get()); @@ -154,6 +155,7 @@ void RegisterMessageTypes() { MessageRegistry::instance()->Register(); MessageRegistry::instance()->Register(); MessageRegistry::instance()->Register(); + MessageRegistry::instance()->Register(); MessageRegistry::instance()->Register(); MessageRegistry::instance()->Register(); MessageRegistry::instance()->Register(); @@ -344,6 +346,36 @@ void AddCodeLens(std::vector* result, AddCodeLens(result, loc, uses0, exclude_loc, only_interesting, singular, plural); } +optional GetDefinitionOfUsr(QueryableDatabase* db, const Usr& usr) { + SymbolIdx symbol = db->usr_to_symbol[usr]; + switch (symbol.kind) { + case SymbolKind::Type: { + QueryableTypeDef* def = &db->types[symbol.idx]; + if (def->def.definition) + return def->def.definition.value(); + break; + } + case SymbolKind::Func: { + QueryableFuncDef* def = &db->funcs[symbol.idx]; + if (def->def.definition) + return def->def.definition.value(); + break; + } + case SymbolKind::Var: { + QueryableVarDef* def = &db->vars[symbol.idx]; + if (def->def.definition) + return def->def.definition.value(); + break; + } + case SymbolKind::File: + case SymbolKind::Invalid: { + assert(false && "unexpected"); + break; + } + } + return nullopt; +} + void AddCodeLens(std::vector* result, QueryableDatabase* db, QueryableLocation loc, @@ -355,32 +387,9 @@ void AddCodeLens(std::vector* result, std::vector uses0; uses0.reserve(usrs.size()); for (const Usr& usr : usrs) { - SymbolIdx symbol = db->usr_to_symbol[usr]; - switch (symbol.kind) { - case SymbolKind::Type: { - QueryableTypeDef* def = &db->types[symbol.idx]; - if (def->def.definition) - uses0.push_back(def->def.definition.value()); - break; - } - case SymbolKind::Func: { - QueryableFuncDef* def = &db->funcs[symbol.idx]; - if (def->def.definition) - uses0.push_back(def->def.definition.value()); - break; - } - case SymbolKind::Var: { - QueryableVarDef* def = &db->vars[symbol.idx]; - if (def->def.definition) - uses0.push_back(def->def.definition.value()); - break; - } - case SymbolKind::File: - case SymbolKind::Invalid: { - assert(false && "unexpected"); - break; - } - } + optional loc = GetDefinitionOfUsr(db, usr); + if (loc) + uses0.push_back(loc.value()); } AddCodeLens(result, loc, uses0, exclude_loc, only_interesting, singular, plural); } @@ -456,6 +465,7 @@ void QueryDbMainLoop( } case IpcId::TextDocumentCompletion: { + // TODO: better performance auto msg = static_cast(message.get()); Out_TextDocumentComplete response; response.id = msg->id; @@ -469,6 +479,48 @@ void QueryDbMainLoop( break; } + case IpcId::TextDocumentDefinition: { + auto msg = static_cast(message.get()); + + QueryableFile* file = FindFile(db, msg->params.textDocument.uri.GetPath()); + if (!file) { + std::cerr << "Unable to find file " << msg->params.textDocument.uri.GetPath() << std::endl; + break; + } + + Out_TextDocumentDefinition response; + response.id = msg->id; + + // TODO: Edge cases (whitespace, etc) will work a lot better + // if we store range information instead of hacking it. + int target_line = msg->params.position.line + 1; + int target_column = msg->params.position.character + 1; + int best_dist = INT_MAX; + for (const UsrRef& ref : file->all_symbols) { + if (ref.loc.line == target_line) { + if (ref.loc.column > target_column) + continue; + + int dist = target_column - ref.loc.column; + if (dist < best_dist) { + optional location = GetDefinitionOfUsr(db, ref.usr); + + if (location) { + best_dist = dist; + response.result.clear(); + response.result.push_back(GetLsLocation(location.value())); + } + + if (dist == 0) + break; + } + } + } + + SendOutMessageToClient(language_client, response); + break; + } + case IpcId::TextDocumentDocumentSymbol: { auto msg = static_cast(message.get()); @@ -575,9 +627,9 @@ void QueryDbMainLoop( // false /*exclude_loc*/, false /*only_interesting*/, "reference", // "references"); AddCodeLens(&response.result, ref.loc.OffsetColumn(1), def.callers, - false /*exclude_loc*/, false /*only_interesting*/, "caller", "callers"); - AddCodeLens(&response.result, ref.loc.OffsetColumn(2), def.def.callees, - false /*exclude_loc*/, false /*only_interesting*/, "callee", "callees"); + true /*exclude_loc*/, false /*only_interesting*/, "caller", "callers"); + //AddCodeLens(&response.result, ref.loc.OffsetColumn(2), def.def.callees, + // false /*exclude_loc*/, false /*only_interesting*/, "callee", "callees"); AddCodeLens(&response.result, db, ref.loc.OffsetColumn(3), def.derived, false /*exclude_loc*/, false /*only_interesting*/, "derived", "derived"); break; @@ -793,6 +845,8 @@ void LanguageServerStdinLoop(IpcMessageQueue* ipc) { response.result.capabilities.codeLensProvider = lsCodeLensOptions(); response.result.capabilities.codeLensProvider->resolveProvider = false; + response.result.capabilities.definitionProvider = true; + response.result.capabilities.documentSymbolProvider = true; response.result.capabilities.workspaceSymbolProvider = true; @@ -816,6 +870,7 @@ void LanguageServerStdinLoop(IpcMessageQueue* ipc) { case IpcId::TextDocumentDidChange: case IpcId::TextDocumentDidClose: { case IpcId::TextDocumentCompletion: + case IpcId::TextDocumentDefinition: case IpcId::TextDocumentDocumentSymbol: case IpcId::TextDocumentCodeLens: case IpcId::WorkspaceSymbol: @@ -903,7 +958,7 @@ int main(int argc, char** argv) { //bool loop = true; //while (loop) // std::this_thread::sleep_for(std::chrono::milliseconds(10)); - std::this_thread::sleep_for(std::chrono::seconds(3)); + //std::this_thread::sleep_for(std::chrono::seconds(3)); PlatformInit(); RegisterMessageTypes(); diff --git a/src/ipc.cc b/src/ipc.cc index 9ff09b7a..1d72d8d8 100644 --- a/src/ipc.cc +++ b/src/ipc.cc @@ -18,6 +18,8 @@ const char* IpcIdToString(IpcId id) { return "textDocument/didClose"; case IpcId::TextDocumentCompletion: return "textDocument/completion"; + case IpcId::TextDocumentDefinition: + return "textDocument/definition"; case IpcId::TextDocumentDocumentSymbol: return "textDocument/documentSymbol"; case IpcId::TextDocumentCodeLens: diff --git a/src/ipc.h b/src/ipc.h index 44d66cbe..838a098c 100644 --- a/src/ipc.h +++ b/src/ipc.h @@ -14,6 +14,7 @@ enum class IpcId : int { TextDocumentDidChange, TextDocumentDidClose, TextDocumentCompletion, + TextDocumentDefinition, TextDocumentDocumentSymbol, TextDocumentCodeLens, CodeLensResolve, diff --git a/src/language_server_api.h b/src/language_server_api.h index 7ed2e813..e4a509cb 100644 --- a/src/language_server_api.h +++ b/src/language_server_api.h @@ -1103,6 +1103,19 @@ struct Out_TextDocumentComplete : public lsOutMessage }; MAKE_REFLECT_STRUCT(Out_TextDocumentComplete, jsonrpc, id, result); +// Goto definition +struct Ipc_TextDocumentDefinition : public IpcMessage { + const static IpcId kIpcId = IpcId::TextDocumentDefinition; + + lsRequestId id; + lsTextDocumentPositionParams params; +}; +MAKE_REFLECT_STRUCT(Ipc_TextDocumentDefinition, id, params); +struct Out_TextDocumentDefinition : public lsOutMessage { + lsRequestId id; + NonElidedVector result; +}; +MAKE_REFLECT_STRUCT(Out_TextDocumentDefinition, jsonrpc, id, result); // List symbols in a document. struct lsDocumentSymbolParams { diff --git a/src/query.cc b/src/query.cc index dc1fc78e..c487ca05 100644 --- a/src/query.cc +++ b/src/query.cc @@ -151,25 +151,46 @@ QueryableFile::QueryableFile(const IndexedFile& indexed) if (location.file_id() == local_file_id) outline.push_back(UsrRef(usr, MapIdToUsr(indexed.id_cache, location))); }; + auto add_all_symbols = [this, &indexed, local_file_id](Usr usr, Location location) { + if (location.file_id() == local_file_id) + all_symbols.push_back(UsrRef(usr, MapIdToUsr(indexed.id_cache, location))); + }; for (const IndexedTypeDef& def : indexed.types) { - if (def.def.definition.has_value()) + if (def.def.definition.has_value()) { add_outline(def.def.usr, def.def.definition.value()); + add_all_symbols(def.def.usr, def.def.definition.value()); + } + for (const Location& use : def.uses) + add_all_symbols(def.def.usr, use); } for (const IndexedFuncDef& def : indexed.funcs) { - for (Location decl : def.declarations) - add_outline(def.def.usr, decl); - if (def.def.definition.has_value()) + if (def.def.definition.has_value()) { add_outline(def.def.usr, def.def.definition.value()); + add_all_symbols(def.def.usr, def.def.definition.value()); + } + for (Location decl : def.declarations) { + add_outline(def.def.usr, decl); + add_all_symbols(def.def.usr, decl); + } + for (const Location& use : def.uses) + add_all_symbols(def.def.usr, use); } for (const IndexedVarDef& def : indexed.vars) { - if (def.def.definition.has_value()) + if (def.def.definition.has_value()) { add_outline(def.def.usr, def.def.definition.value()); + add_all_symbols(def.def.usr, def.def.definition.value()); + } + for (const Location& use : def.uses) + add_all_symbols(def.def.usr, use); } std::sort(outline.begin(), outline.end(), [](const UsrRef& a, const UsrRef& b) { return a.loc < b.loc; }); + std::sort(all_symbols.begin(), all_symbols.end(), [](const UsrRef& a, const UsrRef& b) { + return a.loc < b.loc; + }); } QueryableTypeDef::QueryableTypeDef(IdCache& id_cache, const IndexedTypeDef& indexed) @@ -488,6 +509,7 @@ void IndexUpdate::Merge(const IndexUpdate& update) { INDEX_UPDATE_MERGE(files_removed); INDEX_UPDATE_MERGE(files_added); INDEX_UPDATE_MERGE(files_outline); + INDEX_UPDATE_MERGE(files_all_symbols); INDEX_UPDATE_MERGE(types_removed); INDEX_UPDATE_MERGE(types_added); diff --git a/src/query.h b/src/query.h index 517d3576..879de1b2 100644 --- a/src/query.h +++ b/src/query.h @@ -95,15 +95,18 @@ void Reflect(TVisitor& visitor, MergeableUpdate& value) { struct QueryableFile { using OutlineUpdate = MergeableUpdate; + using AllSymboslUpdate = MergeableUpdate; Usr file_id; - // Outline of the file (ie, all symbols). + // Outline of the file (ie, for code lens). std::vector outline; + // Every symbol found in the file (ie, for goto definition) + std::vector all_symbols; QueryableFile() {} // For serialization. QueryableFile(const IndexedFile& indexed); }; -MAKE_REFLECT_STRUCT(QueryableFile, file_id, outline); +MAKE_REFLECT_STRUCT(QueryableFile, file_id, outline, all_symbols); struct QueryableTypeDef { using DefUpdate = TypeDefDefinitionData; @@ -186,6 +189,7 @@ struct IndexUpdate { std::vector files_removed; std::vector files_added; std::vector files_outline; + std::vector files_all_symbols; // Type updates. std::vector types_removed; @@ -227,6 +231,7 @@ MAKE_REFLECT_STRUCT(IndexUpdate, files_removed, files_added, files_outline, + files_all_symbols, types_removed, types_added, types_def_changed,