「LOJ6073」「2017 山東一輪集訓 Day5」距離-主席樹+樹鏈剖分
阿新 • • 發佈:2018-12-31
Description
給你一棵
個點的樹和一個排列
,邊有邊權,記
表示
到
的距離,
表示
到
路徑上所有點組成的集合,現在有
次詢問,每次給出
,
,
,問:
,強制線上。時間限制
,空間限制
。
Solution
首先可以把路徑轉化為兩者到根再相減。
然後考慮維護 。首先可以把 轉化為兩者深度減去 深度的兩倍。而 的深度為兩者鏈交的長度。所以用樹剖+主席樹維護,每次新增一個點 時,把 到根的路徑加 。查詢時查詢點 到 的權值和即可。
其實就是一個維護一個點集與任意一個點的 的深度之和的套路。
#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
const int maxn = 200005;
int n, q, type, p[maxn];
struct edge
{
int to, next, w;
} e[maxn * 2];
int h[maxn], tot, top[maxn], fa[maxn], dfn[maxn], ord[maxn], Time, siz[maxn], dep[maxn], w[maxn], son[maxn];
lint dis[maxn], pre_d[maxn];
int rt[maxn], cnt, lch[maxn * 60], rch[maxn * 60];
lint lzy[maxn * 60], sum[maxn * 60], sum_c[maxn * 60];
inline int gi()
{
char c = getchar();
while (c < '0' || c > '9') c = getchar();
int sum = 0;
while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
return sum;
}
inline void add(int u, int v, int w)
{
e[++tot] = (edge) {v, h[u], w}; h[u] = tot;
e[++tot] = (edge) {u, h[v], w}; h[v] = tot;
}
void dfs1(int u)
{
siz[u] = 1;
for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
if (v != fa[u]) {
fa[v] = u; dep[v] = dep[u] + 1; w[v] = e[i].w; dis[v] = dis[u] + e[i].w;
dfs1(v);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u)
{
ord[dfn[u] = ++Time] = u;
if (son[u]) top[son[u]] = top[u], dfs2(son[u]);
for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
if (v != fa[u] && v != son[u]) top[v] = v, dfs2(v);
}
int lca(int u, int v)
{
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]]) u = fa[top[u]];
else v = fa[top[v]];
}
return dep[u] < dep[v] ? u : v;
}
#define mid ((l + r) >> 1)
void build(int &s, int l, int r)
{
s = ++cnt;
if (l == r) return sum_c[s] = w[ord[l]], void();
build(lch[s], l, mid);
build(rch[s], mid + 1, r);
sum_c[s] = sum_c[lch[s]] + sum_c[rch[s]];
}
void insert(int &s, int l, int r, int x, int y)
{
++cnt;
sum[cnt] = sum[s]; sum_c[cnt] = sum_c[s]; lzy[cnt] = lzy[s];
lch[cnt] = lch[s]; rch[cnt] = rch[s];
s = cnt;
if (x <= l && r <= y) return ++lzy[s], sum[s] += sum_c[s], void();
if (x <= mid) insert(lch[s], l, mid, x, y);
if (y >= mid + 1) insert(rch[s], mid + 1, r, x, y);
sum[s] = sum[lch[s]] + sum[rch[s]] + lzy[s] * sum_c[s];
}
pair<lint, lint> operator + (const pair<lint, lint> &a, const pair<lint, lint> &b)
{
return make_pair(a.first + b.first, a.second + b.second);
}
pair<lint, lint> query(int &s, int l, int r, int x, int y)
{
if (x <= l && r <= y) return make_pair(sum[s], sum_c[s]);
pair<lint, lint> res = make_pair(0, 0);
if (x <= mid) res = res + query(lch[s], l, mid, x, y);
if (y >= mid + 1) res = res + query(rch[s], mid + 1, r, x, y);
res.first += res.second * lzy[s];
return res;
}
void insert(int u)
{
int k = u;
rt[k] = rt[fa[u]]; u = p[u];
while (u) {
insert(rt[k], 1, n, dfn[top[u]], dfn[u]);
u = fa[top[u]];
}
}
lint query(int u, int k)
{
if (!k) return 0;
lint res = pre_d[k] + (dep[k] + 1) * dis[u];
while (u) {
res -= query(rt[k], 1, n, dfn[top[u]], dfn[u]).first << 1;
u = fa[top[u]];
}
return res;
}
void dfs3(int u)
{
pre_d[u] = pre_d[fa[u]] + dis[p[u]];
insert(u);
for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
if (v != fa[u]) dfs3(v);
}
int main()
{
type = gi();
n = gi(); q = gi();
for (int i = 1, u, v, w; i < n; ++i) u = gi(), v = gi(), w = gi(), add(u, v, w);
for (int i = 1; i <= n; ++i) p[i] = gi(<