游戏

  • 要统计差值为k的数对(i,j)的数量,这种感觉类似于卷积,我们把和差放到幂次中体现,就可以用NTT做到O(ailogai)
  • 其中,对于差值为0的特殊情况,不仅需要减去数自匹配的n种情况,还要除以2
  • NTT要预处理step+倍增法优化,否则会TLE
  • 将游戏每轮的操作抽象为数学函数
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int mod=998244353;
int a[1000005];
int v[10000005],prime[10000005],m;
long long cnt[1000005],inv[10000005];
long long f[500005];
int rev[5000005],p[5000005][2];
int read1()
{
	char cc=getchar();
	while(!(cc>=48&&cc<=57))
	{
		if(cc=='-')
		{
			break;
		}
		cc=getchar();
	}
	bool f=false;
	int s=0;
	if(cc=='-')
	{
		f=true;
	}
	else
	{
		s=cc-48;
	}
	while(1)
	{
		cc=getchar();
		if(cc>=48&&cc<=57)
		{
			s=s*10+cc-48;
		}
		else
		{
			break;
		}
	}
	if(f==true)
	{
		s=-s;
	}
	return s;
}
int power(int n,int p)
{
	if(p==0)
	{
		return 1;
	}
	long long tmp=power(n,p/2);
	if(p%2==1)
	{
		return tmp*tmp%mod*n%mod;
	}
	return tmp*tmp%mod;
}
void NTT(vector<long long>&f,int opt)
{
	int n=f.size();
	for(int i=1;i<n;i++)
	{
		if(i<rev[i])
		{
			swap(f[i],f[rev[i]]);
		}
	}
	for(int m=2;m<=n;m*=2)
	{
		int k=m/2;
		for(int i=0;i<n;i+=m)
		{
			long long cur=1,step;
			if(opt==1)
			{
				step=p[m][0];
			}
			else
			{
				step=p[m][1];
			}
			for(int j=0;j<k;j++)
			{
				long long tmp=cur*f[i+j+k]%mod;
				f[i+j+k]=(f[i+j]-tmp)%mod;
				f[i+j]=(f[i+j]+tmp)%mod;
				cur=cur*step%mod;
			}
		}
	}
}
vector<long long> operator*(vector<long long>a,vector<long long>b)
{
	vector<long long>c(a.size()+b.size()-1);
	while(c.size()<(1<<22))
	{
		c.push_back(0);
	}
	while(a.size()<c.size())
	{
		a.push_back(0);
	}
	while(b.size()<c.size())
	{
		b.push_back(0);
	}
	NTT(a,1),NTT(b,1);
	for(int i=0;i<c.size();i++)
	{
		c[i]=a[i]*b[i]%mod;
	}
	NTT(c,-1);
	int p=power(c.size(),998244351);
	for(int i=0;i<c.size();i++)
	{
		c[i]=c[i]*p%mod;
	}
	return c;
}
vector<long long>c,c1(2000001),c2(2000001);
int main()
{
	for(int i=1;i<(1<<22);i++)
	{
		rev[i]=(rev[i>>1]>>1);
		if(i&1)
		{
			rev[i]+=(1<<21);
		}
	}
	for(int i=1;i<=22;i++)
	{
		p[1<<i][0]=power(3,998244352/(1<<i));
		p[1<<i][1]=power(3,998244352-998244352/(1<<i));
	}
	inv[1]=1;
	for(int i=2;i<=10000000;i++)
	{
		if(v[i]==0)
		{
			v[i]=i;
			prime[++m]=i;
			inv[i]=power(i,998244351);
		}
		for(int j=1;j<=m;j++)
		{
			if(i*prime[j]>10000000||prime[j]>v[i])
			{
				break;
			}
			v[i*prime[j]]=prime[j];
			inv[i*prime[j]]=inv[i]*inv[prime[j]]%mod;
		}
	}
	long long n,t;
	cin>>n>>t;
	for(int i=1;i<=n;i++)
	{
		a[i]=read1();
		c1[a[i]+1000000]++;
		c2[-a[i]+1000000]++;
	}
	c=c1*c2;
	for(int i=0;i<1000000;i++)
	{
		cnt[i]=c[-i+2000000];
	}
	cnt[0]-=n;
	cnt[0]=cnt[0]*power(2,998244351)%mod;
	long long p1=(n-2)*power((n*(n-1)/2)%mod,998244351)%mod,p1inv=power(p1,998244351);
	vector<long long>a(t+1);
	a[0]=power(p1,t);
	a[1]=power(p1,t-1)*(1-2*p1)%mod*t%mod;
	for(int i=1;i<t;i++)
	{
		a[i+1]=((1-2*p1)%mod*a[i]%mod*t%mod+2*p1*t%mod*a[i-1]%mod-i*a[i]%mod*(1-2*p1)%mod-a[i-1]*(i-1)%mod*p1%mod)%mod*inv[i+1]%mod*p1inv%mod;
	}
	long long ans=0;
	for(int i=max(0ll,t-1000000);i<=t;i++)
	{
		ans=(ans+cnt[t-i]*a[i]%mod)%mod;
	}
	cout<<(ans+mod)%mod<<endl;
	return 0;
}
posted @ 2024-07-31 20:08  D06  阅读(7)  评论(0编辑  收藏  举报