【bzoj4836】[Lydsy2017年4月月赛]二元运算 分治+FFT

题目描述

定义二元运算 opt 满足
 
现在给定一个长为 n 的数列 a 和一个长为 m 的数列 b ,接下来有 q 次询问。每次询问给定一个数字 c 
你需要求出有多少对 (i, j) 使得 a_i  opt b_j=c 。

输入

第一行是一个整数 T (1≤T≤10) ,表示测试数据的组数。
对于每组测试数据:
第一行是三个整数 n,m,q (1≤n,m,q≤50000) 。
第二行是 n 个整数,表示 a_1,a_2,?,a_n (0≤a_1,a_2,?,a_n≤50000) 。
第三行是 m 个整数,表示 b_1,b_2,?,b_m (0≤b_1,b_2,?,b_m≤50000) 。
第四行是 q 个整数,第 i 个整数 c_i (0≤c_i≤100000) 表示第 i 次查询的数。

输出

对于每次查询,输出一行,包含一个整数,表示满足条件的 (i, j) 对的个数。

样例输入

2
2 1 5
1 3
2
1 2 3 4 5
2 2 5
1 3
2 4
1 2 3 4 5

样例输出

1
0
1
0
0
1
0
1
0
1


题解

分治+FFT

如果只有第一种运算就是裸的FFT求卷积;只有第二种运算可以把B序列翻转,然后求卷积即可。

但是有x与y大小关系的限制使得我们不能直接求卷积来得出答案。

考虑分治,对于每个区间$[l,r]$,处理出A中的$[l,mid]$与B中的$[mid+1,r]$对答案的贡献以及A中的$[mid+1,r]$与B中的$[l,mid]$对答案的贡献,这两个是有严格的x与y的大小关系的,分别使用FFT求卷积解决。再递归处理子区间即可。

时间复杂度$O(Tn\log^2n)$。一开始len开了正常的2倍导致无限TLE,QAQ

#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 140010
#define rint register int
using namespace std;
typedef long long ll;
const double pi = acos(-1);
struct data
{
	double x , y;
	data() {}
	data(double x0 , double y0) {x = x0 , y = y0;}
	data operator+(const data &a)const {return data(x + a.x , y + a.y);}
	data operator-(const data &a)const {return data(x - a.x , y - a.y);}
	data operator*(const data &a)const {return data(x * a.x - y * a.y , x * a.y + y * a.x);}
}ta[N] , tb[N];
ll a[N] , b[N] , c[N];
inline int read()
{
	static int ret; static char ch = getchar();
	ret = 0;
	while(ch < '0' || ch > '9') ch = getchar();
	while(ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0' , ch = getchar();
	return ret;
}
void fft(data *a , int n , int flag)
{
	rint i , j , k;
	for(i = k = 0 ; i < n ; i ++ )
	{
		if(i > k) swap(a[i] , a[k]);
		for(j = n >> 1 ; (k ^= j) < j ; j >>= 1);
	}
	for(k = 2 ; k <= n ; k <<= 1)
	{
		data wn(cos(2 * pi * flag / k) , sin(2 * pi * flag / k));
		for(i = 0 ; i < n ; i += k)
		{
			data w(1 , 0) , t;
			for(j = i ; j < i + (k >> 1) ; j ++ , w = w * wn)
				t = w * a[j + (k >> 1)] , a[j + (k >> 1)] = a[j] - t , a[j] = a[j] + t;
		}
	}
	if(flag == -1) for(i = 0 ; i < n ; i ++ ) a[i].x /= n;
}
void work(ll *a , ll *b , int n , bool flag)
{
	rint i;
	for(i = 0 ; i < n ; i ++ ) ta[i].x = a[i] , ta[i].y = ta[i + n].x = ta[i + n].y = 0;
	for(i = 0 ; i < n ; i ++ ) tb[i].x = (flag ? b[n - 1 - i] : b[i]) , tb[i].y = tb[i + n].x = tb[i + n].y = 0;
	fft(ta , n << 1 , 1) , fft(tb , n << 1 , 1);
	for(i = 0 ; i < n << 1 ; i ++ ) ta[i] = ta[i] * tb[i];
	fft(ta , n << 1 , -1); 
}
void solve(int l , int r)
{
	if(l == r)
	{
		c[0] += a[l] * b[l];
		return;
	}
	int mid = (l + r) >> 1 , n = r - l + 1;
	rint i;
	work(a + l , b + mid + 1 , n >> 1 , 0);
	for(i = 0 ; i < n ; i ++ ) c[i + l + mid + 1] += (ll)(ta[i].x + 0.5);
	work(a + mid + 1 , b + l , n >> 1 , 1);
	for(i = 0 ; i < n ; i ++ ) c[i + 1] += (ll)(ta[i].x + 0.5);
	solve(l , mid) , solve(mid + 1 , r);
}
int main()
{
	int T;
	T = read();
	while(T -- )
	{
		memset(a , 0 , sizeof(a)) , memset(b , 0 , sizeof(b)) , memset(c , 0 , sizeof(c));
		int n , m , q , x , k = 0 , len;
		rint i;
		n = read() , m = read() , q = read();
		for(i = 1 ; i <= n ; i ++ ) x = read() , a[x] ++ , k = max(k , x);
		for(i = 1 ; i <= m ; i ++ ) x = read() , b[x] ++ , k = max(k , x);
		for(len = 1 ; len <= k ; len <<= 1);
		solve(0 , len - 1);
		while(q -- ) printf("%lld\n" , c[read()]);
	}
	return 0;
}

 

 

posted @ 2017-08-26 10:02  GXZlegend  阅读(383)  评论(0编辑  收藏  举报