树上背包NTT优化
主要结合两道例题讲,复杂度的计算很重要。
LOJ6290 花朵
非常容易可以考虑到树上背包的做法,但是过不了。
怎么将这个 \(\text{dp}\) 优化呢?考虑背包实际上就是一个卷积的形式,所以我们可以用多项式科技优化卷积过程。
可以想到的,我们不能将背包直接卷积,因为复杂度由 \(O(nm)\) 变成 \(O((n+m)\log(n+m))\) ,在 \(n=1\) 的时候复杂度反而会变慢。
具体的,我们将原树轻重链剖分,对于一个点,我们先通过类曼哈顿的贪心,每一次将最小的两个轻子树合并,等到一条重链上的点的所有节点的轻子树都合并完成了,我们再在这条重链上做分治 \(\text{fft}\) 。
复杂度是 \(O(n\log^3n)\) ,具体证明的话应该是考虑重链和轻链分开算。
对于重链来说,考虑每个点会向上跳 \(\log n\) 次,每次链上做合并的时候贡献 \(\log n\) 次,每次是 \(\log n\) ,所以复杂度是 \(O(n\log^3 n)\) 的。
对于轻链来说,考虑每一个点的贡献,由于每一个点在一次合并操作中的均摊复杂度可以近似看成 \(O(\log n)\) ,所以我们只需要计算出每一个点的操作次数即可。考虑一个点在进行轻子树合并的时候大小一定翻倍,所以其从自己位置一直合并到根的合并次数是一定小于 \(\log n\) 的,所以这里的复杂度是 \(O(n\log^2n)\) 的。
卧槽发现有个直接套距阵的老哥,也太帅了吧。
我顿悟了,距阵掌握度 +1 。
关于矩阵实现的细节很重要。对于矩阵的优化,一定需要写出对应的状态转移方程,然后根据转移前的矩阵和转移后的矩阵写出转移矩阵。
#include<bits/stdc++.h>
using namespace std;
const int N=131072;
const int MOD=998244353,G=3;
int ADD(int x,int y){return x+y>=MOD?x+y-MOD:x+y;}
int TIME(int x,int y){return (int)(1ll*x*y%MOD);}
int ksm(int x,int k=MOD-2){int res=1;for(;k;k>>=1,x=TIME(x,x))if(k&1)res=TIME(res,x);return res;}
int rev[N],lst=0;
void get_rev(int lg){
if(lst==lg) return ;else lst=lg;
for(int i=0;i<(1<<lg);++i)
rev[i]=((rev[i>>1]>>1)|((i&1)<<(lg-1)));
}
struct Polynomial{
vector<int> f;
int &operator [] (int x){return assert(x<(int)f.size()),f[x];}
int len(){return (int)f.size();}void clear(){return f.clear();}
void resize(int n){
while((int)f.size()>n) f.pop_back();
while((int)f.size()<n) f.push_back(0);
}
void NTT(int lg,bool tag){
int n=(1<<lg);get_rev(lg),resize(n);
for(int i=0;i<n;++i) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int len=2;len<=n;len<<=1){
int m=(len>>1),g=ksm(G,(MOD-1)/len);if(tag) g=ksm(g);
for(int i=0;i<n;i+=len){
for(int j=0,gg=1;j<m;++j,gg=TIME(gg,g)){
int tmp=TIME(f[i+j+m],gg);
f[i+j+m]=ADD(f[i+j],MOD-tmp),f[i+j]=ADD(f[i+j],tmp);
}
}
}
if(tag) for(int i=0,tmp=ksm(n);i<n;++i) f[i]=TIME(f[i],tmp);
}
void print(){
for(int i=0;i<len();++i) printf("%d ",f[i]);
printf("\n");
}
};
Polynomial init(int x){
Polynomial res;return res.resize(2),res[1]=x,res;
}
Polynomial operator * (Polynomial f,Polynomial g){
if(!f.len()||!g.len()) return Polynomial();
int n=f.len()+g.len()-1,lg=0;while((1<<lg)<n) lg++;
f.NTT(lg,false),g.NTT(lg,false);
for(int i=0;i<(1<<lg);++i) f[i]=TIME(f[i],g[i]);
return f.NTT(lg,true),f.resize(n),f;
}
Polynomial operator + (Polynomial f,Polynomial g){
if(f.len()<g.len()) swap(f,g);
for(int i=0;i<g.len();++i) f[i]=ADD(f[i],g[i]);
return f;
}
struct Matrix{
Polynomial f[2][2];
void clear(){
for(int i=0;i<2;++i){
for(int j=0;j<2;++j)
f[i][j].clear();
}
}
void print(){
for(int i=0;i<2;++i){
for(int j=0;j<2;++j)
printf("f[%d][%d]=",i,j),f[i][j].print();
}
}
};
bool operator < (Matrix a,Matrix b){
int tmp1=max({a.f[0][0].len(),a.f[0][1].len(),a.f[1][0].len(),a.f[1][1].len()});
int tmp2=max({b.f[0][0].len(),b.f[0][1].len(),b.f[1][0].len(),b.f[1][1].len()});
return tmp1<tmp2;
}
bool operator > (Matrix a,Matrix b){
int tmp1=max({a.f[0][0].len(),a.f[0][1].len(),a.f[1][0].len(),a.f[1][1].len()});
int tmp2=max({b.f[0][0].len(),b.f[0][1].len(),b.f[1][0].len(),b.f[1][1].len()});
return tmp1>tmp2;
}
Matrix operator * (Matrix a,Matrix b){
Matrix res;res.clear();
for(int i=0;i<2;++i){
for(int k=0;k<2;++k){
for(int j=0;j<2;++j)
res.f[i][j]=res.f[i][j]+a.f[i][k]*b.f[k][j];
}
}
return res;
}
int n,m,p[N];
struct Edge{int nxt,to;}e[N<<1];int fir[N];
void add(int u,int v,int i){e[i]=(Edge){fir[u],v},fir[u]=i;}
struct Node{int fa,son,siz;}tr[N];
void dfs1(int u){
tr[u].siz=1;
for(int i=fir[u];i;i=e[i].nxt){
int v=e[i].to;if(v==tr[u].fa) continue;
tr[v].fa=u,dfs1(v),tr[u].siz+=tr[v].siz;
if(tr[v].siz>tr[tr[u].son].siz) tr[u].son=v;
}
}
Matrix cdq(vector<Matrix> &bag,int l,int r){
if(l==r) return bag[l];
int mid=(l+r)>>1;
return cdq(bag,l,mid)*cdq(bag,mid+1,r);
}
priority_queue<Matrix,vector<Matrix>,greater<Matrix> > q;
Matrix merge(vector<Matrix> &bag){
while(!q.empty()) q.pop();
for(int i=0;i<(int)bag.size();++i) q.push(bag[i]);
while(q.size()>1){
Matrix a=q.top();q.pop();
Matrix b=q.top();q.pop();
q.push(a*b);
}
return q.top();
}
Matrix dfs2(int u){
vector<Matrix> bag;
for(;u;u=tr[u].son){
vector<Matrix> BAG;
for(int i=fir[u];i;i=e[i].nxt){
int v=e[i].to;if(v==tr[u].fa||v==tr[u].son) continue;
Matrix tmp=dfs2(v),TMP;TMP.clear();
TMP.f[1][1]=tmp.f[0][0];
TMP.f[0][0]=TMP.f[1][1]+tmp.f[1][0];
BAG.push_back(TMP);
}
Matrix tmp,TMP;TMP.clear();
if(BAG.empty()){
tmp.clear();
tmp.f[0][0].resize(1),tmp.f[0][0][0]=1;
tmp.f[1][1].resize(1),tmp.f[1][1][0]=1;
}
else tmp=merge(BAG);
TMP.f[0][0]=TMP.f[0][1]=tmp.f[0][0];
TMP.f[1][0]=tmp.f[1][1]*init(p[u]);
bag.push_back(TMP);
}
return cdq(bag,0,(int)bag.size()-1);
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;++i) scanf("%d",&p[i]);
for(int i=1;i<n;++i){
int u,v;scanf("%d%d",&u,&v);
add(u,v,i<<1),add(v,u,i<<1|1);
}
dfs1(1);Matrix tmp=dfs2(1);Polynomial res;
res=tmp.f[0][0]+tmp.f[1][0],res.resize(m+1);
return printf("%d\n",res[m]),0;
}
GYM102331J Jiry Matchings
我们考虑这里的合并的复杂度是 \(O(n+m)\) 的,也是需要轻重链剖分的。