ccls/src/messages/ccls_random.cc

149 lines
4.1 KiB
C++
Raw Normal View History

2018-02-11 09:13:32 +00:00
#include "message_handler.h"
#include "query_utils.h"
#include "queue_manager.h"
#include <math.h>
#include <stdlib.h>
#include <numeric>
MAKE_HASHABLE(SymbolIdx, t.kind, t.id);
2018-02-11 09:13:32 +00:00
namespace {
2018-03-31 03:16:33 +00:00
MethodType kMethodType = "$ccls/random";
2018-03-31 03:16:33 +00:00
struct In_CclsRandom : public RequestInMessage {
MethodType GetMethodType() const override { return kMethodType; }
2018-02-11 09:13:32 +00:00
};
2018-03-31 03:16:33 +00:00
MAKE_REFLECT_STRUCT(In_CclsRandom, id);
REGISTER_IN_MESSAGE(In_CclsRandom);
2018-02-11 09:13:32 +00:00
const double kDeclWeight = 3;
const double kDamping = 0.1;
template <typename Q>
struct Kind;
template <>
2018-02-22 07:34:32 +00:00
struct Kind<QueryFunc> {
static constexpr SymbolKind value = SymbolKind::Func;
};
2018-02-11 09:13:32 +00:00
template <>
2018-02-22 07:34:32 +00:00
struct Kind<QueryType> {
static constexpr SymbolKind value = SymbolKind::Type;
};
2018-02-11 09:13:32 +00:00
template <>
2018-02-22 07:34:32 +00:00
struct Kind<QueryVar> {
static constexpr SymbolKind value = SymbolKind::Var;
};
2018-02-11 09:13:32 +00:00
template <typename Q>
void Add(const std::unordered_map<SymbolIdx, int>& sym2id,
std::vector<std::unordered_map<int, double>>& adj,
const std::vector<Id<Q>>& ids,
int n,
double w = 1) {
for (Id<Q> id : ids) {
auto it = sym2id.find(SymbolIdx{id, Kind<Q>::value});
2018-02-11 09:13:32 +00:00
if (it != sym2id.end())
adj[it->second][n] += w;
}
}
2018-03-31 03:16:33 +00:00
struct Handler_CclsRandom : BaseMessageHandler<In_CclsRandom> {
MethodType GetMethodType() const override { return kMethodType; }
2018-03-31 03:16:33 +00:00
void Run(In_CclsRandom* request) override {
2018-02-11 09:13:32 +00:00
std::unordered_map<SymbolIdx, int> sym2id;
std::vector<SymbolIdx> syms;
int n = 0;
for (RawId i = 0; i < db->funcs.size(); i++)
if (db->funcs[i].AnyDef()) {
syms.push_back(SymbolIdx{Id<void>(i), SymbolKind::Func});
2018-02-11 09:13:32 +00:00
sym2id[syms.back()] = n++;
}
for (RawId i = 0; i < db->types.size(); i++)
if (db->types[i].AnyDef()) {
syms.push_back(SymbolIdx{Id<void>(i), SymbolKind::Type});
2018-02-11 09:13:32 +00:00
sym2id[syms.back()] = n++;
}
for (RawId i = 0; i < db->vars.size(); i++)
if (db->vars[i].AnyDef()) {
syms.push_back(SymbolIdx{Id<void>(i), SymbolKind::Var});
2018-02-11 09:13:32 +00:00
sym2id[syms.back()] = n++;
}
std::vector<std::unordered_map<int, double>> adj(n);
auto add = [&](const std::vector<Use>& uses, double w) {
for (Use use : uses) {
auto it = sym2id.find(use);
2018-02-11 09:13:32 +00:00
if (it != sym2id.end())
adj[it->second][n] += w;
}
};
n = 0;
for (QueryFunc& func : db->funcs)
if (func.AnyDef()) {
2018-02-11 09:13:32 +00:00
add(func.declarations, kDeclWeight);
add(func.uses, 1);
Add(sym2id, adj, func.derived, n);
n++;
}
for (QueryType& type : db->types)
if (const auto* def = type.AnyDef()) {
2018-02-11 09:13:32 +00:00
add(type.uses, 1);
Add(sym2id, adj, type.instances, n);
Add(sym2id, adj, def->funcs, n);
Add(sym2id, adj, def->types, n);
Add(sym2id, adj, def->vars, n);
2018-02-11 09:13:32 +00:00
n++;
}
for (QueryVar& var : db->vars)
if (var.AnyDef()) {
2018-02-11 09:13:32 +00:00
add(var.declarations, kDeclWeight);
add(var.uses, 1);
n++;
}
for (int i = 0; i < n; i++) {
double sum = 0;
adj[i][i] += 1;
for (auto& it : adj[i])
sum += it.second;
for (auto& it : adj[i])
it.second = it.second / sum * (1 - kDamping);
}
std::vector<double> x(n, 1), y;
for (int j = 0; j < 8; j++) {
2018-02-11 09:13:32 +00:00
y.assign(n, kDamping);
for (int i = 0; i < n; i++)
for (auto& it : adj[i])
y[it.first] += x[i] * it.second;
double d = 0;
for (int i = 0; i < n; i++)
d = std::max(d, fabs(x[i] - y[i]));
2018-02-22 07:34:32 +00:00
if (d < 1e-5)
break;
2018-02-11 09:13:32 +00:00
x.swap(y);
}
double sum = std::accumulate(x.begin(), x.end(), 0.);
2018-02-11 09:13:32 +00:00
Out_LocationList out;
out.id = request->id;
double roulette = rand() / (RAND_MAX + 1.0) * sum;
sum = 0;
for (int i = 0; i < n; i++) {
sum += x[i];
if (sum >= roulette) {
2018-04-04 06:05:41 +00:00
if (Maybe<Use> use = GetDefinitionExtent(db, syms[i]))
if (auto ls_loc = GetLsLocationEx(db, working_files, *use,
g_config->xref.container))
out.result.push_back(*ls_loc);
2018-02-11 09:13:32 +00:00
break;
}
}
QueueManager::WriteStdout(kMethodType, out);
2018-02-11 09:13:32 +00:00
}
};
2018-03-31 03:16:33 +00:00
REGISTER_MESSAGE_HANDLER(Handler_CclsRandom);
2018-02-11 09:13:32 +00:00
} // namespace