1. 程式人生 > >「LOJ6073」「2017 山東一輪集訓 Day5」距離-主席樹+樹鏈剖分

「LOJ6073」「2017 山東一輪集訓 Day5」距離-主席樹+樹鏈剖分

Description

給你一棵 n n 個點的樹和一個排列 p p ,邊有邊權,記 d i

s t ( u , v ) dist(u, v) 表示 u
u
v v 的距離, p a t h (
u , v ) path(u, v)
表示 u u v v 路徑上所有點組成的集合,現在有 q q 次詢問,每次給出 u i u_i , v i v_i , k i k_i ,問:
i p a t h d i s t ( p i , k ) \sum_{i\in path} dist(p_i,k)
n , q 2 × 1 0 5 n, q \leq 2 × 10^5 ,強制線上。時間限制 4 s 4s ,空間限制 1 G B 1GB

Solution

首先可以把路徑轉化為兩者到根再相減。

然後考慮維護 i p a t h ( u , r o o t ) d i s ( i , k ) \sum_{i \in path(u,root)}dis(i,k) 。首先可以把 d i s dis 轉化為兩者深度減去 l c a lca 深度的兩倍。而 l c a lca 的深度為兩者鏈交的長度。所以用樹剖+主席樹維護,每次新增一個點 u u 時,把 p u p_u 到根的路徑加 1 1 。查詢時查詢點 k k r o o t root 的權值和即可。

其實就是一個維護一個點集與任意一個點的 l c a lca 的深度之和的套路。

#include <bits/stdc++.h>
using namespace std;

typedef long long lint;
const int maxn = 200005;

int n, q, type, p[maxn];

struct edge
{
	int to, next, w;
} e[maxn * 2];
int h[maxn], tot, top[maxn], fa[maxn], dfn[maxn], ord[maxn], Time, siz[maxn], dep[maxn], w[maxn], son[maxn];
lint dis[maxn], pre_d[maxn];

int rt[maxn], cnt, lch[maxn * 60], rch[maxn * 60];
lint lzy[maxn * 60], sum[maxn * 60], sum_c[maxn * 60];

inline int gi()
{
	char c = getchar();
	while (c < '0' || c > '9') c = getchar();
	int sum = 0;
	while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
	return sum;
}

inline void add(int u, int v, int w)
{
	e[++tot] = (edge) {v, h[u], w}; h[u] = tot;
	e[++tot] = (edge) {u, h[v], w}; h[v] = tot;
}

void dfs1(int u)
{
	siz[u] = 1;
	for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
		if (v != fa[u]) {
			fa[v] = u; dep[v] = dep[u] + 1; w[v] = e[i].w; dis[v] = dis[u] + e[i].w;
			dfs1(v);
			siz[u] += siz[v];
			if (siz[v] > siz[son[u]]) son[u] = v;
		}
}

void dfs2(int u)
{
	ord[dfn[u] = ++Time] = u;
	if (son[u]) top[son[u]] = top[u], dfs2(son[u]);
	for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
		if (v != fa[u] && v != son[u]) top[v] = v, dfs2(v);
}

int lca(int u, int v)
{
	while (top[u] != top[v]) {
		if (dep[top[u]] > dep[top[v]]) u = fa[top[u]];
		else v = fa[top[v]]; 
	}
	return dep[u] < dep[v] ? u : v;
}

#define mid ((l + r) >> 1)

void build(int &s, int l, int r)
{
	s = ++cnt;
	if (l == r) return sum_c[s] = w[ord[l]], void();
	build(lch[s], l, mid);
	build(rch[s], mid + 1, r);
	sum_c[s] = sum_c[lch[s]] + sum_c[rch[s]];
}

void insert(int &s, int l, int r, int x, int y)
{
	++cnt;
	sum[cnt] = sum[s]; sum_c[cnt] = sum_c[s]; lzy[cnt] = lzy[s];
	lch[cnt] = lch[s]; rch[cnt] = rch[s];
	s = cnt;

	if (x <= l && r <= y) return ++lzy[s], sum[s] += sum_c[s], void();
	if (x <= mid) insert(lch[s], l, mid, x, y);
	if (y >= mid + 1) insert(rch[s], mid + 1, r, x, y);

	sum[s] = sum[lch[s]] + sum[rch[s]] + lzy[s] * sum_c[s];
}

pair<lint, lint> operator + (const pair<lint, lint> &a, const pair<lint, lint> &b)
{
	return make_pair(a.first + b.first, a.second + b.second);
}

pair<lint, lint> query(int &s, int l, int r, int x, int y)
{
	if (x <= l && r <= y) return make_pair(sum[s], sum_c[s]);
	pair<lint, lint> res = make_pair(0, 0);
	if (x <= mid) res = res + query(lch[s], l, mid, x, y);
	if (y >= mid + 1) res = res + query(rch[s], mid + 1, r, x, y);
	res.first += res.second * lzy[s];
	return res;
}

void insert(int u)
{
	int k = u;
	rt[k] = rt[fa[u]]; u = p[u];
	while (u) {
		insert(rt[k], 1, n, dfn[top[u]], dfn[u]);
		u = fa[top[u]];
	}
}

lint query(int u, int k)
{
	if (!k) return 0;
	lint res = pre_d[k] + (dep[k] + 1) * dis[u];
	while (u) {
		res -= query(rt[k], 1, n, dfn[top[u]], dfn[u]).first << 1;
		u = fa[top[u]];
	}
	return res;
}

void dfs3(int u)
{
	pre_d[u] = pre_d[fa[u]] + dis[p[u]];
	insert(u);
	for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
		if (v != fa[u]) dfs3(v);
}

int main()
{
	type = gi();
	n = gi(); q = gi();
	for (int i = 1, u, v, w; i < n; ++i) u = gi(), v = gi(), w = gi(), add(u, v, w);
	for (int i = 1; i <= n; ++i) p[i] = gi(<