【bzoj4836】[Lydsy2017年4月月赛]二元运算 分治+FFT
题目描述
定义二元运算 opt 满足
现在给定一个长为 n 的数列 a 和一个长为 m 的数列 b ,接下来有 q 次询问。每次询问给定一个数字 c
你需要求出有多少对 (i, j) 使得 a_i opt b_j=c 。
输入
第一行是一个整数 T (1≤T≤10) ,表示测试数据的组数。
对于每组测试数据:
第一行是三个整数 n,m,q (1≤n,m,q≤50000) 。
第二行是 n 个整数,表示 a_1,a_2,?,a_n (0≤a_1,a_2,?,a_n≤50000) 。
第三行是 m 个整数,表示 b_1,b_2,?,b_m (0≤b_1,b_2,?,b_m≤50000) 。
第四行是 q 个整数,第 i 个整数 c_i (0≤c_i≤100000) 表示第 i 次查询的数。
输出
对于每次查询,输出一行,包含一个整数,表示满足条件的 (i, j) 对的个数。
样例输入
2
2 1 5
1 3
2
1 2 3 4 5
2 2 5
1 3
2 4
1 2 3 4 5
样例输出
1
0
1
0
0
1
0
1
0
1
题解
分治+FFT
如果只有第一种运算就是裸的FFT求卷积;只有第二种运算可以把B序列翻转,然后求卷积即可。
但是有x与y大小关系的限制使得我们不能直接求卷积来得出答案。
考虑分治,对于每个区间$[l,r]$,处理出A中的$[l,mid]$与B中的$[mid+1,r]$对答案的贡献以及A中的$[mid+1,r]$与B中的$[l,mid]$对答案的贡献,这两个是有严格的x与y的大小关系的,分别使用FFT求卷积解决。再递归处理子区间即可。
时间复杂度$O(Tn\log^2n)$。一开始len开了正常的2倍导致无限TLE,QAQ
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> #define N 140010 #define rint register int using namespace std; typedef long long ll; const double pi = acos(-1); struct data { double x , y; data() {} data(double x0 , double y0) {x = x0 , y = y0;} data operator+(const data &a)const {return data(x + a.x , y + a.y);} data operator-(const data &a)const {return data(x - a.x , y - a.y);} data operator*(const data &a)const {return data(x * a.x - y * a.y , x * a.y + y * a.x);} }ta[N] , tb[N]; ll a[N] , b[N] , c[N]; inline int read() { static int ret; static char ch = getchar(); ret = 0; while(ch < '0' || ch > '9') ch = getchar(); while(ch >= '0' && ch <= '9') ret = ret * 10 + ch - '0' , ch = getchar(); return ret; } void fft(data *a , int n , int flag) { rint i , j , k; for(i = k = 0 ; i < n ; i ++ ) { if(i > k) swap(a[i] , a[k]); for(j = n >> 1 ; (k ^= j) < j ; j >>= 1); } for(k = 2 ; k <= n ; k <<= 1) { data wn(cos(2 * pi * flag / k) , sin(2 * pi * flag / k)); for(i = 0 ; i < n ; i += k) { data w(1 , 0) , t; for(j = i ; j < i + (k >> 1) ; j ++ , w = w * wn) t = w * a[j + (k >> 1)] , a[j + (k >> 1)] = a[j] - t , a[j] = a[j] + t; } } if(flag == -1) for(i = 0 ; i < n ; i ++ ) a[i].x /= n; } void work(ll *a , ll *b , int n , bool flag) { rint i; for(i = 0 ; i < n ; i ++ ) ta[i].x = a[i] , ta[i].y = ta[i + n].x = ta[i + n].y = 0; for(i = 0 ; i < n ; i ++ ) tb[i].x = (flag ? b[n - 1 - i] : b[i]) , tb[i].y = tb[i + n].x = tb[i + n].y = 0; fft(ta , n << 1 , 1) , fft(tb , n << 1 , 1); for(i = 0 ; i < n << 1 ; i ++ ) ta[i] = ta[i] * tb[i]; fft(ta , n << 1 , -1); } void solve(int l , int r) { if(l == r) { c[0] += a[l] * b[l]; return; } int mid = (l + r) >> 1 , n = r - l + 1; rint i; work(a + l , b + mid + 1 , n >> 1 , 0); for(i = 0 ; i < n ; i ++ ) c[i + l + mid + 1] += (ll)(ta[i].x + 0.5); work(a + mid + 1 , b + l , n >> 1 , 1); for(i = 0 ; i < n ; i ++ ) c[i + 1] += (ll)(ta[i].x + 0.5); solve(l , mid) , solve(mid + 1 , r); } int main() { int T; T = read(); while(T -- ) { memset(a , 0 , sizeof(a)) , memset(b , 0 , sizeof(b)) , memset(c , 0 , sizeof(c)); int n , m , q , x , k = 0 , len; rint i; n = read() , m = read() , q = read(); for(i = 1 ; i <= n ; i ++ ) x = read() , a[x] ++ , k = max(k , x); for(i = 1 ; i <= m ; i ++ ) x = read() , b[x] ++ , k = max(k , x); for(len = 1 ; len <= k ; len <<= 1); solve(0 , len - 1); while(q -- ) printf("%lld\n" , c[read()]); } return 0; }