树上背包NTT优化

主要结合两道例题讲,复杂度的计算很重要。

LOJ6290 花朵

非常容易可以考虑到树上背包的做法,但是过不了。

怎么将这个 \(\text{dp}\) 优化呢?考虑背包实际上就是一个卷积的形式,所以我们可以用多项式科技优化卷积过程。

可以想到的,我们不能将背包直接卷积,因为复杂度由 \(O(nm)\) 变成 \(O((n+m)\log(n+m))\) ,在 \(n=1\) 的时候复杂度反而会变慢。

具体的,我们将原树轻重链剖分,对于一个点,我们先通过类曼哈顿的贪心,每一次将最小的两个轻子树合并,等到一条重链上的点的所有节点的轻子树都合并完成了,我们再在这条重链上做分治 \(\text{fft}\)

复杂度是 \(O(n\log^3n)\) ,具体证明的话应该是考虑重链和轻链分开算。

对于重链来说,考虑每个点会向上跳 \(\log n\) 次,每次链上做合并的时候贡献 \(\log n\) 次,每次是 \(\log n\) ,所以复杂度是 \(O(n\log^3 n)\) 的。

对于轻链来说,考虑每一个点的贡献,由于每一个点在一次合并操作中的均摊复杂度可以近似看成 \(O(\log n)\) ,所以我们只需要计算出每一个点的操作次数即可。考虑一个点在进行轻子树合并的时候大小一定翻倍,所以其从自己位置一直合并到根的合并次数是一定小于 \(\log n\) 的,所以这里的复杂度是 \(O(n\log^2n)\) 的。


卧槽发现有个直接套距阵的老哥,也太帅了吧。

我顿悟了,距阵掌握度 +1 。


关于矩阵实现的细节很重要。对于矩阵的优化,一定需要写出对应的状态转移方程,然后根据转移前的矩阵和转移后的矩阵写出转移矩阵。

#include<bits/stdc++.h>
using namespace std;
const int N=131072;
const int MOD=998244353,G=3;
int ADD(int x,int y){return x+y>=MOD?x+y-MOD:x+y;}
int TIME(int x,int y){return (int)(1ll*x*y%MOD);}
int ksm(int x,int k=MOD-2){int res=1;for(;k;k>>=1,x=TIME(x,x))if(k&1)res=TIME(res,x);return res;}
int rev[N],lst=0;
void get_rev(int lg){
	if(lst==lg) return ;else lst=lg;
	for(int i=0;i<(1<<lg);++i)
		rev[i]=((rev[i>>1]>>1)|((i&1)<<(lg-1)));
}
struct Polynomial{
	vector<int> f;
	int &operator [] (int x){return assert(x<(int)f.size()),f[x];}
	int len(){return (int)f.size();}void clear(){return f.clear();}
	void resize(int n){
		while((int)f.size()>n) f.pop_back();
		while((int)f.size()<n) f.push_back(0);
	}
	void NTT(int lg,bool tag){
		int n=(1<<lg);get_rev(lg),resize(n);
		for(int i=0;i<n;++i) if(i<rev[i]) swap(f[i],f[rev[i]]);
		for(int len=2;len<=n;len<<=1){
			int m=(len>>1),g=ksm(G,(MOD-1)/len);if(tag) g=ksm(g);
			for(int i=0;i<n;i+=len){
				for(int j=0,gg=1;j<m;++j,gg=TIME(gg,g)){
					int tmp=TIME(f[i+j+m],gg);
					f[i+j+m]=ADD(f[i+j],MOD-tmp),f[i+j]=ADD(f[i+j],tmp);
				}
			}
		}
		if(tag) for(int i=0,tmp=ksm(n);i<n;++i) f[i]=TIME(f[i],tmp);
	}
	void print(){
		for(int i=0;i<len();++i) printf("%d ",f[i]);
		printf("\n");
	}
};
Polynomial init(int x){
	Polynomial res;return res.resize(2),res[1]=x,res;
}
Polynomial operator * (Polynomial f,Polynomial g){
	if(!f.len()||!g.len()) return Polynomial();
	int n=f.len()+g.len()-1,lg=0;while((1<<lg)<n) lg++;
	f.NTT(lg,false),g.NTT(lg,false);
	for(int i=0;i<(1<<lg);++i) f[i]=TIME(f[i],g[i]);
	return f.NTT(lg,true),f.resize(n),f;
}
Polynomial operator + (Polynomial f,Polynomial g){
	if(f.len()<g.len()) swap(f,g);
	for(int i=0;i<g.len();++i) f[i]=ADD(f[i],g[i]);
	return f;
}
struct Matrix{
	Polynomial f[2][2];
	void clear(){
		for(int i=0;i<2;++i){
			for(int j=0;j<2;++j)
				f[i][j].clear();
		}
	}
	void print(){
		for(int i=0;i<2;++i){
			for(int j=0;j<2;++j)
			printf("f[%d][%d]=",i,j),f[i][j].print();
		}
	}
};
bool operator < (Matrix a,Matrix b){
	int tmp1=max({a.f[0][0].len(),a.f[0][1].len(),a.f[1][0].len(),a.f[1][1].len()});
	int tmp2=max({b.f[0][0].len(),b.f[0][1].len(),b.f[1][0].len(),b.f[1][1].len()});
	return tmp1<tmp2;
}
bool operator > (Matrix a,Matrix b){
	int tmp1=max({a.f[0][0].len(),a.f[0][1].len(),a.f[1][0].len(),a.f[1][1].len()});
	int tmp2=max({b.f[0][0].len(),b.f[0][1].len(),b.f[1][0].len(),b.f[1][1].len()});
	return tmp1>tmp2;
}
Matrix operator * (Matrix a,Matrix b){
	Matrix res;res.clear();
	for(int i=0;i<2;++i){
		for(int k=0;k<2;++k){
			for(int j=0;j<2;++j)
			res.f[i][j]=res.f[i][j]+a.f[i][k]*b.f[k][j];
		}
	}
	return res;
}
int n,m,p[N];
struct Edge{int nxt,to;}e[N<<1];int fir[N];
void add(int u,int v,int i){e[i]=(Edge){fir[u],v},fir[u]=i;}
struct Node{int fa,son,siz;}tr[N];
void dfs1(int u){
	tr[u].siz=1;
	for(int i=fir[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==tr[u].fa) continue;
		tr[v].fa=u,dfs1(v),tr[u].siz+=tr[v].siz;
		if(tr[v].siz>tr[tr[u].son].siz) tr[u].son=v;
	}
}
Matrix cdq(vector<Matrix> &bag,int l,int r){
	if(l==r) return bag[l];
	int mid=(l+r)>>1;
	return cdq(bag,l,mid)*cdq(bag,mid+1,r);
}
priority_queue<Matrix,vector<Matrix>,greater<Matrix> > q;
Matrix merge(vector<Matrix> &bag){
	while(!q.empty()) q.pop();
	for(int i=0;i<(int)bag.size();++i) q.push(bag[i]);
	while(q.size()>1){
		Matrix a=q.top();q.pop();
		Matrix b=q.top();q.pop();
		q.push(a*b);
	}
	return q.top();
}
Matrix dfs2(int u){
	vector<Matrix> bag;
	for(;u;u=tr[u].son){
		vector<Matrix> BAG;
		for(int i=fir[u];i;i=e[i].nxt){
			int v=e[i].to;if(v==tr[u].fa||v==tr[u].son) continue;
			Matrix tmp=dfs2(v),TMP;TMP.clear();
			TMP.f[1][1]=tmp.f[0][0];
			TMP.f[0][0]=TMP.f[1][1]+tmp.f[1][0];
			BAG.push_back(TMP);
		}
		Matrix tmp,TMP;TMP.clear();
		if(BAG.empty()){
			tmp.clear();
			tmp.f[0][0].resize(1),tmp.f[0][0][0]=1;
			tmp.f[1][1].resize(1),tmp.f[1][1][0]=1;
		}
		else tmp=merge(BAG);
		TMP.f[0][0]=TMP.f[0][1]=tmp.f[0][0];
		TMP.f[1][0]=tmp.f[1][1]*init(p[u]);
		bag.push_back(TMP);
	}
	return cdq(bag,0,(int)bag.size()-1);
}
int main(){
	cin>>n>>m;
	for(int i=1;i<=n;++i) scanf("%d",&p[i]);
	for(int i=1;i<n;++i){
		int u,v;scanf("%d%d",&u,&v);
		add(u,v,i<<1),add(v,u,i<<1|1);
	}
	dfs1(1);Matrix tmp=dfs2(1);Polynomial res;
	res=tmp.f[0][0]+tmp.f[1][0],res.resize(m+1);
	return printf("%d\n",res[m]),0;
}

GYM102331J Jiry Matchings

我们考虑这里的合并的复杂度是 \(O(n+m)\) 的,也是需要轻重链剖分的。

posted @ 2021-12-25 09:33  Point_King  阅读(556)  评论(0编辑  收藏  举报