【广西省赛#7】G.Grand XOR Counting Problem Challenge

Description

给一个数组 \({a_i}, i=1, \cdots, n\),对 \(j=0, 1,\cdots, m-1\) ,计算其中有多少个大小为 \(k\) 的子序列满足其异或和为 \(j\)

  • \(n\leq 10^5\)
  • $ m\leq 65536$

Solution

首先答案是

\[[y^k]\prod_{i=1}^n (1+x^{a_i}y) \]

这里对 \(y\) 做的是多项式乘法,对 \(x\) 做的是异或卷积。

正常做法就是直接 FWT,算完连乘积再 IFWT 回去。

用类似《黎明前的巧克力》那道题的套路,对每个单项式 \(x^{a_i}\),它 FWT 的结果每项都是 \(\pm 1\)

根据 FWT 的线性性,所有单项式相加后 FWT 的结果等于每个单项式 FWT 的结果之和。我们算出单项式之和的 FWT 的每一项系数 \(c_i\),然后根据

\[x_{i,1}+x_{i,-1}=n,x_{i,1}-x_{i,-1}=c_i \]

即可计算出每个位置 \(\pm 1\) 的数量。

最后 FWT 的连乘积,第 \(i\) 项就是 \([y^k](1+y)^{x_{1,i}}(1-y)^{x_{-1,i}}\)

把它求出来,然后 IFWT 回去,就是最终答案。

瓶颈在算上面这个 \(y^k\) 的系数,直接算是 \(O(km)\) 的,会T。

\(t=x_{-1,i}\),下面对所有的 \(t=0,1,\dots,n\) 计算 \(y^k\) 的系数。

\[[y^k](1+y)^{n-t}(1-y)^t\\=\sum_{i=0}^k(-1)^i\binom ti\binom{n-t}{k-i}\\=\sum_{i=0}^t(-1)^i\binom ti\binom{n-t}{k-i}\\=t!(n-t)!\sum_{i=0}^t\frac {(-1)^i}{i!(k-i)!(t-i)!(n-t-k+i)!} \]

\[f_i=\frac {(-1)^i}{i!(k-i)!}, g_i=\frac 1{i!(n-k-i)!} \]

然后卷积即可。

这样总复杂度为 \(O(m\log m+n\log n)\)

Code

#define LOCAL
#include "bits/stdc++.h"
using namespace std;
using ui=unsigned; using db=long double; using ll=long long; using ull=unsigned long long; using lll=__int128;
using pii=pair<int,int>; using pll=pair<ll,ll>;
template<class T1, class T2> istream &operator>>(istream &cin, pair<T1, T2> &a) { return cin>>a.first>>a.second; }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index==sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { }
template <std::size_t Index=0, typename... Ts> typename std::enable_if<Index < sizeof...(Ts), void>::type tuple_read(std::istream &is, std::tuple<Ts...> &t) { is>>std::get<Index>(t); tuple_read<Index+1>(is, t); }
template <typename... Ts>std::istream &operator>>(std::istream &is, std::tuple<Ts...> &t) { tuple_read(is, t); return is; }
template<class T1> istream &operator>>(istream &cin, vector<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1> istream &operator>>(istream &cin, valarray<T1> &a) { for (auto &x:a) cin>>x; return cin; }
template<class T1, class T2> bool cmin(T1 &x, const T2 &y) { if (y<x) { x=y; return 1; } return 0; }
template<class T1, class T2> bool cmax(T1 &x, const T2 &y) { if (x<y) { x=y; return 1; } return 0; }
istream &operator>>(istream &cin, lll &x) { x=0; static string s; cin>>s; for (char c:s) x=x*10+(c-'0'); return cin; }
ostream &operator<<(ostream &cout, lll x) { static char s[60]; int tp=1; s[0]='0'+(x%10); while (x/=10) s[tp++]='0'+(x%10); while (tp--) cout<<s[tp]; return cout; }
#if !defined(ONLINE_JUDGE)&&defined(LOCAL)
#include "my_header/IO.h"
#include "my_header/defs.h"
#else
#define dbg(...) ;
#define dbgx(...) ;
#define dbg1(x) ;
#define dbg2(x) ;
#define dbg3(x) ;
#define DEBUG(msg) ;
#define REGISTER_OUTPUT_NAME(Type, ...) ;
#define REGISTER_OUTPUT(Type, ...) ;
#endif
#define all(x) (x).begin(),(x).end()
#define print(...) cout<<format(__VA_ARGS__)
#define println(...) cout<<format(__VA_ARGS__)<<'\n'
#define err(...) cerr<<format(__VA_ARGS__)
#define errln(...) cerr<<format(__VA_ARGS__)<<'\n'

namespace NTT
{
	const ull g=3, p=998244353;
	const int N=1<<19;//务必修改
	ull w[N];
	int r[N];
	ull ksm(ull x, ull y)
	{
		ull r=1;
		while (y)
		{
			if (y&1) r=r*x%p;
			x=x*x%p;
			y>>=1;
		}
		return r;
	}
	void init(int n)
	{
		static int pr=0, pw=0;
		if (pr==n) return;
		int b=__lg(n)-1, i, j, k;
		for (i=1; i<n; i++) r[i]=r[i>>1]>>1|(i&1)<<b;
		if (pw<n)
		{
			for (j=1; j<n; j=k)
			{
				k=j*2;
				ull wn=ksm(g, (p-1)/k);
				w[j]=1;
				for (i=j+1; i<k; i++) w[i]=w[i-1]*wn%p;
			}
			pw=n;
		}
		pr=n;
	}
	int cal(int x) { return 1<<__lg(max(x, 1)*2-1); }
	struct Q:vector<ull>
	{
		bool flag;
		Q &operator%=(int n) { resize(n); return *this; }
		Q operator%(int n) const
		{
			if (size()<=n)
			{
				auto f=*this;
				return f%=n;
			}
			return Q(vector(begin(), begin()+n));
		}
		int deg() const
		{
			int n=size()-1;
			while (n>=0&&begin()[n]==0) --n;
			return n;
		}
		explicit Q(int x=1, bool f=0):flag(f), vector<ull>(cal(x)) { }//小心:{}会调用这条而非下一条
		Q(const vector<ull> &o, bool f=0):Q(o.size(), f) { copy(all(o), begin()); }
		void dft()
		{
			int n=size(), i, j, k;
			ull y, *f, *g, *wn, *a=data();
			init(n);
			for (i=1; i<n; i++) if (i<r[i]) ::swap(a[i], a[r[i]]);
			for (k=1; k<n; k*=2)
			{
				wn=w+k;
				for (i=0; i<n; i+=k*2)
				{
					g=(f=a+i)+k;
					for (j=0; j<k; j++)
					{
						y=g[j]*wn[j]%p;
						g[j]=f[j]+p-y;
						f[j]+=y;
					}
				}
				if (k*2==n||k==1<<14) for (i=0; i<n; i++) a[i]%=p;
			}
			if (flag)
			{
				y=ksm(n, p-2);
				for (i=0; i<n; i++) a[i]=a[i]*y%p;
				reverse(a+1, a+n);
			}
			flag^=1;
		}
	};
	Q &operator*=(Q &f, Q g)//卷积
	{
		if (f.flag|g.flag)
		{
			int n=f.size(), i;
			assert(n==g.size());
			if (!f.flag) f.dft();
			if (!g.flag) g.dft();
			for (i=0; i<n; i++) (f[i]*=g[i])%=p;
			f.dft();
		}
		else
		{
			int n=cal(f.size()+g.size()-1), i, j;
			int m1=f.deg(), m2=g.deg();
			if ((ull)m1*m2>(ull)n*__lg(n)*8)
			{
				(f%=n).dft(); (g%=n).dft();
				for (i=0; i<n; i++) (f[i]*=g[i])%=p;
				f.dft();
			}
			else
			{
				vector<ull> r(max(0, m1+m2+1));
				for (i=0; i<=m1; i++) for (j=0; j<=m2; j++) (r[i+j]+=f[i]*g[j])%=p;
				f=Q(n);
				copy(all(r), f.begin());
			}
		}
		return f;
	}
}
using NTT::p;
using poly=NTT::Q;

int cy[100005];

void init(int n, int k)
{
    poly x(n+1), y(n+1);
    vector<ull> fac(n+1), inv(n+1);
    for(int i=fac[0]=inv[0]=inv[1]=1; i<=n; ++i) fac[i]=(ull)fac[i-1]*i%p;
    for(int i=2; i<=n; ++i) inv[i]=(ull)inv[p%i]*(p-p/i)%p;
    for(int i=2; i<=n; ++i) inv[i]=(ull)inv[i-1]*inv[i]%p;
    for(int i=0; i<=k; ++i) x[i]=(ull)((i&1)?(p-1ull):1ull)*inv[i]%p*inv[k-i]%p;
    for(int i=0; i+k<=n; ++i) y[i]=(ull)inv[i]*inv[n-k-i]%p;
    x*=y;
    for(int i=0; i<=n; ++i) cy[i]=(ull)fac[i]*fac[n-i]%p*(x[i]%p+p)%p;
}

void fwt_xor(vector<ui> &A)
{
	ui n=A.size(),*a=A.data(),i,j,k,l,*f,*g;
	for (i=1;i<n;i=l)
	{
		l=i*2;
		for (j=0;j<n;j+=l)
		{
			f=a+j;g=a+j+i;
			for (k=0;k<i;k++)
			{
				if ((f[k]+=g[k])>=p) f[k]-=p;
				g[k]=(f[k]+2*(p-g[k]))%p;
			}
		}
	}
}
void ifwt_xor(vector<ui> &A)
{
	ui n=A.size(),*a=A.data(),i,j,k,l,*f,*g,x=p+1>>1,y=1;
	for (i=1;i<n;i=l)
	{
		l=i*2;
		for (j=0;j<n;j+=l)
		{
			f=a+j;g=a+j+i;
			for (k=0;k<i;k++)
			{
				if ((f[k]+=g[k])>=p) f[k]-=p;
				g[k]=(f[k]+2*(p-g[k]))%p;
			}
		}
		y=(ull)y*x%p;
	}
	for (i=0;i<n;i++) a[i]=(ull)a[i]*y%p;
}


int main()
{
	ios::sync_with_stdio(0); cin.tie(0);
	cout<<fixed<<setprecision(15);
	ll n, k, b, x;
    cin >> n >> k >> b;
    init(n, k);
	int B = NTT::cal(b);
    vector<ui> f(B), g(B);
    for(int i=1;i<=n;++i) cin>>x, f[x]++;
    fwt_xor(f);
    for(int i=0;i<B;++i) 
    {
        int pos=((n+p-f[i])%p+p)%p/2;
        g[i]=cy[pos];
    }
    ifwt_xor(g);
    for(int i=0;i<b;++i) cout<<g[i]<<" ";
}
posted @ 2024-10-16 15:39  PaperCloud  阅读(11)  评论(0编辑  收藏  举报