WC2019 数树

多谢yn大佬的指点。

op=0

挺简单的。

op=1

\(T1,T2\)为两棵树
考虑我们知道\(T1\),不知道\(T2\)
但是答案只是多了个\(\sum\)

\[ans= \sum_{T2} y^{|T1 \& T2|} \]

考虑重合了\(m\)条边

\[ans= \sum_{T2} y^{n-m} \]

\[ans= \sum_{m=0}^n y^{n-m} f(m) \]

\(f(m)\)为与\(T1\)重合了\(m\)条边的\(T2\)数量
考虑\(f(m)\)不好算,令\(g(m)\)为与\(T1\)至少重合了\(m\)条边的\(T2\)数量
\(g(m)= \sum_{i=m}^{n-1} \tbinom{i}{m} f(i)\)
考虑把\(ans\)的表达式搞一搞
\(ans=y^n \sum_{m=0}^{n-1} (y-1+1)^{-m} f(m)\)
\(k=\frac{1}{y}-1\)
\(ans=y^n \sum_{m=0}^{n-1} (k+1)^m f(m)\)
\(ans=y^n \sum_{m=0}^{n-1} \sum_{i=0}^m \tbinom{m}{i} k^i f(m)\)
\(ans=y^n \sum_{m=0}^{n-1} k^i g(m)\)

然后考虑怎么算\(g(m)\)
\(l=n-m\)
那么\(g(m)\)等价于将\(T2\)划分成\(l\)个联通块(每个联通块也是一棵树),并把这些联通块连成一棵树的方案数。
那么有一个经典结论Cayley定理。
\(l\)个大小为\(a[i]\)的树连成一颗大树的方案数是\((\sum_{i=1}^l a[i])^{l-2} \prod_{i=1}^l a[i]\)
证明可以用prufer序列或者矩阵数定理。

那么
\(g(m)= \sum_{(\sum_{i=1}^l a[l] )= n} n^{l-2} \prod_{i=1}^l a[i]\)

考虑一起DP这个式子以及\(ans\)关于\(g\)的表达式。
首先观察式子,发现有一个乘联通块元素数的操作,那么可以认为是在一个联通块内选一个关键点的方案数。
每加入一个联通块,对应的\(n\)的指数就会加一,而相应的\(k\)的指数会减一,可以考虑在联通块关键点放这个贡献。
那么设\(f[i][0]\)为dp完了\(i\)的子树,\(i\)所在联通块还没有放关键点的贡献和,
同理\(f[i][1]\)表示放了。
那么考虑一对父子关系\(x,son\)
\(x,son\)在一个联通块的情况,那么
\(newf[x][0]+=f[x][0]*f[son][0]\)
\(newf[x][1]+=f[x][1]*f[son][0]+f[x][0]*f[son][1]\)
若不在一个联通块,那么孩子所在联通块应该是已有关键点的。
\(newf[x][0]+=f[x][0]*f[son][1]\)
\(newf[x][1]+=f[x][1]*f[son][1]\)
这样dp就行了,最后乘上一些分离出的常数。

op=2

说实话这应该比\(op=1\)简单。
考虑树的\(EGF\) \(f(x)=i^{i-2}\frac{x^i}{i!}\)
那么把他m次方就是m棵树的\(EGF\)
然后还是考虑\(g(l)\)表示至少有\(l\)条边相同
那么\(g(l)=(f(x)^m)[x^n]*n^{m-2}\prod a_i\)
那么把所有与\(m\)有关的东西都弄到\(EGF\)里就好了。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int M=998244353,G=3;
struct poly{
    vector<int> g;
    __attribute((always_inline)) size_t size() const {
        return g.size();
    }
    __attribute((always_inline)) poly(){
    }
    __attribute((always_inline)) poly(size_t x):g(x){
    }
    __attribute((always_inline)) void resize(size_t x){
        g.resize(x);
    }
    __attribute((always_inline)) poly cat(size_t x){
    	g.resize(x);
        return *this;
    }
    __attribute((always_inline)) int& operator [](int x){
        return g[x];
    }
    __attribute((always_inline)) int operator [](int x) const {
        return g[x];
    }
    void operator =(int x){
        g.resize(1); g[0]=x;
    }
    friend istream& operator >>(istream &is,poly &A){
        int n; is>>n;
        A.resize(n);
        for (int i=0; i<n; ++i) is>>A[i];
        return is;
    }
    friend ostream& operator <<(ostream &os,const poly &A){
        for (size_t i=0; i<A.size(); ++i) os<<A[i]<<" ";
        return os;
    }
};
__attribute((always_inline)) int fp(int x,int y){
    int ret=1;
    for (; y; y>>=1,x=(ll)x*x%M) if (y&1) ret=(ll)ret*x%M;
    return ret;
}
__attribute((always_inline)) int D(int x){
    return x>=M?x-M:x;
}
__attribute((always_inline)) int U(int x){
    return x<0?x+M:x; 
}
const int P=270000;
int a[P],w[P],b[P],rev[P];
void fft(int *a,int n){
    for (int i=0; i<n; ++i) if (i>rev[i]) swap(a[i],a[rev[i]]);
    for (int i=1; i<n; i<<=1){
        w[0]=1; w[1]=fp(G,(M-1)/(i<<1));
        for (int j=2; j<i; ++j) w[j]=(ll)w[j-1]*w[1]%M;
        for (int j=0; j<n; j+=(i<<1))
        for (int k=j; k<j+i; ++k){
            int x=a[k],y=(ll)a[k+i]*w[k-j]%M;
          	a[k]=x+y>=M?x+y-M:x+y;
          	a[k+i]=x-y<0?x-y+M:x-y;
//            a[k]=D(x+y);
//            a[k+i]=U(x-y);
        }
    }
}
void print(const poly &A){
    for (size_t i=0; i<A.size(); ++i) cout<<A[i]<<" ";
    cout<<endl; 
}
poly operator *(const poly &A,const int &x){
    poly B(A.size());
    for (size_t i=0; i<A.size(); ++i) B[i]=(ll)A[i]*x%M; 
    return B;
}
poly operator +(const poly &A,const int &x){//x>=0 x<M 
    poly B=A;
    B[0]=D(B[0]+x);
    return B;
}
poly operator -(const poly &A,const int &x){//x>=0 x<M
    poly B=A;
    B[0]=U(B[0]-x);
    return B;
}
poly operator *(const int &x,const poly &A){
    poly B(A.size());
    for (size_t i=0; i<A.size(); ++i) B[i]=(ll)A[i]*x%M; 
    return B;
}
poly operator +(const int &x,const poly &A){//x>=0 x<M 
    poly B=A;
    B[0]=D(B[0]+x);
    return B;
}
poly operator -(const int &x,const poly &A){//x>=0 x<M
    poly B=A;
    B[0]=U(B[0]-x);
    return B;
}
poly operator *(const poly &A,const poly &B){
    size_t u,bit=0;
    for (u=1; u<A.size()+B.size()-1; u<<=1,++bit); --bit;
    for (size_t i=0; i<A.size(); ++i) a[i]=A[i];
    memset(a+A.size(),0,sizeof(*a)*(u-A.size()));
    for (size_t i=0; i<B.size(); ++i) b[i]=B[i];
    memset(b+B.size(),0,sizeof(*b)*(u-B.size()));
    for (size_t i=0; i<u; ++i) rev[i]=rev[i>>1]>>1|((i&1)<<bit);
    fft(a,u); fft(b,u);
    for (size_t i=0; i<u; ++i) a[i]=(ll)a[i]*b[i]%M;
    reverse(a+1,a+u);
    fft(a,u);
    int ni=fp(u,M-2); for (size_t i=0; i<u; ++i) a[i]=(ll)a[i]*ni%M;
    poly ret(A.size()+B.size()-1);
    for (size_t i=0; i<ret.size(); ++i) ret[i]=a[i];
    return ret;
}
poly cat(const poly &_,size_t x){//_=cat(_,x) is equal to _.resize(x)
    poly ret(x);
    if (_.size()>=x) for (size_t i=0; i<x; ++i) ret[i]=_[i];
    else for (size_t i=0; i<_.size(); ++i) ret[i]=_[i];
    return ret;
}
poly pow(const poly &A,const int &x){//A^x has limit 
    size_t u,bit=0;
    for (u=1; u<(A.size()*x<<1)-1; u<<=1,++bit); --bit;
    for (size_t i=0; i<A.size(); ++i) a[i]=A[i];
    memset(a+A.size(),0,sizeof(*a)*(u-A.size()));
    for (size_t i=0; i<u; ++i) rev[i]=rev[i>>1]>>1|((i&1)<<bit);
    fft(a,u);
    for (size_t i=0; i<u; ++i) a[i]=fp(a[i],x);
    reverse(a+1,a+u);
    fft(a,u);
    int ni=fp(u,M-2); for (size_t i=0; i<u; ++i) a[i]=(ll)a[i]*ni%M;
    poly ret((A.size()*x<<1)-1);
    for (size_t i=0; i<ret.size(); ++i) ret[i]=a[i];
    return ret;
}
poly pow(const poly &A,const int &x,const int &C){
    poly ret(1); ret[0]=1; poly B=cat(A,C);
    #ifdef FAST
    	for (int y=x; y; y>>=1,B=(B*B).cat(C)) if (y&1) ret=(ret*B).cat(C);
    #else
        for (int y=x; y; y>>=1,B=cat(B*B,C)) if (y&1) ret=cat(ret*B,C);
 	#endif
    return ret;
}
poly operator +(const poly &A,const poly &B){
    if (A.size()>=B.size()){
        poly ret=A;
        for (size_t i=0; i<B.size(); ++i) ret[i]=D(ret[i]+B[i]);
        return ret;
    }
    poly ret=B;
    for (size_t i=0; i<A.size(); ++i) ret[i]=D(ret[i]+A[i]);
    return ret;
}
poly operator -(const poly &A,const poly &B){
    poly ret=A;
    if (ret.size()<B.size()) ret.resize(B.size()); 
    for (size_t i=0; i<B.size(); ++i) ret[i]=U(ret[i]-B[i]);
    return ret;
}
poly getrev(const poly &A){
    size_t u; for (u=1; u<A.size()+A.size()-1; u<<=1);
    poly B(1); B[0]=fp(A[0],M-2);
    #ifdef FAST
    	for (size_t i=2; i<=u; i<<=1) B=B+B-(cat(A,i>>1)*(B*B).cat(i>>1)).cat(i);
    #else
    	for (size_t i=2; i<=u; i<<=1) B=B+B-cat(cat(A,i>>1)*cat(B*B,i>>1),i);
    #endif
    return B;
}
poly diff(const poly &A){
    poly B(A.size()-1);
    for (size_t i=0; i<B.size(); ++i) B[i]=(ll)A[i+1]*(i+1)%M; 
    return B;
}
#ifdef FAST
    int inv[P];
#endif
poly inte(const poly &A){//slow
    poly B(A.size()+1);
    #ifdef FAST
    	//prework inv
    	for (size_t i=1; i<B.size(); ++i) B[i]=(ll)A[i-1]*inv[i]%M;
    #else
    	for (size_t i=1; i<B.size(); ++i) B[i]=(ll)A[i-1]*fp(i,M-2)%M;
    #endif
    return B;
}
poly ln(const poly &A){//A[0]=1
    #ifdef FAST
        return inte(diff(A)*getrev(A).cat(A.size()));
    #else
    	return inte(diff(A)*cat(getrev(A),A.size()));
    #endif
}
poly exp(const poly &A){//A[0]=0
    size_t u; for (u=1; u<A.size()+A.size()-1; u<<=1);
    poly B(1); B[0]=1;
    #ifdef FAST
    	for (size_t i=2; i<=u; i<<=1) B=B+(B*(cat(A,i>>1)-ln(B).cat(i>>1))).cat(i);
    #else
        for (size_t i=2; i<=u; i<<=1) B=B+cat(B*(cat(A,i>>1)-cat(ln(B),i>>1)),i);
    #endif
    return B;
}
poly sqrt(const poly &A){//???
    size_t u; for (u=1; u<A.size()+A.size()-1; u<<=1);
    poly B(1); B[0]=sqrt(A[0]);
    #ifdef FAST
    	for (size_t i=2; i<=u; i<<=1) B=((cat(A,i>>1)+(B*B).cat(i>>1))*(getrev(B+B)).cat(i>>1)).cat(i);
    #else
    	for (size_t i=2; i<=u; i<<=1) B=cat((cat(A,i>>1)+cat(B*B,i>>1))*cat(getrev(B+B),i>>1),i);
    #endif
    return B;
}
int ADD(int x,int y){
	return (x+=y)>=M?x-M:x;
}
int SUB(int x,int y){
	return (x-=y)<0?x+M:x;
}
int MUL(int x,int y){
	return (ll)x*y%M;
}
const int N=1e5+10;
struct solve2{
	void main(int n,int y){
		if (y==1) return void(cout<<fp(n,2*(n-2)));
		int pre=fp(y,n);
		//cerr<<"pre"<<pre<<endl;
		int yp=fp(y,M-2);
		pre=MUL(pre,fp(yp-1,n));
		pre=MUL(pre,fp(fp(n,4),M-2));
		int k=MUL(MUL(n,n),fp(yp-1,M-2));
		//cerr<<"????"<<k<<endl;
		static int fac[N];
		static int invfac[N];
		fac[0]=1;
		for (int i=1; i<=n; ++i) fac[i]=MUL(fac[i-1],i);
		invfac[n]=fp(fac[n],M-2);
		for (int i=n-1; i>=0; --i) invfac[i]=MUL(invfac[i+1],i+1);
		//cerr<<"???"<<endl;
		poly g(n+1);
		for (int i=0; i<=n; ++i)
			g[i]=(MUL(MUL(k,(i>0?fp(i,i):0)),invfac[i]));
		assert(g[0]==0);
		g=exp(g);
		cout<<MUL(MUL(g[n],pre),fac[n]);
	}
}A;
struct solve1{
	int f[N][2];
	int k;
	vector<int> e[N];
	void dfs(int x,int fa){
		//cerr<<"dfs"<<x<<" "<<fa<<endl;
		f[x][1]=k;
		f[x][0]=1;
		for (auto i:e[x]){
			if (i==fa) continue;
			dfs(i,x);
			//set i x are same
			int tmp[2]={0,0};
			tmp[0]=ADD(tmp[0],MUL(f[i][0],f[x][0]));
			tmp[1]=ADD(tmp[1],MUL(f[i][0],f[x][1]));
			tmp[1]=ADD(tmp[1],MUL(f[i][1],f[x][0]));
			//cerr<<i<<" "<<x<<" "<<tmp[0]<<" "<<tmp[1]<<endl;
			//set i x are different
			tmp[0]=ADD(tmp[0],MUL(f[i][1],f[x][0]));
			tmp[1]=ADD(tmp[1],MUL(f[i][1],f[x][1]));
			//cerr<<i<<" "<<x<<" "<<tmp[0]<<" "<<tmp[1]<<endl;
			tie(f[x][0],f[x][1])=tie(tmp[0],tmp[1]);
		}
		//cerr<<"ED"<<x<<" "<<f[x][0]<<" "<<f[x][1]<<endl;
	}
	void main(int n,int y){
		if (y==1){
			cout<<fp(n,n-2);
			return;
		}
		for (int i=1; i<n; ++i){
			int x,y;
			scanf("%d%d",&x,&y);
			e[x].push_back(y);
			e[y].push_back(x);
		}
		int py=fp(y,M-2);
		k=MUL(fp(py-1,M-2),n);
		//cerr<<"k"<<k<<endl;
		dfs(1,0);
		//cerr<<f[1][1]<<endl;
		cout<<MUL(MUL(f[1][1],fp(MUL(py-1,y),n)),fp(MUL(n,n),M-2));
	}
}B;
struct solve0{
	unordered_set<ll> mp;
	void main(int n,int y){
		for (int i=1; i<n; ++i){
			int x,y;
			scanf("%d%d",&x,&y);
			if (x>y) swap(x,y);
			mp.insert(x*1000000ll+y);
		}
		int fit=0;
		for (int i=1; i<n; ++i){
			int x,y;
			scanf("%d%d",&x,&y);
			if (x>y) swap(x,y);
			if (mp.count(x*1000000ll+y)) ++fit;
		}
		cout<<fp(y,n-fit);
	}
}C;
int n,y,op;
int main(){
	scanf("%d%d%d",&n,&y,&op);
	if (op==2) A.main(n,y);
	else if (op==1) B.main(n,y);
	else C.main(n,y);
}
posted @ 2019-02-25 20:54  Yuhuger  阅读(241)  评论(0编辑  收藏  举报