【LOJ565】【LibreOJ Round #10】mathematican 的二进制 DP 分治FFT

题目大意

  有一个无限长的二进制串,初始时它的每一位都为 \(0\)。现在有 \(m\) 个操作,其中第 \(i\) 个操作是将这个二进制串的数值加上 \(2^{a_i}\)。我们称每次操作的代价是这次操作改变的位的数量。

  我们以一定概率执行这些操作:第 \(i\) 个操作有 \(p_i\) 的概率执行,否则不执行。

  请求出所有执行的操作的代价和的期望。

  \(n\leq 100000,m\leq 200000,0\leq a_i\leq n\)

题解

  容易发现,如果进行了 \(k\) 次操作且把这个数从 \(0\) 修改成了 \(v\),那么代价就是 \(2k-\operatorname{bitcount}(v)\)

  可以用势能分析相关知识解释随便看看就看出来了。

  前半部分就是 \(2\sum_{i=1}^mp_i\)

  后半部分可以每位分开算:计算 \(v\) 的每一位为 \(1\) 的概率。

  设 \(f_{i,j}\) 为只考虑 \(a_k\leq i\) 的那些操作,修改完后 \(\lfloor\frac{v}{2^i}\rfloor=j\) 的概率。

   \(g_{i,j}\) 为只考虑 \(a_k=i\) 的那些操作,修改完后 \(\lfloor\frac{v}{2^i}\rfloor=j\) 的概率,也就是执行了 \(j\) 个操作的概率。

  \(g_{i,j}\) 可以用分治 NTT 求出。

  \(f_{i,j}=\sum_{\lfloor\frac{k}{2}\rfloor+l=j}f_{i-1,k}g_{i,l}\),可以用 NTT 优化。

  那么答案的第二部分就是 \(\sum_{i\geq 0}\sum_{2\nmid j}f_{i,j}\)

  时间复杂度是多少?

  记 \(c_i=\sum_{j=1}^m[a_j=i]\)

  算 \(g\) 的复杂度显然是 \(O(m\log^2 m)\)

  算 \(f\) 的复杂度是

\[\begin{align} &O(\log m\times \sum_{i=0}^n\sum_{j=0}^{i}\frac{c_j}{2^{i-j}})\\ =&O(\log m\times \sum_{i=0}^nc_i\sum_{j\geq 0}2^{-j})\\ =&O(m\log m) \end{align} \]

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<functional>
#include<cmath>
#include<vector>
//using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
typedef long long ll;
typedef unsigned long long ull;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const ll p=998244353;
const int N=600000;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
ll v[N],x[N],y[N];
int a[N];
int n,m;
ll e[N];
ll *f[N*2];
ll *g[N];
namespace fft
{
	const int W=524288;
	ll w[N];
	int rev[N];
	void init()
	{
		ll s=fp(3,(p-1)/W);
		w[0]=1;
		for(int i=1;i<W;i++)
			w[i]=w[i-1]*s%p;
	}
	void ntt(ll *a,int n,int t)
	{
		for(int i=1;i<n;i++)
		{
			rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
			if(rev[i]>i)
				swap(a[i],a[rev[i]]);
		}
		for(int i=2;i<=n;i<<=1)
			for(int j=0;j<n;j+=i)
				for(int k=0;k<i/2;k++)
				{
					ll u=a[j+k];
					ll v=a[j+k+i/2]*w[W/i*k]%p;
					a[j+k]=(u+v)%p;
					a[j+k+i/2]=(u-v)%p;
				}
		if(t==-1)
		{
			reverse(a+1,a+n);
			ll inv=fp(n,p-2);
			for(int i=0;i<n;i++)
				a[i]=a[i]*inv%p;
		}
	}
	void mul(ll *a,ll *b,ll *c,int n,int m,int l)
	{
		static ll a1[N],a2[N];
		int k=1;
		while(k<=n+m)
			k<<=1;
		for(int i=0;i<=n;i++)
			a1[i]=a[i];
		for(int i=n+1;i<k;i++)
			a1[i]=0;
		for(int i=0;i<=m;i++)
			a2[i]=b[i];
		for(int i=m+1;i<k;i++)
			a2[i]=0;
		ntt(a1,k,1);
		ntt(a2,k,1);
		for(int i=0;i<k;i++)
			a1[i]=a1[i]*a2[i]%p;
		ntt(a1,k,-1);
		for(int i=0;i<=l;i++)
			c[i]=a1[i];
	}
	int cnt;
	int len[N*2];
	void solve(int &now,int l,int r)
	{
		now=++cnt;
		if(l==r)
		{
			len[now]=1;
			f[now]=new ll[len[now]+1];
			f[now][0]=1-e[l];
			f[now][1]=e[l];
			return;
		}
		int ls;
		int rs;
		int mid=(l+r)>>1;
		solve(ls,l,mid);
		solve(rs,mid+1,r);
		len[now]=len[ls]+len[rs];
		f[now]=new ll[len[now]+1];
		mul(f[ls],f[rs],f[now],len[ls],len[rs],len[now]);
	}
}
int rt[N];
int len[N];
int len2[N];
ll h[N];
std::vector<ll> c[N];
int main()
{
	fft::init();
	open("loj565");
	scanf("%d%d",&n,&m);
	n++;
	ll ans=0;
	ll s=0;
	for(int i=1;i<=m;i++)
	{
		scanf("%d%lld%lld",&a[i],&x[i],&y[i]);
		a[i]++;
		v[i]=x[i]*fp(y[i],p-2)%p;
		ans+=2*v[i];
		c[a[i]].push_back(v[i]);
	}
	g[0]=new ll[1];
	g[0][0]=1;
	for(int i=1;i<=n+20;i++)
	{
		for(auto v:c[i])
			e[++len[i]]=v;
		if(len[i])
			fft::solve(rt[i],1,len[i]);
		else
		{
			rt[i]=++fft::cnt;
			f[rt[i]]=new ll[1];
			f[rt[i]][0]=1;
		}
		for(int j=0;j<=len2[i-1]>>1;j++)
			h[j]=0;
		for(int j=0;j<=len2[i-1];j++)
			h[j>>1]=(h[j>>1]+g[i-1][j])%p;
		len2[i-1]>>=1;
		len2[i]=len2[i-1]+len[i];
		g[i]=new ll[len2[i]+1];
		fft::mul(h,f[rt[i]],g[i],len2[i-1],len[i],len2[i]);
		for(int j=1;j<=len2[i];j+=2)
			s=(s+g[i][j])%p;
	}
	ans=(ans-s)%p;
	ans=(ans+p)%p;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2018-06-25 20:56  ywwyww  阅读(672)  评论(0编辑  收藏  举报