diff --git a/src/query_utils.cc b/src/query_utils.cc index 0423f03f..d75fde0d 100644 --- a/src/query_utils.cc +++ b/src/query_utils.cc @@ -283,6 +283,39 @@ optional GetBaseDefinitionOrDeclarationSpelling(QueryDatabase* db return def; } +bool HasCallersOnSelfOrBaseOrDerived(QueryDatabase* db, QueryFunc& root) { + // Check self. + if (!root.callers.empty()) + return true; + + // Check for base calls. + optional func_id = root.def.base; + while (func_id) { + optional& func = db->funcs[func_id->id]; + if (!func) + break; + if (!func->callers.empty()) + return true; + func_id = func->def.base; + } + + // Check for derived calls. + std::queue queue; + PushRange(&queue, root.derived); + while (!queue.empty()) { + optional& func = db->funcs[queue.front().id]; + queue.pop(); + if (!func) + continue; + + if (!func->derived.empty()) + return true; + PushRange(&queue, func->derived); + } + + return false; +} + std::vector GetCallersForAllBaseFunctions(QueryDatabase* db, QueryFunc& root) { std::vector callers; @@ -654,7 +687,7 @@ NonElidedVector BuildInitialCallTree(QueryDatabas entry.name = root_func->def.short_name; entry.usr = root_func->def.usr; entry.location = *def_loc; - entry.hasCallers = !root_func->callers.empty(); + entry.hasCallers = HasCallersOnSelfOrBaseOrDerived(db, *root_func); NonElidedVector result; result.push_back(entry); return result; @@ -709,7 +742,7 @@ NonElidedVector BuildExpandCallTree(QueryDatabase call_entry.name = prefix + call_func->def.short_name + " (" + format_location(*call_location) + ")"; call_entry.usr = call_func->def.usr; call_entry.location = *call_location; - call_entry.hasCallers = !call_func->callers.empty(); + call_entry.hasCallers = HasCallersOnSelfOrBaseOrDerived(db, *call_func); result.push_back(call_entry); } else { diff --git a/src/query_utils.h b/src/query_utils.h index 124bcffb..4feade09 100644 --- a/src/query_utils.h +++ b/src/query_utils.h @@ -21,6 +21,7 @@ std::vector ToQueryLocation(QueryDatabase* db, const std::vector< std::vector GetUsesOfSymbol(QueryDatabase* db, const SymbolIdx& symbol); std::vector GetDeclarationsOfSymbolForGotoDefinition(QueryDatabase* db, const SymbolIdx& symbol); optional GetBaseDefinitionOrDeclarationSpelling(QueryDatabase* db, QueryFunc& func); +bool HasCallersOnSelfOrBaseOrDerived(QueryDatabase* db, QueryFunc& root); std::vector GetCallersForAllBaseFunctions(QueryDatabase* db, QueryFunc& root); std::vector GetCallersForAllDerivedFunctions(QueryDatabase* db, QueryFunc& root); optional GetLsPosition(WorkingFile* working_file, const Position& position);