洛谷 P5276 模板题(uoi)
这题挺恶心的。
首先一颗树的时候点分加卷积统计答案,注意合并子树时按深度从小到大合并,否则复杂度就爆了。
我偷懒用size从小到大合并,复杂度应该还是两个log.
然后考虑万恶的环。
先随便删掉环上一条边,按照树统计一下答案。
然后考虑
必须经过环上该条边的答案但又不经过整个环的答案。
考虑再钦定一条边不经过,算答案。
然后递归做就行了。
最后加上经过整个环的答案。
时间复杂度\(O(n log^2(n))\)
// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include <bits/stdc++.h>
using namespace std;
typedef vector<int> poly;
typedef long long ll;
poly a,b;
const int P=1<<17;
const int M=998244353;
const int G=3;
int rev[P],w[P];
namespace{
int add(int x,int y){
return (x+=y)>=M?x-M:x;
}
int sub(int x,int y){
return (x-=y)<0?x+M:x;
}
int mul(int x,int y){
return (ll)x*y%M;
}
int fp(int x,int y){
int ret=1;
for (; y; y>>=1,x=mul(x,x))
if (y&1) ret=mul(ret,x);
return ret;
}
}
int inv2[30];
void init(int len){
for (int i=1; i<len; i<<=1){
w[i]=1;
if (i>1) w[i+1]=fp(G,(M-1)/(i<<1));
for (int j=2; j<i; ++j) w[i+j]=mul(w[i+j-1],w[i+1]);
//cerr<<w[i]<<" "<<w[i+1]<<" "<<w[i+2]<<endl;
}
inv2[0]=1;
inv2[1]=499122177;
int bit=1;
for (int i=4; i<=len; i<<=1){
++bit;
inv2[bit]=mul(inv2[bit-1],inv2[1]);
}
}
void NTT(int *a,int len){
for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i],a[rev[i]]);
for (int i=1; i<len; i<<=1){
for (int j=0; j<len; j+=(i<<1)){
int *l=a+j,*b=l+i,*ww=w+i;
for (int k=0; k<i; ++k){
int y=mul(*b,*(ww++));
(*b)=(*l)-y;
(*b)+=((*b)>>31)&M;
++b;
(*l)+=y-M;
(*l)+=((*l)>>31)&M;
++l;
}
}
}
}
void INTT(int *a,int len,int bit){
reverse(a+1,a+len);
NTT(a,len);
int ni=inv2[bit];
for (int i=0; i<len; ++i) a[i]=mul(a[i],ni);
}
poly operator *(const poly &u,const poly &v){
//cerr<<"mulfff"<<endl;
if ((ll)u.size()*v.size()<=(u.size()+v.size())*30){
//cerr<<u.size()<<" "<<v.size()<<endl;
poly ret(u.size()+v.size()-1);
for (int i=0; i<u.size(); ++i)
for (int j=0; j<v.size(); ++j)
ret[i+j]=add(ret[i+j],mul(u[i],v[j]));
return ret;
}
//cerr<<"?????"<<endl;
a=u;
b=v;
int len=1;
int bit=0;
for (; len<a.size()+b.size()-1; len<<=1) ++bit;
//cerr<<"len"<<len<<" "<<u.size()<<" "<<v.size()<<endl;
a.resize(len); b.resize(len);
for (int i=0; i<len; ++i) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
NTT(a.data(),len);
NTT(b.data(),len);
for (int i=0; i<len; ++i) a[i]=mul(a[i],b[i]);
INTT(a.data(),len,bit);
a.resize(u.size()+v.size()-1);
return a;
}
poly operator +(const poly &u,const poly &v){
poly ret(max(u.size(),v.size()));
for (int i=0; i<ret.size(); ++i){
int x=(i<u.size()?u[i]:0);
int y=(i<v.size()?v[i]:0);
ret[i]=add(x,y);
}
return ret;
}
void operator +=(poly &u,const poly &v){
//cerr<<"????"<<endl;
if (u.size()<v.size()) u.resize(v.size());
for (int i=0; i<v.size(); ++i) u[i]=add(u[i],v[i]);
//cerr<<"!!!!"<<endl;
}
ostream& operator <<(ostream& out,const poly &a){
for (auto i:a) out<<i<<" ";
return out<<endl;
}
void test(){
poly a({1,2}),b({2,3,2333});
a=a*b;
cerr<<a;
}
int n,m;
const int N=100010;
poly ans;
namespace solve1{
vector<int> e[N];
int sz[N],tmp[N],rt;
void Dfs(int x,int fa){
sz[x]=1;
for (auto i:e[x])
if (i!=fa){
Dfs(i,x);
sz[x]+=sz[i];
}
}
int calc(int y,int x){
return max(y-sz[x],tmp[x]);
}
void Getrt(int x,int fa,const int totsize){
//cerr<<"Getrt"<<x<<" "<<fa<<endl;
tmp[x]=0;
for (auto i:e[x])
if (i!=fa){
tmp[x]=max(sz[i],tmp[x]);
Getrt(i,x,totsize);
}
if (calc(totsize,x)<calc(totsize,rt)) rt=x;
}
void Getdeep(int x,int fa,poly &a,int nowdis){
//cerr<<"Getdeep"<<x<<" "<<fa<<endl;
++a[nowdis];
for (auto i:e[x])
if (i!=fa){
//cerr<<"???"<<i<<endl;
Getdeep(i,x,a,nowdis+1);
}
}
void df(int x){
//int t=clock();
Dfs(x,0);
rt=x;
int bbb=sz[rt];
//cerr<<"bbb"<<bbb<<endl;
Getrt(x,0,sz[x]);
//cerr<<"rt"<<rt<<" "<<sz[rt]<<endl;
//getchar();
for (auto i:e[rt])
if (sz[i]>sz[rt]) sz[i]=bbb-sz[rt];
sort(e[rt].begin(),e[rt].end(),[&](int x,int y){
return sz[x]<sz[y];
});
//cerr<<"???"<<endl;
poly c,b(1,1);
for (auto i:e[rt]){
//cerr<<"son"<<i<<" "<<sz[i]<<endl;
c.clear();
c.resize(sz[i]+1);
Getdeep(i,rt,c,1);
//cerr<<"Gend"<<c<<endl;
//cerr<<"mulend"<<c.size()<<endl;
ans+=b*c;
//cerr<<"AAAA"<<endl;
b+=c;
}
//cerr<<"ans"<<ans<<endl;
int fkrt=rt;
for (auto i:e[fkrt]){
e[i].erase(find(e[i].begin(),e[i].end(),fkrt));
df(i);
}
//cerr<<"dend"<<endl;
}
void main(int *fa){
for (int i=1; i<=n; ++i)
if (fa[i]){
//cerr<<"faf"<<i<<" "<<fa[i]<<endl;
e[fa[i]].push_back(i);
e[i].push_back(fa[i]);
}
df(1);
ans[0]=n;
}
}
int vis[N];
int fa[N],k,f;
vector<int> g[N];
void noloop(int x){
//cerr<<"noloop"<<x<<endl;
vis[x]=1;
for (auto i:g[x])
if (!vis[i]){
fa[i]=x;
noloop(i);
}
}
void Output(poly &a,int k,int f){
a.resize(k+1);
int ans1=0;
for (auto i:a) ans1=add(ans1,i);
cout<<ans1<<endl;
if (f) cout<<a;
}
int main(){
init(1<<17);
ios::sync_with_stdio(0);
cin.tie(0);
test();
cin.ignore(233,'\n');
cin>>n>>m>>k>>f;
//n=100000; m=n-1;
//k=100000; f=1;
//cerr<<n<<" "<<m<<" "<<k<<" "<<f<<endl;
for (int i=1; i<=m; ++i){
int x,y;
cin>>x>>y;
//x=rand()%i+1; y=i+1;
//cerr<<"add"<<x<<" "<<y<<endl;
g[x].push_back(y);
g[y].push_back(x);
}
noloop(1);
//cerr<<"What's the fuck?"<<endl;
solve1::main(fa);
if (m==n-1){
Output(ans,k,f);
return 0;
}
//Output(ans,k,f);
poly s;
int pp=0;
function<void(int,int)> findloop=[&](int x,int f){
vis[x]=2;
s.push_back(x);
for (auto i:g[x])
if (i!=f){
if (vis[i]!=2) findloop(i,x);
else pp=i;
if (pp) return;
}
s.pop_back();
};
findloop(1,0);
s.erase(s.begin(),find(s.begin(),s.end(),pp));
//cerr<<"cut"<<s.front()<<" "<<s.back()<<endl;
g[s.front()].erase(find(g[s.front()].begin(),g[s.front()].end(),s.back()));
g[s.back()].erase(find(g[s.back()].begin(),g[s.back()].end(),s.front()));
function<void(int,int,poly&,int)> ddd=[&](int x,int fa,poly &c,int dis){
if (dis>=c.size()) c.resize(dis+1);
++c[dis];
for (auto j:g[x])
if (j!=fa) ddd(j,x,c,dis+1);
};
auto Fakeadd=[&](poly &u,const poly &v,int len){
if (u.size()<v.size()+len) u.resize(v.size()+len);
for (int i=0; i<v.size(); ++i) u[i+len]=add(u[i+len],v[i]);
};
auto waylength=[&](int x,int y){
return y-x;
};
function<void(int,int,int)> solve=[&](int l,int r,int nowlen){
//cerr<<"solve"<<l<<" "<<r<<" "<<nowlen<<endl;
if (l==r) return;
//valid l~r point
int mid=(l+r)>>1;
//cut mid mid+1
//cerr<<"cut"<<s[mid]<<" "<<s[mid+1]<<endl;
g[s[mid]].erase(find(g[s[mid]].begin(),g[s[mid]].end(),s[mid+1]));
g[s[mid+1]].erase(find(g[s[mid+1]].begin(),g[s[mid+1]].end(),s[mid]));
//cerr<<"!!!"<<endl;
poly c,d;
ddd(s[l],0,c,0);
ddd(s[r],0,d,0);
//cerr<<"???"<<c<<" "<<d<<" "<<"noewln"<<nowlen<<endl;
Fakeadd(ans,c*d,nowlen);
//cerr<<"ANS"<<ans<<endl;
solve(l,mid,waylength(mid,r)+nowlen);
solve(mid+1,r,waylength(l,mid+1)+nowlen);
};
solve(0,s.size()-1,1);
for (auto i:s){
poly c;
ddd(i,0,c,0);
c[0]=0;
Fakeadd(ans,c,s.size());
}
ans[s.size()]=add(ans[s.size()],1);
Output(ans,k,f);
}