1. 程式人生 > >5877 Weak Pair (離散化+dfs+樹狀陣列)

5877 Weak Pair (離散化+dfs+樹狀陣列)

Weak Pair

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 5327    Accepted Submission(s): 1543


 

Problem Description

You are given a rooted tree of N nodes, labeled from 1 to N . To the i th node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
  (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
  (2) au×av≤k .

Can you find the number of weak pairs in the tree?

 

Input

There are multiple cases in the data set.
  The first line of input contains an integer T denoting number of test cases.
  For each case, the first line contains two space-separated integers, N and k , respectively.
  The second line contains N space-separated integers, denoting a1 to aN .
  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v .

  Constrains:
  
  1≤N≤105
  
  0≤ai≤109
  
  0≤k≤1018

 

Output

For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.

 

Sample Input

1 2 3 1 2 1 2

 

Sample Output

1

        大體題意就是找到多找個(u,v)得組合使value[u]*value[v]<=k;其中u是v得祖先; 

        首先我們可以想到,求這種組合方式我們可以通過dfs,查詢到某點時,看看對於這個v點而言能讓組合成立的(u,v)到底有多少種。這樣的話,我們可以想到用樹狀陣列,查詢價值小於一個數(k/value[v])的祖先點到底有幾個。

        這樣的話又出現了一個問題,就是資料範圍太大了,於是我們可以先把每個點v的value以及能使他們相乘<=k的值need[v]預處理出來,然後進行離散化就可以了。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int maxn = 200005;
const int inf = 0x3f3f3f3f;
#define ll long long int 
int n, in[maxn], c[maxn];
ll k, ans, v[maxn], need[maxn], sot[maxn];
vector<int>G[maxn];
void init() {
	for (int s = 1; s <= n; s++)
		G[s].clear();
	memset(c, 0, sizeof(c));
	memset(in, 0, sizeof(in));
	ans = 0;
}
int lowbit(int x) {
	return x&(-x);
}
int sum(int x) {
	int ans = 0;
	while (x) {
		ans += c[x];
		x -= lowbit(x);
	}
	return ans;
}
void add(int x, int d) {
	while (x <= 2 * n) {
		c[x] += d;
		x += lowbit(x);
	}
}
void dfs(int x) {
	ans += sum(need[x]);
	add(v[x], 1);
	int sz = G[x].size();
	for (int s = 0; s < sz; s++)
		dfs(G[x][s]);
	add(v[x], -1);
}
int main() {
	int te;
	scanf("%d", &te);
	while (te--) {
		scanf("%d%lld", &n, &k);
		init();
		for (int s = 1; s <= n; s++) {
			scanf("%lld", &v[s]);
			if (v[s] == 0)
				need[s] = 1e18;
			else 
				need[s] = k / v[s];
			sot[s - 1] = v[s];
			sot[n + s - 2] = need[s];
		}
		sort(sot, sot + 2 * n);
		for (int s = 1; s <= n; s++) {
			v[s] = lower_bound(sot, sot + 2 * n, v[s]) - sot + 1;
			need[s]= lower_bound(sot, sot + 2 * n, need[s]) - sot + 1;
		}
		for (int s = 1; s < n; s++) {
			int a, b;
			scanf("%d%d", &a, &b);
			G[a].push_back(b);
			in[b] = 1;
		}
		for (int s = 1; s <= n; s++) {
			if (!in[s]) {
				dfs(s);
				break;
			}
		}
		printf("%lld\n", ans);
	}
}