diff --git a/src/command_line.cc b/src/command_line.cc index dfe9f78c..a395f66f 100644 --- a/src/command_line.cc +++ b/src/command_line.cc @@ -336,6 +336,7 @@ void LaunchStdinLoop(Config* config, case IpcId::CqueryBase: case IpcId::CqueryDerived: case IpcId::CqueryIndexFile: + case IpcId::CqueryRandom: case IpcId::CqueryWait: { queue->for_querydb.PushBack(std::move(message)); break; diff --git a/src/ipc.cc b/src/ipc.cc index 495ad307..01029af8 100644 --- a/src/ipc.cc +++ b/src/ipc.cc @@ -86,6 +86,8 @@ const char* IpcIdToString(IpcId id) { return "$cquery/base"; case IpcId::CqueryDerived: return "$cquery/derived"; + case IpcId::CqueryRandom: + return "$cquery/random"; case IpcId::Unknown: return "$unknown"; diff --git a/src/ipc.h b/src/ipc.h index aaab0055..59d820ec 100644 --- a/src/ipc.h +++ b/src/ipc.h @@ -57,6 +57,7 @@ enum class IpcId : int { CqueryCallers, // Show all callers of a function. CqueryBase, // Show base types/method. CqueryDerived, // Show all derived types/methods. + CqueryRandom, // Show random definition. // Internal implementation detail. Unknown, diff --git a/src/messages/cquery_random.cc b/src/messages/cquery_random.cc new file mode 100644 index 00000000..029bfd1c --- /dev/null +++ b/src/messages/cquery_random.cc @@ -0,0 +1,147 @@ +#include "message_handler.h" +#include "query_utils.h" +#include "queue_manager.h" + +#include +#include +#include + +namespace { +struct Ipc_CqueryRandom : public RequestMessage { + const static IpcId kIpcId = IpcId::CqueryRandom; +}; +MAKE_REFLECT_STRUCT(Ipc_CqueryRandom, id); +REGISTER_IPC_MESSAGE(Ipc_CqueryRandom); + +const double kDeclWeight = 3; +const double kDamping = 0.1; +const double kAlpha = 0.3; + +template +struct Kind; +template <> +struct Kind { static constexpr SymbolKind value = SymbolKind::Func; }; +template <> +struct Kind { static constexpr SymbolKind value = SymbolKind::Type; }; +template <> +struct Kind { static constexpr SymbolKind value = SymbolKind::Var; }; + +template +void Add(const std::unordered_map& sym2id, + std::vector>& adj, + const std::vector>& ids, + int n, + double w = 1) { + for (Id id : ids) { + auto it = sym2id.find(SymbolIdx{RawId(id), Kind::value}); + if (it != sym2id.end()) + adj[it->second][n] += w; + } +} + +struct CqueryRandomHandler : BaseMessageHandler { + void Run(Ipc_CqueryRandom* request) override { + std::unordered_map sym2id; + std::vector syms; + int n = 0; + + for (RawId i = 0; i < db->funcs.size(); i++) + if (db->funcs[i].def) { + syms.push_back(SymbolIdx{i, SymbolKind::Func}); + sym2id[syms.back()] = n++; + } + for (RawId i = 0; i < db->types.size(); i++) + if (db->types[i].def) { + syms.push_back(SymbolIdx{i, SymbolKind::Type}); + sym2id[syms.back()] = n++; + } + for (RawId i = 0; i < db->vars.size(); i++) + if (db->vars[i].def) { + syms.push_back(SymbolIdx{i, SymbolKind::Var}); + sym2id[syms.back()] = n++; + } + + std::vector> adj(n); + auto add = [&](const std::vector& uses, double w) { + for (Use use : uses) { + auto it = sym2id.find(SymbolIdx{RawId(use.id), use.kind}); + if (it != sym2id.end()) + adj[it->second][n] += w; + } + }; + n = 0; + for (QueryFunc& func : db->funcs) + if (func.def) { + add(func.declarations, kDeclWeight); + add(func.uses, 1); + Add(sym2id, adj, func.derived, n); + n++; + } + for (QueryType& type : db->types) + if (type.def) { + add(type.uses, 1); + Add(sym2id, adj, type.instances, n); + Add(sym2id, adj, type.def->funcs, n); + Add(sym2id, adj, type.def->types, n); + Add(sym2id, adj, type.def->vars, n); + n++; + } + for (QueryVar& var : db->vars) + if (var.def) { + add(var.declarations, kDeclWeight); + add(var.uses, 1); + n++; + } + for (int i = 0; i < n; i++) { + double sum = 0; + adj[i][i] += 1; + for (auto& it : adj[i]) + sum += it.second; + for (auto& it : adj[i]) + it.second = it.second / sum * (1 - kDamping); + } + + std::vector x(n, 1), y; + while (1) { + y.assign(n, kDamping); + for (int i = 0; i < n; i++) + for (auto& it : adj[i]) + y[it.first] += x[i] * it.second; + double d = 0; + for (int i = 0; i < n; i++) + d = std::max(d, fabs(x[i] - y[i])); + if (d < 1e-5) break; + x.swap(y); + } + + double sum = std::accumulate(x.begin(), x.end(), 0.), stdev = 0, offset = 0; + for (int i = 0; i < n; i++) + stdev += (x[i] - sum / n) * (x[i] - sum / n); + stdev = sqrt(stdev / n) * kAlpha; + for (int i = 0; i < n; i++) + offset = std::max(offset, x[i] / stdev); + sum = 0; + for (int i = 0; i < n; i++) + sum += x[i] = exp(x[i] / stdev - offset); + + Out_LocationList out; + out.id = request->id; + double roulette = rand() / (RAND_MAX + 1.0) * sum; + sum = 0; + for (int i = 0; i < n; i++) { + sum += x[i]; + if (sum >= roulette) { + Maybe use = GetDefinitionExtentOfSymbol(db, syms[i]); + if (!use) + continue; + optional ls_loc = GetLsLocation(db, working_files, *use); + if (ls_loc) + out.result.push_back(*ls_loc); + break; + } + } + QueueManager::WriteStdout(IpcId::CqueryRandom, out); + } +}; +REGISTER_MESSAGE_HANDLER(CqueryRandomHandler); +} // namespace diff --git a/src/ntstring.h b/src/ntstring.h index 3aa16687..0e6703de 100644 --- a/src/ntstring.h +++ b/src/ntstring.h @@ -35,13 +35,6 @@ class NTString { const char *p = strstr(c_str(), s); return p ? std::string::size_type(p - c_str()) : std::string::npos; } - size_type find(const char* s, size_type pos, size_type cnt) { - auto* p = (const char*)memmem(c_str() + pos, strlen(c_str()) - pos, s, cnt); - return p ? std::string::size_type(p - c_str()) : std::string::npos; - } - size_type find(std::string_view sv) { - return find(sv.data(), 0, sv.size()); - } void operator=(std::string_view sv) { str = std::unique_ptr(new char[sv.size() + 1]);