【XSY2680】玩具谜题 NTT 牛顿迭代

题目描述

  小南一共有\(n\)种不同的玩具小人,每种玩具小人的数量都可以被认为是无限大。每种玩具小人都有特定的血量,第\(i\)种玩具小人的血量就是整数\(i\)。此外,每种玩具小人还有自己的攻击力,攻击力可以是任意非负整数,且两种不同的玩具小人的攻击力可以相同。我们把第\(i\)种玩具小人的血量和攻击力表示成\(a_i\)\(b_i\)

  为了让玩具小人们进行战斗,小南打算把一些小人选出来,编成队伍。一个队伍可以表示成一个由玩具小人组成的序列:\((p_1,p_2,\ldots,p_l)\),其中\(p_i\)表示队伍中第\(i\)个玩具小人的种类,\(l\)为队伍的长度。对于不同的\(i\)\(p_i\)可以相同。两个队伍被认为相同,当且仅当长度相同,且每个位置的玩具小人种类都分别相同。

  一个队伍也有血量和攻击力两个属性,记为\(a_t,b_t\)。队伍的血量就是每个玩具小人的血量之和,而队伍攻击力可能会由于队伍内部产生矛盾而减小,对于长度为\(l\)的队伍,队伍的攻击力为每个玩具小人的攻击力之乘积除以\(l\)的阶乘。同时,当\(l\)大于等于某个常数\(c\)时,攻击力会有一个额外的加成:乘以\((1+\frac{l!}{(l−c)!})\)。也就是说:

\[a_t=\sum_{i=1}^la_{p_i}\\ b_t=\begin{cases} \frac{1}{l!}\sum_{i=1}^lb_{p_i}&,l<c\\ (\frac{1}{l!}+\frac{1}{(l-c)!})\sum_{i=1}^lb_{p_i}&,l\geq c \end{cases} \]

  然而,小南的玩具小人们对小南的独裁统治感到愤怒,准备联合起来发起民主运动。为了旗帜鲜明地反对动乱,小南必须了解清楚玩具小人们的战斗力。不幸的是,由于玩具小人数量过多,小南已经忘记每种玩具小人的战斗力具体是多少了。现在,小南掌握的情报只有对于每个\(1\)\(n\)之间的整数\(i\),所有血量等于\(i\)的不同队伍的战斗力之和对\(998244353\)取模的值是多少(\(s_i\))。他希望你根据已有的情报,还原出每种玩具小人的战斗力对\(998244353\)取模的结果 。如果镇压成功了,小南会请你到北京去做一回***(当然是北京玩具协会的***)。

  \(n\leq 60000,0\leq c\leq n\)

题解

  设\(F=\sum_{i\geq 1}b_i,S=\sum_{i\geq 0}s_i\),如果\(c=0\),那么\(s_0=2\)

\[\begin{align} \sum_{i\geq 0}\frac{F^i}{i!}+\sum_{i\geq 0}\frac{F^i}{i!}&=S\\ 2e^F&=S\\ F=\ln\frac{S}{2} \end{align} \]

  否则\(s_0=1\)

\[\begin{align} \sum_{i\geq 1}\frac{F^i}{i!}+\sum_{i\geq c}\frac{F^i}{(i-c)!}&=S-1\\ \sum_{i\geq 1}\frac{F^i}{i!}+F^c\sum_{i\geq0}\frac{F^i}{i!}&=S-1\\ (F^c+1)e^F&=S \end{align} \]

  然后就是牛顿迭代解方程。我们需要满足

\[g(F)=(F^c+1)e^F-S=0 \]

  的\(F\)。设当前求出了

\[g(F_0)\equiv0\pmod {x^{\frac{n}{2}}} \]

  的\(F_0\),现在我们要求\(F\)满足

\[g(F)\equiv 0\pmod {x^n} \]

  考虑在\(F_0\)出对\(g\)泰勒展开

\[g(F)=g(F_0)+g'(F_0)(F-F_0)+\frac{g''(F_0)}{2}{(F-F_0)}^2+\cdots \]

  后面的项都是\(0\),因为\(F-F_0\)的最小非零项的次数至少是\(\frac{n}{2}\),所以后面的部分在模\(x^n\)意义下一定会被消掉。

  式子就变成了

\[\begin{align} g(F)&\equiv g(F_0)+g'(F_0)(F-F_0)\pmod {x^n}\\ F&\equiv F_0-\frac{g(F_0)}{g'(F_0)}\pmod {x^n}\\ F&\equiv F_0-\frac{({F_0}^c+1)e^{F_0}-S}{(c{F_0}^{c-1}+{F_0}^c+1)e^{F_0}}\pmod {x^n} \end{align} \]

  套各种多项式算法可以做到

\[T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n) \]

  常数巨大。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return s;
}
void put(int x)
{
	if(!x)
	{
		putchar('0');
		return;
	}
	static int c[20];
	int t=0;
	while(x)
	{
		c[++t]=x%10;
		x/=10;
	}
	while(t)
		putchar(c[t--]+'0');
}
int upmin(int &a,int b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(int &a,int b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
const ll p=998244353;
const ll g=3;
const int maxn=65536;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
ll inv[200000];
namespace ntt
{
	int rev[200000];
	int m;
	void ntt(ll *a,int n,int t)
	{
		ll u,v,w,wn;
		int i,j,k;
		if(n!=m)
		{
			m=n;
			rev[0]=0;
			for(i=1;i<n;i++)
				rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
		}
		for(i=0;i<n;i++)
			if(rev[i]<i)
				swap(a[i],a[rev[i]]);
		for(i=2;i<=n;i<<=1)
		{
			wn=fp(g,(p-1)/i);
			if(t==-1)
				wn=fp(wn,p-2);
			for(j=0;j<n;j+=i)
			{
				w=1;
				for(k=j;k<j+i/2;k++)
				{
					u=a[k];
					v=a[k+i/2]*w%p;
					a[k]=(u+v)%p;
					a[k+i/2]=(u-v)%p;
					w=w*wn%p;
				}
			}
		}
		if(t==-1)
		{
			ll inv=fp(n,p-2);
			for(i=0;i<n;i++)
				a[i]=a[i]*inv%p;
		}
	}
	void getinv(ll *a,ll *b,int n)
	{
		if(n==1)
		{
			b[0]=fp(a[0],p-2);
			return;
		}
		getinv(a,b,n>>1);
		static ll a1[200000],a2[200000];
		int i;
		for(i=0;i<n;i++)
			a1[i]=a[i];
		for(;i<n<<1;i++)
			a1[i]=0;
		for(i=0;i<n>>1;i++)
			a2[i]=b[i];
		for(;i<n<<1;i++)
			a2[i]=0;
		ntt(a1,n<<1,1);
		ntt(a2,n<<1,1);
		for(i=0;i<n<<1;i++)
			a1[i]=(2*a2[i]-a1[i]*a2[i]%p*a2[i])%p;
		ntt(a1,n<<1,-1);
		for(i=0;i<n;i++)
			b[i]=a1[i];
	}
	void getln(ll *a,ll *b,int n)
	{
		static ll a1[200000],a2[200000];
		int i;
		for(i=1;i<n;i++)
			a1[i-1]=a[i]*i%p;
		a1[n-1]=0;
		getinv(a,a2,n);
		for(i=n;i<n<<1;i++)
			a1[i]=a2[i]=0;
		ntt(a1,n<<1,1);
		ntt(a2,n<<1,1);
		for(i=0;i<n<<1;i++)
			a1[i]=a1[i]*a2[i]%p;
		ntt(a1,n<<1,-1);
		b[0]=0;
		for(i=1;i<n;i++)
			b[i]=a1[i-1]*inv[i]%p;
	}
	void getexp(ll *a,ll *b,int n)
	{
		if(n==1)
		{
			b[0]=1;
			return;
		}
		getexp(a,b,n>>1);
		static ll a1[200000],a2[200000];
		int i;
		for(i=0;i<n>>1;i++)
			a1[i]=b[i];
		for(;i<n<<1;i++)
			a1[i]=0;
		for(i=n>>1;i<n;i++)
			b[i]=0;
		getln(b,a2,n);
		for(i=0;i<n;i++)
			a2[i]=-a2[i];
		for(i=n;i<n<<1;i++)
			a2[i]=0;
		a2[0]++;
		for(i=0;i<n;i++)
			a2[i]=(a2[i]+a[i])%p;
		ntt(a1,n<<1,1);
		ntt(a2,n<<1,1);
		for(i=0;i<n<<1;i++)
			a1[i]=a1[i]*a2[i]%p;
		ntt(a1,n<<1,-1);
		for(i=0;i<n;i++)
			b[i]=a1[i];
	}
	void getpow(ll *a,ll *b,int n,ll k)
	{
		int d=0;
		while(d<n&&!a[d])
			d++;
		int i;
		if(d>=n)
		{
			for(i=0;i<n;i++)
				b[i]=0;
			if(!k)
				b[0]=1;
			return;
		}
		static ll a1[200000],a2[200000];
		ll c=a[d];
		ll e=fp(c,p-2);
		for(i=0;i<n;i++)
			if(i+d<n)
				a1[i]=a[i+d]*e%p;
			else
				a1[i]=0;
		getln(a1,a2,n);
		for(i=0;i<n;i++)
			a2[i]=a2[i]*k%p;
		getexp(a2,a1,n);
		for(i=0;i<n&&i<d*k;i++)
			b[i]=0;
		c=fp(c,k);
		for(i=d*k;i<n;i++)
			b[i]=a1[i-d*k]*c%p;
	}
	void mul(ll *a,ll *b,ll *c,int n)
	{
		int i;
		static ll a1[200000],a2[200000];
		for(i=0;i<n;i++)
		{
			a1[i]=a[i];
			a2[i]=b[i];
		}
		for(;i<n<<1;i++)
			a1[i]=a2[i]=0;
		ntt(a1,n<<1,1);
		ntt(a2,n<<1,1);
		for(i=0;i<n<<1;i++)
			a1[i]=a1[i]*a2[i]%p;
		ntt(a1,n<<1,-1);
		for(i=0;i<n;i++)
			c[i]=a1[i];
	}
}
using namespace ntt;
ll a[200000],b[200000];
void init()
{
	int i;
	inv[0]=inv[1]=1;
	for(i=2;i<=maxn;i++)
		inv[i]=-p/i*inv[p%i]%p;
}
int c;
void gao(ll *a,ll *b,int n)
{
	if(n==1)
	{
		b[0]=0;
		return;
	}
	gao(a,b,n>>1);
	int i;
	for(i=n>>1;i<n;i++)
		b[i]=0;
	static ll a1[200000],a2[200000],a3[200000],a4[200000],a5[200000],a6[200000],a7[200000];
	//a1=F^(c-1)
	getpow(b,a1,n,c-1);
	//a2=F^c=a1F
	mul(a1,b,a2,n);
	//a3=e^F
	getexp(b,a3,n);
	for(i=0;i<n;i++)
		a4[i]=a2[i];
	a4[0]++;
	mul(a4,a3,a5,n);
	for(i=0;i<n;i++)
		a5[i]=(a5[i]-a[i])%p;
	for(i=0;i<n;i++)
		a6[i]=(a2[i]+c*a1[i])%p;
	a6[0]++;
	mul(a6,a3,a7,n);
	getinv(a7,a6,n);
	mul(a6,a5,a7,n);
	for(i=0;i<n;i++)
		b[i]=(b[i]-a7[i])%p;
}
void gao2(ll *a,ll *b,int n)
{
	int i;
	for(i=0;i<n;i++)
		a[i]=a[i]*inv[2]%p;
	getln(a,b,n);
}
int n;
int main()
{
	init();
	open("c");
	scanf("%d%d",&n,&c);
	int m=1;
	while(m<=n)
		m<<=1;
	int i;
	for(i=1;i<=n;i++)
		scanf("%lld",&a[i]);
	for(i=n+1;i<m;i++)
		a[i]=0;
	if(!c)
	{
		a[0]=2;
		gao2(a,b,m);
	}
	else
	{
		a[0]=1;
		gao(a,b,m);
	}
	for(i=1;i<=n;i++)
	{
		b[i]=(b[i]+p)%p;
		printf("%lld\n",b[i]);
	}
	return 0;
}
posted @ 2018-03-06 11:48  ywwyww  阅读(271)  评论(0编辑  收藏  举报