洛谷 5291 [十二省联考2019]希望(52分)——思路+树形DP
题目:https://www.luogu.org/problemnew/show/P5291
考场上写了 16 分的。不过只得了 4 分。
对于一个救援范围,其中合法的点集也是一个连通块。 2n 枚举一个救援范围,然后换根 DP 一下范围内的每个点开始的最长链,那些最长链 <=L 的点就是该范围的合法点集。
这样得到每个合法点集出现的方案, 与卷积 k 次即可。卷积的时候先 FWT 成点值,然后快速幂一样乘 k 次,再 FWT 回来即可。
但只有 4 分。过不了大样例。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} const int N=1e6+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int n,L,k,hd[N],xnt,to[N<<1],nxt[N<<1]; void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} namespace S1{ const int K=20,M=(1<<16)+5; int bin[K],dp[K],pr[K],sc[K],nd[K],tot; int ts,len,f[M],g[M]; bool vis[K],col[K]; void chk_dfs(int cr,int fa) { vis[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if(col[v=to[i]]&&v!=fa)chk_dfs(v,cr); } void dfs(int cr,int fa) { dp[cr]=0; for(int i=hd[cr],v;i;i=nxt[i]) if(col[v=to[i]]&&v!=fa) dfs(v,cr), dp[cr]=Mx(dp[cr],dp[v]+1); } void dfsx(int cr,int fa,int tmp) { if(Mx(dp[cr],tmp)<=L)ts|=bin[cr-1]; int l=tot; for(int i=hd[cr],v;i;i=nxt[i]) if(col[v=to[i]]&&v!=fa) nd[++tot]=v; int r=tot; if(l==r)return; pr[l+1]=dp[nd[l+1]]+1; for(int i=l+2;i<=r;i++)pr[i]=Mx(pr[i-1],dp[nd[i]]+1); sc[r]=dp[nd[r]]+1; for(int i=r-1;i>l;i--)sc[i]=Mn(sc[i+1],dp[nd[i]]+1); for(int i=l+1;i<=r;i++) { int tp=tmp;//=tmp if(i>l+1)tp=pr[i-1];if(i<r)tp=Mx(tp,sc[i+1]); dfsx(nd[i],cr,tp+1); } } void fwt(int *a,bool fx) { for(int R=2;R<=len;R<<=1) for(int i=0,m=R>>1;i<len;i+=R) for(int j=0;j<m;j++) { if(!fx)a[i+j]=upt(a[i+j]+a[i+m+j]); else a[i+j]=upt(a[i+j]-a[i+m+j]); } } void solve() { bin[0]=1; for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1; for(int s=1;s<bin[n];s++) { for(int i=1;i<=n;i++) { vis[i]=0; if(s&bin[i-1])col[i]=1; else col[i]=0; } int cr=0; for(int i=1;i<=n;i++) if(col[i]){chk_dfs(i,0);cr=i;break;} bool fg=0; for(int i=1;i<=n;i++) if(col[i]&&!vis[i]){fg=1;break;} if(fg)continue; ts=tot=0; dfs(cr,0); dfsx(cr,0,0); if(ts){f[ts]++; g[ts]++;} } k--; len=bin[n]; fwt(g,0); fwt(f,0); while(k) { if(k&1) { for(int i=0;i<len;i++)f[i]=(ll)f[i]*g[i]%mod; } for(int i=0;i<len;i++)g[i]=(ll)g[i]*g[i]%mod; k>>=1; } int ans=0; fwt(f,1); for(int s=1;s<bin[n];s++)ans=upt(ans+f[s]); printf("%d\n",ans); } } int main() { freopen("hope.in","r",stdin); freopen("hope.out","w",stdout); n=rdn();L=rdn();k=rdn(); for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); if(n<=16){S1::solve();return 0;} return 0; }
后来发现两个地方写错了:
1.换根的时候做了前缀 max 和后缀 max ,其中后缀取 max 写成取 min 了;
2.往孩子换根的时候用了一个 tp 对父亲来的 tmp 、前缀 max 、后缀 max 取 max ,结果 tp=tmp 之后写成 tp = pr[ ] 而非 tp = Mx( tp , pr[ ] ) 。
改了这两个地方就有 16 分了。
希望以后写代码的时候更仔细。别走神或不集中之类的。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} const int N=1e6+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int n,L,k,hd[N],xnt,to[N<<1],nxt[N<<1]; void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} namespace S1{ const int K=20,M=(1<<16)+5; int bin[K],dp[K],pr[K],sc[K],nd[K],tot; int ts,len,f[M],g[M]; bool vis[K],col[K]; void chk_dfs(int cr,int fa) { vis[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if(col[v=to[i]]&&v!=fa)chk_dfs(v,cr); } void dfs(int cr,int fa) { dp[cr]=0; for(int i=hd[cr],v;i;i=nxt[i]) if(col[v=to[i]]&&v!=fa) dfs(v,cr), dp[cr]=Mx(dp[cr],dp[v]+1); } void dfsx(int cr,int fa,int tmp) { if(Mx(dp[cr],tmp)<=L)ts|=bin[cr-1]; int l=tot; for(int i=hd[cr],v;i;i=nxt[i]) if(col[v=to[i]]&&v!=fa) nd[++tot]=v; int r=tot; if(l==r)return; pr[l+1]=dp[nd[l+1]]+1; for(int i=l+2;i<=r;i++)pr[i]=Mx(pr[i-1],dp[nd[i]]+1); sc[r]=dp[nd[r]]+1; for(int i=r-1;i>l;i--)sc[i]=Mx(sc[i+1],dp[nd[i]]+1);////mx not mn!!! for(int i=l+1;i<=r;i++) { int tp=tmp;//=tmp if(i>l+1)tp=Mx(tp,pr[i-1]);if(i<r)tp=Mx(tp,sc[i+1]);//mx!!! dfsx(nd[i],cr,tp+1); } } void fwt(int *a,bool fx) { for(int R=2;R<=len;R<<=1) for(int i=0,m=R>>1;i<len;i+=R) for(int j=0;j<m;j++) { if(!fx)a[i+j]=upt(a[i+j]+a[i+m+j]); else a[i+j]=upt(a[i+j]-a[i+m+j]); } } void solve() { bin[0]=1; for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1; for(int s=1;s<bin[n];s++) { for(int i=1;i<=n;i++) { vis[i]=0; if(s&bin[i-1])col[i]=1; else col[i]=0; } int cr=0; for(int i=1;i<=n;i++) if(col[i]){chk_dfs(i,0);cr=i;break;} bool fg=0; for(int i=1;i<=n;i++) if(col[i]&&!vis[i]){fg=1;break;} if(fg)continue; ts=tot=0; dfs(cr,0); dfsx(cr,0,0); if(ts){f[ts]++; g[ts]++;} } k--; len=bin[n]; fwt(g,0); fwt(f,0); while(k) { if(k&1) { for(int i=0;i<len;i++)f[i]=(ll)f[i]*g[i]%mod; } for(int i=0;i<len;i++)g[i]=(ll)g[i]*g[i]%mod; k>>=1; } int ans=0; fwt(f,1); for(int s=1;s<bin[n];s++)ans=upt(ans+f[s]); printf("%d\n",ans); } } int main() { freopen("hope.in","r",stdin); freopen("hope.out","w",stdout); n=rdn();L=rdn();k=rdn(); for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); if(n<=16){S1::solve();return 0;} return 0; }
然后参照题解写了 52 分的。
很重要的转化是令 \( f[i] \) 表示 i 是合法点的救援范围个数,那么 k 个救援范围包含 i 的方案就是 \( f[i]^k \) ;考虑到一个方案的合法点集是连通块,即点数比边数大一,所以令 \( g[i] \) 表示边 i 的两端点是合法点的救援范围个数,答案就是 \( \sum\limits_{i=1}^{n}f[i]^k - \sum\limits_{i=1}^{n-1}g[i]^k \) 。
然后就可以写 n*L 的 DP 了。再把链和 L=n 的部分做一下就有 52 分。
不太会 k=1 时候的长链剖分。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} const int N=1e6+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n,L,k,hd[N],xnt=1,to[N<<1],nxt[N<<1],rd[N],f[N],g[N]; void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;rd[y]++;} namespace S1{ const int N=1005; int dfs(int cr,int fa,int lm) { int ret=1; if(!lm)return ret; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) ret=(ll)ret*(dfs(v,cr,lm-1)+1)%mod; return ret; } void dfsx(int cr,int fa) { for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { int ret=dfs(cr,v,L-1); ret=(ll)ret*dfs(v,cr,L-1)%mod; g[i>>1]=ret; dfsx(v,cr); } } void solve() { for(int i=1;i<=n;i++) f[i]=dfs(i,0,L); dfsx(1,0); int ans=0; for(int i=1;i<=n;i++)ans=upt(ans+pw(f[i],k)); for(int i=1;i<n;i++)ans=upt(ans-pw(g[i],k)); printf("%d\n",ans); } } namespace S2{ const int N=1e5+5,M=105; int nd[N],tot; struct Node{ int v[M],s[M],cd; void init(){v[0]=s[0]=1;} void frs() { for(int i=1;i<=cd;i++) s[i]=upt(s[i-1]+v[i]); } void cz() { cd=Mn(cd+1,L); for(int i=cd;i;i--)v[i]=v[i-1]; frs(); } }dp[N],pr[N],sc[N],up[N]; void mrg(Node &d0,Node d1) { int yc=d0.cd, lm=d1.cd, tc=Mn(L,Mx(yc,lm+1)); d0.cd=tc; for(int j=yc+1;j<=tc;j++) d0.v[j]=0, d0.s[j]=d0.s[yc];//0 not 1 for(int j=1;j<=tc;j++) { int tp; if(j-1<=lm)tp=d1.s[j-1]; else tp=d1.s[lm]; tp++;///for choosen't d0.v[j]=(ll)d0.v[j]*tp%mod; if(j-1<=lm) d0.v[j]=(d0.v[j]+(ll)d0.s[j-1]*d1.v[j-1])%mod; } d0.frs(); } void mg2(Node &d0,Node d1) { int yc=d0.cd, lm=d1.cd, tc=Mn(L,Mx(yc,lm)); d0.cd=tc; for(int j=yc+1;j<=tc;j++) d0.v[j]=0, d0.s[j]=d0.s[yc];//0 not 1 for(int j=0;j<=tc;j++) { int tp; if(j<=lm)tp=d1.s[j]; else tp=d1.s[lm]; tp++;///for choosen't d0.v[j]=(ll)d0.v[j]*tp%mod; if(j&&j<=lm) d0.v[j]=(d0.v[j]+(ll)d0.s[j-1]*d1.v[j])%mod; } d0.frs(); } void dfs(int cr,int fa) { dp[cr].init(); for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs(v,cr); mrg(dp[cr],dp[v]); } } void dfsx(int cr,int fa) { int tp=up[cr].cd; f[cr]=(tp>=L?up[cr].s[L]:up[cr].s[tp]); tp=dp[cr].cd; f[cr]=(ll)f[cr]*(tp>=L?dp[cr].s[L]:dp[cr].s[tp])%mod; int l=tot; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { nd[++tot]=i; if(tot==l+1)pr[tot].init(); else pr[tot]=pr[tot-1]; mrg(pr[tot],dp[v]); } int r=tot; for(int i=r;i>l;i--) { if(i==r)sc[i].init(); else sc[i]=sc[i+1]; mrg(sc[i],dp[to[nd[i]]]); } for(int i=l+1;i<=r;i++) { pr[i].v[0]=pr[i].s[0]=0;pr[i].frs(); sc[i].v[0]=sc[i].s[0]=0;sc[i].frs(); } for(int i=l+1;i<=r;i++) { int v=to[nd[i]],bh=nd[i]>>1; up[v]=up[cr]; if(i>l+1) mg2(up[v],pr[i-1]); if(i<r) mg2(up[v],sc[i+1]); int tp=up[v].cd; g[bh]=(tp>=L-1?up[v].s[L-1]:up[v].s[tp]); tp=dp[v].cd; g[bh]=(ll)g[bh]*(tp>=L-1?dp[v].s[L-1]:dp[v].s[tp])%mod; up[v].cz(); dfsx(v,cr); } } void solve() { dfs(1,0); up[1].init(); dfsx(1,0); int ans=0; for(int i=1;i<=n;i++) ans=upt(ans+pw(f[i],k)); for(int i=1;i<n;i++) ans=upt(ans-pw(g[i],k)); printf("%d\n",ans); } } namespace S3{ const int N=2e5+5; int dp[N],nd[N],pr[N],sc[N],tot; void dfs(int cr,int fa) { dp[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { dfs(v,cr); dp[cr]=(ll)dp[cr]*(dp[v]+1)%mod; } } void dfsx(int cr,int fa,int tmp) { f[cr]=(ll)dp[cr]*(tmp+1)%mod; int l=tot; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { nd[++tot]=i; if(tot==l+1)pr[tot]=1; else pr[tot]=pr[tot-1]; pr[tot]=(ll)pr[tot]*(dp[v]+1)%mod; } int r=tot; for(int i=r;i>l;i--) { if(i==r)sc[i]=1; else sc[i]=sc[i+1]; sc[i]=(ll)sc[i]*(dp[to[nd[i]]]+1)%mod; } for(int i=l+1;i<=r;i++) { int v=to[nd[i]], tp=tmp+1, bh=nd[i]>>1; if(i>l+1)tp=(ll)tp*pr[i-1]%mod; if(i<r)tp=(ll)tp*sc[i+1]%mod; g[bh]=(ll)tp*dp[v]%mod; dfsx(v,cr,tp); } } void solve() { dfs(1,0); dfsx(1,0,0); int ans=0; for(int i=1;i<=n;i++) ans=upt(ans+pw(f[i],k)); for(int i=1;i<n;i++) ans=upt(ans-pw(g[i],k)); printf("%d\n",ans); } } namespace S4{ void solve() { int ans=0; for(int i=1;i<=n;i++) { int t0=Mn(L+1,i), t1=Mn(L+1,n-i+1); ans=upt(ans+pw((ll)t0*t1%mod,k)); } for(int i=1;i<n;i++) { int t0=Mn(L,i), t1=Mn(L,n-i); ans=upt(ans-pw((ll)t0*t1%mod,k)); } printf("%d\n",ans); } } int main() { n=rdn();L=rdn();k=rdn(); for(int i=1,u,v;i<n;i++) { u=rdn();v=rdn();add(u,v);add(v,u);} if(n<=1000){S1::solve();return 0;} if((ll)n*L<=1e7){S2::solve();return 0;} if(L==n){S3::solve();return 0;} bool fg=0; for(int i=1;i<=n;i++)if(rd[i]>2){fg=1;break;} if(!fg){S4::solve();return 0;} return 0; }