5877 Weak Pair (離散化+dfs+樹狀陣列)
Weak PairTime Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)Total Submission(s): 5327 Accepted Submission(s): 1543 Problem Description You are given a rooted tree of N nodes, labeled from 1 to N . To the i th node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if Input There are multiple cases in the data set. Output For each test case, print a single integer on a single line denoting the number of weak pairs in the tree. Sample Input 1 2 3 1 2 1 2 Sample Output 1 |
大體題意就是找到多找個(u,v)得組合使value[u]*value[v]<=k;其中u是v得祖先;
首先我們可以想到,求這種組合方式我們可以通過dfs,查詢到某點時,看看對於這個v點而言能讓組合成立的(u,v)到底有多少種。這樣的話,我們可以想到用樹狀陣列,查詢價值小於一個數(k/value[v])的祖先點到底有幾個。
這樣的話又出現了一個問題,就是資料範圍太大了,於是我們可以先把每個點v的value以及能使他們相乘<=k的值need[v]預處理出來,然後進行離散化就可以了。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int maxn = 200005;
const int inf = 0x3f3f3f3f;
#define ll long long int
int n, in[maxn], c[maxn];
ll k, ans, v[maxn], need[maxn], sot[maxn];
vector<int>G[maxn];
void init() {
for (int s = 1; s <= n; s++)
G[s].clear();
memset(c, 0, sizeof(c));
memset(in, 0, sizeof(in));
ans = 0;
}
int lowbit(int x) {
return x&(-x);
}
int sum(int x) {
int ans = 0;
while (x) {
ans += c[x];
x -= lowbit(x);
}
return ans;
}
void add(int x, int d) {
while (x <= 2 * n) {
c[x] += d;
x += lowbit(x);
}
}
void dfs(int x) {
ans += sum(need[x]);
add(v[x], 1);
int sz = G[x].size();
for (int s = 0; s < sz; s++)
dfs(G[x][s]);
add(v[x], -1);
}
int main() {
int te;
scanf("%d", &te);
while (te--) {
scanf("%d%lld", &n, &k);
init();
for (int s = 1; s <= n; s++) {
scanf("%lld", &v[s]);
if (v[s] == 0)
need[s] = 1e18;
else
need[s] = k / v[s];
sot[s - 1] = v[s];
sot[n + s - 2] = need[s];
}
sort(sot, sot + 2 * n);
for (int s = 1; s <= n; s++) {
v[s] = lower_bound(sot, sot + 2 * n, v[s]) - sot + 1;
need[s]= lower_bound(sot, sot + 2 * n, need[s]) - sot + 1;
}
for (int s = 1; s < n; s++) {
int a, b;
scanf("%d%d", &a, &b);
G[a].push_back(b);
in[b] = 1;
}
for (int s = 1; s <= n; s++) {
if (!in[s]) {
dfs(s);
break;
}
}
printf("%lld\n", ans);
}
}