2019 Multi-University Training Contest 1 - 1012 - NTT

题目连接:
http://acm.hdu.edu.cn/showproblem.php?pid=6589

题解连接:
https://www.cnblogs.com/xusirui/p/11229450.html
https://www.cnblogs.com/FST-stay-night/p/11227505.html

NTT来自:
https://www.cnblogs.com/Sakits/p/8416918.html

题解1说,要先暴力模拟看看规律。

一个很明显的直觉是可以看看最终的序列由哪些a[i]贡献而成。但是我连暴力都不会写啊。

还不如手推:
当x为1时:
观察求前缀和的过程

0次: a[0], a[1], a[2], a[3]

1次: a[0], a[1]+a[0], a[2]+a[1]+a[0], a[3]+a[2]+a[1]+a[0]

2次: a[0], a[1]+2a[0], a[2]+2a[1]+3a[0], a[3]+2a[2]+3a[1]+4a[0]

3次: a[0], a[1]+3a[0], a[2]+3a[1]+6a[0], a[3]+3a[2]+6a[1]+10a[0]

4次: a[0], a[1]+4a[0], a[2]+4a[1]+10a[0], a[3]+4a[2]+10a[1]+20a[0]

每一项前面的系数看起来有什么规律?
0次的时候就跳过吧……
1次的时候,各个都是1?其实是 C(i,0) 。
2次的时候,是从1开始递增的。其实是 C(i+1,1) 。
3次的时候,第i项的系数看起来像 C(i+2,2) 。
4次的时候,第i项的系数看起来像 C(i+3,3) 。

所以第m次时候,系数应该是c[i]=C(m-1+i,m-1)。

m次: c[0]a[0], c[0]a[1]+c[1]a[0], c[0]a[2]+c[1]a[1]+c[2]a[0], c[0]a[3]+c[1]a[2]+c[2]a[1]+c[3]a[0]

那么其实就是数组:

a[0],a[1],a[2],a[3],a[4]...

c[0],c[1],c[2],c[3],c[4]...

做卷积的结果。

所以就预处理组合数一波,然后直接NTT。

然后其实x=2和x=3是对几个数组分开求这个前缀和。

标程给出一个更方便的做法。直接跳着赋值,例如在x=2的时候,赋值c'[0]=c[0],c'[1]=0,c'[2]=c[1],c'[3]=0,c'[4]=c[2],c'[5]=0

那么直接卷积就是:

m次: c[0]a[0], c[0]a[1], c[0]a[2]+c[1]a[0], c[0]a[3]+c[1]a[1], c[0]a[4]+c[1]a[2]+c[2]a[0]

从标程瞎改的快一倍的AC代码。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MAXN = 2e6, mod = 998244353;

inline int pow_mod(ll x, int n) {
    ll res;
    for(res = 1; n; n >>= 1, x = x * x % mod)
        if(n & 1)
            res = res * x % mod;
    return res;
}

inline int add_mod(int x, int y) {
    x += y;
    return x >= mod ? x - mod : x;
}

inline int sub_mod(int x, int y) {
    x -= y;
    return x < 0 ? x + mod : x;
}

void NTT(int a[], int n, int op) {
    for(int i = 1, j = n >> 1; i < n - 1; ++i) {
        if(i < j)
            swap(a[i], a[j]);
        int k = n >> 1;
        while(k <= j) {
            j -= k;
            k >>= 1;
        }
        j += k;
    }
    for(int len = 2; len <= n; len <<= 1) {
        int g = pow_mod(3, (mod - 1) / len);
        for(int i = 0; i < n; i += len) {
            int w = 1;
            for(int j = i; j < i + (len >> 1); ++j) {
                int u = a[j], t = 1ll * a[j + (len >> 1)] * w % mod;
                a[j] = add_mod(u, t), a[j + (len >> 1)] = sub_mod(u, t);
                w = 1ll * w * g % mod;
            }
        }
    }
    if(op == -1) {
        reverse(a + 1, a + n);
        int inv = pow_mod(n, mod - 2);
        for(int i = 0; i < n; ++i)
            a[i] = 1ll * a[i] * inv % mod;
    }
}

int A[MAXN + 5], B[MAXN + 5];
int Asize, Bsize;

int pow2(int x) {
    int res = 1;
    while(res < x)
        res <<= 1;
    return res;
}

void convolution(int A[], int B[], int Asize, int Bsize) {
    int n = pow2(Asize + Bsize - 1);
    for(int i = Asize; i < n; ++i)
        A[i] = 0;
    for(int i = Bsize; i < n; ++i)
        B[i] = 0;
    NTT(A, n, 1);
    NTT(B, n, 1);
    for(int i = 0; i < n; ++i)
        A[i] = 1ll * A[i] * B[i] % mod;
    NTT(A, n, -1);
    return;
}

const int MAXM = 2e6;

int fact[MAXM + 5], ifact[MAXM + 5];

int C(int n, int m) {
    return m <= n ? (ll)fact[n] * ifact[m] % mod * ifact[n - m] % mod : 0;
}

void init_C() {
    fact[0] = 1;
    for(int i = 1; i <= MAXM; ++i)
        fact[i] = 1ll * fact[i - 1] * i % mod;
    ifact[MAXM] = pow_mod(fact[MAXM], mod - 2);
    for(int i = MAXM - 1; i >= 0; --i)
        ifact[i] = 1ll * ifact[i + 1] * (i + 1) % mod;
}

int main() {
#ifdef Yinku
    freopen("Yinku.in", "r", stdin);
#endif // Yinku
    init_C();
    int T;
    scanf("%d", &T);
    while(T--) {
        int n, m;
        scanf("%d%d", &n, &m);
        for(int i = 0; i < n; ++i) {
            scanf("%d", &A[i]);
        }
        int cnt[] = {0, 0, 0, 0};
        for(int i = 1; i <= m; ++i) {
            int x;
            scanf("%d", &x);
            cnt[x]++;
        }
        for(int c = 1; c <= 3; ++c) {
            if(cnt[c]) {
                memset(B, 0, sizeof(B[0])*n);
                for(int i = 0; i * c < n; ++i) {
                    B[i * c] = C(cnt[c] - 1 + i, i);
                }
                convolution(A, B, n, n);
            }
        }
        ll ans = 0;
        for(int i = 0; i < n; ++i)
            ans ^= 1ll * (i + 1) * A[i];
        printf("%lld\n", ans);
    }
    return 0;
}

posted @ 2019-07-23 15:44  韵意  阅读(195)  评论(0编辑  收藏  举报