【洛谷3321_BZOJ3992】[SDOI2015]序列统计(原根_多项式)

题目:

洛谷3321

分析:

一个转化思路比较神(典型?)的题……

一个比较显然的\(O(n^3)\)暴力是用\(f[i][j]\)表示选了\(i\)个数,当前积在模\(m\)意义下为\(j\)的方案数,每次转移枚举\(S\)的元素,即(\(k^{-1}\)表示\(k\)在模\(m\)意义下的逆元):

\[f[i][j]=\sum_{k\in S} f[i-1][jk^{-1}] \]

事实上写的时候通常是从\(f[i][j]\)\(f[i+1][jk]\)贡献

然后通过Orz题解发现那个乘法\(jk^{-1}\)非常地丑,如果能变成加法就好了qwq。注意到保证\(m\)是一个质数,所以是可以求原根的。原根的好处在于模\(m\)的意义下\(g^i(i\in [0,m-1))\)\(a(a \in [1,m-1])\)一一对应。所以如果用原根的幂代替原数,原数的乘法就变成了指数的加法。即设\(f[i][j]\)表示选了\(i\)个数,积在模\(m\)意义下是\(g^j\)的方案数,则:

\[f[i][j]=\sum_{g^k \in S} f[i-1][j-k] \]

上面那个东西像不像卷积?一点都不像看到\(j-k\)难道你就不想再来一个跟\(k\)有关的东西吗?于是定义一个函数\(h(x)\)

\[h(x)=\begin{cases}1&(g^x\in S) \\ 0&(otherwise)\end{cases}\]

那么就成了:

\[f[i][j]=\sum_{k=0}^{m-2}h(k)f[i-1][j-k] \]

设一个多项式\(A\),第\(i\)项系数是\(h(i)\),则一开始\(f[1]\)就是\(A\),所以答案就是\(A^n\)\(x\)次项系数,写一个多项式快速幂即可。

代码:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
#include <set>
using namespace std;

namespace zyt
{
    template<typename T>
    inline bool read(T &x)
    {
        char c;
        bool f = false;
        x = 0;
        do
            c = getchar();
        while (c != EOF && c != '-' && !isdigit(c));
        if (c == EOF)
            return false;
        if (c == '-')
            f = true, c = getchar();
        do
            x = x * 10 + c - '0', c = getchar();
        while (isdigit(c));
        if (f)
            x = -x;
        return true;
    }
    template<typename T>
    inline void write(T x)
    {
        static char buf[20];
        char *pos = buf;
        if (x < 0)
            putchar('-'), x = -x;
        do
            *pos++ = x % 10 + '0';
        while (x /= 10);
        while (pos > buf)
            putchar(*--pos);
    }
    typedef long long ll;
    const int N = 8010, p = 1004535809;
    int n, m;
    set<int> s;
    namespace Polynomial
    {
        const int LEN = N << 2;
        int omega[LEN], winv[LEN], rev[LEN];
        inline int power(int a, int b, const int p = ::zyt::p)
        {
            int ans = 1;
            while (b)
            {
                if (b & 1)
                    ans = (ll)ans * a % p;
                a = (ll)a * a % p;
                b >>= 1;
            }
            return ans;
        }
        inline int inv(const int a, const int p = ::zyt::p)
        {
            return power(a, p - 2);
        }
        namespace Primitive_Root
        {
            pair<int, int>prime[N];
            int cnt;
            inline void get_prime(int n)
            {
                cnt = 0;
                for (int i = 2; i * i <= n; i++)
                {
                    if (n % i == 0)
                    {
                        prime[cnt++] = make_pair(i, 0);
                        while (n % i == 0)
                            ++prime[cnt - 1].second, n /= i;
                    }
                }
                if (n > 1)
                    prime[cnt++] = make_pair(n, 1);
            }
            inline int get_g(const int n)
            {
                get_prime(n - 1);
                for (int i = 2; i < n; i++)
                {
                    bool flag = true;
                    for (int j = 0; j < cnt && flag; j++)
                        flag &= (power(i, (n - 1) / prime[j].first, n) != 1);
                    if (flag)
                        return i;
                }
                return -1;
            }
        }
        void init(const int n, const int lg2)
        {
            static int g = 0;
            if (!g)
                g = Primitive_Root::get_g(p);
            int w = power(g, (p - 1) / n), wi = inv(w);
            omega[0] = winv[0] = 1;
            for (int i = 1; i < n; i++)
            {
                omega[i] = (ll)omega[i - 1] * w % p;
                winv[i] = (ll)winv[i - 1] * wi % p;
            }
            for (int i = 0; i < n; i++)
                rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
        }
        void ntt(int *a, const int *w, const int n)
        {
            for (int i = 0; i < n; i++)
                if (i < rev[i])
                    swap(a[i], a[rev[i]]);
            for (int l = 1; l < n; l <<= 1)
                for (int i = 0; i < n; i += (l << 1))
                    for (int k = 0; k < l; k++)
                    {
                        int tmp = (a[i + k] - (ll)a[i + l + k] * w[n / (l << 1) * k] % p + p) % p;
                        a[i + k] = (a[i + k] + (ll)a[i + l + k]	* w[n / (l << 1) * k] % p) % p;
                        a[i + l + k] = tmp;
                    }
        }
        void mul(const int *a, const int *b, int *c, const int n)
        {
            static int x[LEN], y[LEN];
            memcpy(x, a, sizeof(int[n]));
            memcpy(y, b, sizeof(int[n]));
            int m = 1, lg2 = 0;
            while (m < (n << 1))
                m <<= 1, ++lg2;
            init(m, lg2);
            memset(x + n, 0, sizeof(int[m - n]));
            memset(y + n, 0, sizeof(int[m - n]));
            ntt(x, omega, m), ntt(y, omega, m);
            for (int i = 0; i < m; i++)
                x[i] = (ll)x[i] * y[i] % p;
            ntt(x, winv, m);
            int invm = inv(m);
            for (int i = 0; i < n; i++)
                c[i] = (ll)(x[i] + x[i + n]) * invm % p;
        }
        void power(const int *a, int b, int *c, const int n)
        {
            static int x[N];
            memcpy(x, a, sizeof(int[n]));
            memset(c, 0, sizeof(int[n]));
            c[0] = 1;
            while (b)
            {
                if (b & 1)
                    mul(c, x, c, n);
                mul(x, x, x, n);
                b >>= 1;
            }
        }
    }
    int A[N];
    int work()
    {
        using Polynomial::power;
        using Polynomial::Primitive_Root::get_g;
        int x, ssize;
        read(n), read(m), read(x), read(ssize);
        for (int i = 0; i < ssize; i++)
        {
            int a;
            read(a);
            s.insert(a);
        }
        int g = get_g(m), gx = -1;
        for (int i = 0; i < m - 1; i++)
        {
            int tmp = power(g, i, m);
            if (s.count(tmp))
                A[i] = 1;
            if (tmp == x)
                gx = i;
        }
        power(A, n, A, m - 1);
        write(A[gx]);
        return 0;
    }
}
int main()
{
    return zyt::work();
}
posted @ 2019-01-27 13:15  Inspector_Javert  阅读(166)  评论(0编辑  收藏  举报