专项测试 数学4

A. 传统题

让求的答案为 \(\sum\limits_{i=1}^niF(i)\)

\(F(i)\) 为答案为 \(i\) 的方案数

直接求不好求,可以简单转化一下变成 \(\sum\limits_{i=1}^n(m^n-F(ans<i))\)

那么考虑求 \(\sum\limits_{i=1}^nF(ans<i)\rightarrow\sum\limits_{i=0}^{n-1}F(ans\leq i)\)

枚举最后的答案分成了 \(j\) 段就是 \(\sum\limits_{i=0}^n\sum\limits_{j=1}^nm*(m-1)^{j-1}\)

每段的颜色都不和前面的一段相同所以是 \(m*(m-1)^{j-1}\)

然后每种方案都要再乘上一个方案数表示 \(j\) 个小于等于 \(i\) 的数加和为 \(n\) 的方案数

没有限制的话就是 \(\binom{n-1}{j-1}\) 和昨天的一样插板

可以容斥一下,枚举一个 \(k\) 表示有 \(k\) 个数大于了 \(i\) ,那么先给这些数分配 \(i\) 然后就可以插板了 \(\sum\limits_{k=0}^j(-1)^k\binom{j}{k}\binom{n-ik-1}{j-1}\)

式子变成 \(\sum\limits_{i=0}^n\sum\limits_{j=1}^nm*(m-1)^{j-1}\sum\limits_{k=0}^j(-1)^k\binom{j}{k}\binom{n-ik-1}{j-1}\)

\(m\)\(k\) 提一提变成 \(m\sum\limits_{i=0}^{n-1}\sum\limits_{k=1}^n(-1)^k\sum\limits_{j=1}^n(m-1)^{j-1}\binom{j}{k}\binom{n-ik-1}{j-1}\)

接下来一步神奇的操作把 \(\binom{n-ik-1}{j-1}\) 变成 \(\frac{1}{n-ik}\binom{n-ik}{j}*j\)

然后发现后面的就变成了 \(\sum\limits_{j=1}^n(m-1)^{j-1}\binom{j}{k}\binom{n-ik}{j}j\)

假装给他按个组合意义变换一下

\(n-ik\) 个球里选 \(j\) 个,再从 \(j\) 个里面选一个特殊的球不染色,剩下的都用 \(m-1\) 种颜色去染色

再分类讨论看看,选出来不染色的球在 \(k\) 个的里面还是外面

  1. 在里面,于是把 \(k-1\) 个染色再选出一个不染色 \(k(m-1)^{k-1}\) ,剩下的球里选一个子集染色,相当于用 \(m\) 种颜色去染,额外的一种颜色相当于不选 \(m^{n-ik-k}\)

  2. 在外面,把里面的 \(k\) 个染色,外面的还是选子集和上面一样而且要再选一个不染的 \((m-1)^k(n-ik-k)m^{n-ik-k-1}\)

于是 \(f(k,ik)=\sum\limits_{j=1}^n(m-1)^{j-1}\binom{j}{k}\binom{n-ik}{j}j=k(m-1)^{k-1}m^{n-ik-k}+(m-1)^k(n-ik-k)m^{n-ik-k-1}\binom{n-ik}{k}\)

最后的组合数表示你总共选出来的那 \(k\)

最后就变成了 \(m\sum\limits_{i=0}^{n-1}\sum\limits_{k=1}^n(-1)^k\frac{1}{n-ik}f(k,ik)\)

根据 \(k+ik\leq n\) 去限制边界,就可以做到 \(O(n\log n)\)

Code
#include<bits/stdc++.h>
#define int long long
#define rint signed
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
int n,m,mod,ans,cnt;
int k2[300010];
int fac[300010],ifac[300010],inv[300010];
int Sm[300010],Sm1[300010];
inline int C(int n,int m){return fac[n]*ifac[m]%mod*ifac[n-m]%mod;}
inline int qpow(int x,int k){
	int res=1,base=x;
	while(k){if(k&1) res=res*base%mod;base=base*base%mod;k>>=1;}
	return res;
}
inline int calc(int k,int ik){
	int res=0;
	if(k-1>=0&&n-ik-k>=0) res+=k*Sm1[k-1]%mod*Sm[n-ik-k]%mod;
	if(n-ik-k-1>=0&&k>=0) res+=Sm1[k]*(n-ik-k)%mod*Sm[n-ik-k-1]%mod; 
	return res%mod*C(n-ik,k)%mod;
}
signed main(){
#ifdef LOCAL
	freopen("in","r",stdin);
	freopen("out","w",stdout);
#endif
	n=read(),m=read(),mod=read();
	k2[0]=1;for(int i=1;i<=300000;i++) k2[i]=k2[i-1]*2%mod;
	inv[1]=1;for(int i=2;i<=300000;i++) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
	fac[0]=ifac[0]=1;for(int i=1;i<=300000;i++) fac[i]=fac[i-1]*i%mod,ifac[i]=ifac[i-1]*inv[i]%mod;
	Sm[0]=1,Sm1[0]=1;for(int i=1;i<=300000;i++) Sm[i]=Sm[i-1]*m%mod,Sm1[i]=Sm1[i-1]*(m-1)%mod;
	for(int k=0,r=1;k<=n;k++,r=-r) for(int i=0;i<n&&i*k+k<=n;i++) (ans+=r*calc(k,i*k)*inv[n-i*k]%mod+mod)%=mod;
	printf("%lld\n",(n*Sm[n]%mod-m*ans%mod+mod)%mod);
	return 0;
}

B. 生成树

考虑矩阵树定理实际求出来的东西就是 \(\sum(生成树的边权之积)\)

然后你所有的生成树都是由红绿蓝三种颜色构成的,你要限制其中一些边的选择数量

直接根据限制不好做,可以考虑把所有情况的生成树数量都求出来

相当于一共有 \(\frac{n*(n+1)}{2}\) 个不同变量,所以你可以枚举绿色和蓝色两种边的边权

然后每种都求一个矩阵树,这样你就可以得到 \(\frac{n*(n+1)}{2}\) 方程了,于是可以高斯消元解

再根据选择的边的数量加入答案就行

Code
#include<bits/stdc++.h>
#define int long long
#define rint signed
#define mod 1000000007
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
int n,m,g,b,t,ans;
struct E{int x,y,z;}edge[100010];
int a[50][50];
inline int qpow(int x,int k){
	int res=1,base=x;
	while(k){if(k&1) res=res*base%mod;base=base*base%mod;k>>=1;}
	return res;
}
inline void build(int G,int B){
	memset(a,0,sizeof(a));
	for(int i=1,x,y,z;i<=m;i++){
		x=edge[i].x,y=edge[i].y,z=edge[i].z;
		if(z==1) a[x][x]+=1,a[y][y]+=1,a[x][y]-=1,a[y][x]-=1;
		if(z==2) a[x][x]+=G,a[y][y]+=G,a[x][y]-=G,a[y][x]-=G;
		if(z==3) a[x][x]+=B,a[y][y]+=B,a[x][y]-=B,a[y][x]-=B;
	}
}
inline int solve(){
	int res=1;
	for(int i=2;i<=n;i++) for(int j=i+1,t;j<=n;j++) while(a[j][i]){
		t=a[i][i]/a[j][i];
		for(int k=i;k<=n;k++) a[i][k]=(a[i][k]-t*a[j][k])%mod;
		swap(a[i],a[j]);
		res=-res;
	}
	for(int i=2;i<=n;i++) res=res*a[i][i]%mod;
	return (res+mod)%mod;
}
int mp[1010][1010],lim;
inline void gauss(int n){
	for(int i=1,p,INV;i<=n;i++){
		if(!mp[i][i]) for(int j=i+1;j<=n;j++) if(mp[j][i]){swap(mp[i],mp[j]);break;}
		INV=qpow(mp[i][i],mod-2);
		for(int j=1;j<=n+1;j++) mp[i][j]=mp[i][j]*INV%mod;
		for(int j=1;j<=n;j++){
			if(i==j||mp[j][i]==0) continue;
			INV=mp[j][i]*qpow(mp[i][i],mod-2)%mod;
			for(int k=1;k<=n+1;k++) mp[j][k]=(mp[j][k]-INV*mp[i][k]+mod)%mod;
		}
	}
}
namespace RG{
	signed main(){
		for(int i=1;i<=m;i++) edge[i].x=read(),edge[i].y=read(),edge[i].z=read();
		for(int i=0,p;i<n;i++){
			build(i,0);lim++;p=0;
			for(int k=0;k<n;k++) mp[lim][++p]=qpow(i,k)%mod;
			mp[lim][n+1]=solve();
		}
		gauss(lim);
		for(int k=0,p=0;k<n;k++){p++;if(k<=g) (ans=ans+mp[p][lim+1])%=mod;}
		printf("%lld\n",ans);
		return 0;
	}
}
signed main(){
#ifdef LOCAL
	freopen("in","r",stdin);
	freopen("out","w",stdout);
#endif
	n=read(),m=read(),g=read(),b=read();if(!b) return RG::main();
	for(int i=1;i<=m;i++) edge[i].x=read(),edge[i].y=read(),edge[i].z=read();
	for(int i=0;i<n;i++) for(int j=0,p;i+j<n;j++){
		build(i,j);lim++;p=0;
		for(int k=0;k<n;k++) for(int l=0;l+k<n;l++) mp[lim][++p]=qpow(i,k)*qpow(j,l)%mod;
		mp[lim][n*(n+1)/2+1]=solve();
	}
	gauss(lim);
	for(int k=0,p=0;k<n;k++) for(int l=0;l+k<n;l++){p++;if(k<=g&&l<=b) (ans=ans+mp[p][lim+1]+mod)%=mod;}
	printf("%lld\n",ans);
	return 0;
}

C. 最短路径

如果是树的话很好做可以直接 \(ntt\) 加点分治

现在给定的是一个基环树

所以先按照基环树的套路把环先找出来,然后对于环上的每个点都去做一边点分治把子树内的答案算出来

再考虑环上的点之间的贡献,随便找一边把环破开,发现如果从中间分开,那么左右两块内部的贡献不会走那一条边

于是可以分治,把左右两块内部的答案算出来,每次都从以中间为分治中心,然后再加上一个偏移量就行

再考虑左右两块之间的,把左右两块再分开,变成 \(4\) 个小块,从左到右分别标号为 \(1,2,3,4\)

那么 \(2,3\) 的贡献不跨过环 \(1,4\) 的跨过环,这两部分的贡献可以用上边分治用的方法求出

剩下的 \(1,3\)\(2,4\) 的贡献又是一个子问题还可以分治递归

最后就统计完了所有路径长度的数量

Code
#include<bits/stdc++.h>
#define int long long
#define rint signed
#define mod 998244353
#define i2 499122177
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
int n,k,len,L,INV,ans;
int dis[262200],dep[100010],pre[100010];
int r[262200],w[262200];
int siz[100010],mx[100010],rt,S;
int head[100010],ver[200010],to[200010],tot;
int s[100010],num;
vector<int>f[100010];
bool vis[100010],ic[100010];
inline void add(int x,int y){ver[++tot]=y;to[tot]=head[x];head[x]=tot;}
inline int qpow(int x,int k){
	int res=1,base=x%mod;
	while(k){if(k&1) res=res*base%mod;base=base*base%mod;k>>=1;}
	return res;
}
bool findcycle(int x,int fa){
	dep[x]=dep[fa]+1;
	for(int i=head[x];i;i=to[i]){
		int y=ver[i];
		if(y==fa) continue;
		if(!dep[y]){
			pre[y]=x;
			if(findcycle(y,x)) return true;
		}else if(dep[y]<dep[x]){
			int p=x;
			while(p!=y){s[++num]=p;ic[p]=1;p=pre[p];}
			s[++num]=y;ic[y]=1;
			return true;
		}
	}
	return false;
}
inline void ntt(vector<int> &a){
	for(int i=0;i<len;i++) if(i<r[i]) swap(a[i],a[r[i]]);
	for(int d=1,t=len>>1;d<len;d<<=1,t>>=1) for(int i=0;i<len;i+=(d<<1)) for(int j=0;j<d;j++){
		int tmp=w[t*j]*a[i+j+d]%mod;
		a[i+j+d]=(a[i+j]-tmp+mod)%mod;
		a[i+j]=(a[i+j]+tmp)%mod;
	}
}
inline vector<int> polymul(vector<int> f,vector<int> g){
	int l1=f.size(),l2=g.size();for(len=1,L=0;len<=l1+l2;len<<=1,L++);
	for(int i=0;i<len;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
	w[0]=1,w[1]=qpow(3,(mod-1)/len);for(int i=2;i<len;i++) w[i]=w[i-1]*w[1]%mod;
	f.resize(len);g.resize(len);
	ntt(f);ntt(g);
	for(int i=0;i<len;i++) f[i]=f[i]*g[i]%mod;
	w[0]=1,w[1]=qpow(w[1],mod-2);for(int i=2;i<len;i++) w[i]=w[i-1]*w[1]%mod;
	ntt(f);INV=qpow(len,mod-2);
	for(int i=0;i<len;i++) f[i]=f[i]*INV%mod;
	return f;
}
inline void polyadd(vector<int> &f,vector<int> g){
	if(f.size()<g.size()) f.resize(g.size());
	for(int i=0;i<g.size();i++) f[i]+=g[i];
}
void getrt(int x,int fa){
	siz[x]=1,mx[x]=0;
	for(int i=head[x];i;i=to[i]){
		int y=ver[i];
		if(y==fa||vis[y]) continue;
		getrt(y,x);
		siz[x]+=siz[y];mx[x]=max(mx[x],siz[y]);
	}
	mx[x]=max(mx[x],S-siz[x]);
	if(mx[x]<mx[rt]) rt=x;
}
void dfs(int x,int fa,int dep,vector<int> &f){
	if(dep>=f.size()) f.resize(dep+1);f[dep]++;siz[x]=1;
	for(int i=head[x];i;i=to[i]){
		int y=ver[i];
		if(y==fa||vis[y]) continue;
		dfs(y,x,dep+1,f);
		siz[x]+=siz[y];
	}
}
void solve(int x){
	vis[x]=1;vector<int> F;
	dfs(x,0,0,F);
	F=polymul(F,F);
	for(int i=0;i<F.size();i++) (dis[i]+=F[i])%=mod;
	for(int i=head[x];i;i=to[i]){
		int y=ver[i];vector<int> G;
		if(vis[y]) continue;
		dfs(y,x,1,G);
		G=polymul(G,G);
		for(int i=0;i<G.size();i++) (dis[i]+=-G[i]+mod)%=mod;
	}
	for(int i=head[x];i;i=to[i]){
		int y=ver[i];
		if(vis[y]) continue;
		S=siz[y];mx[rt=0]=inf;
		getrt(y,0);solve(rt);
	}
}
inline void calcs(int l1,int r1,int l2,int r2,int dist){//l<-mid->r
	int m1=0,m2=0;
	for(int i=l1;i<=r1;i++) m1=max(m1,r1-i+(int)f[i].size());
	for(int i=l2;i<=r2;i++) m2=max(m2,i-l2+(int)f[i].size());
	vector<int> a,b;a.resize(m1);b.resize(m2);
	for(int i=l1;i<=r1;i++) for(int j=0;j<f[i].size();j++) (a[r1-i+j]+=f[i][j])%=mod;
	for(int i=l2;i<=r2;i++) for(int j=0;j<f[i].size();j++) (b[i-l2+j]+=f[i][j])%=mod;
	a=polymul(a,b);
	for(int i=0;i<a.size();i++) (dis[i+dist]+=a[i])%=mod;
}
void solve1(int l,int r){
	if(l==r) return ;
	int mid=(l+r)>>1;
	solve1(l,mid);solve1(mid+1,r);
	calcs(l,mid,mid+1,r,1);
}
void solve2(int l1,int r1,int l2,int r2,int dis){
	if(l1==r1) return calcs(l1,r1,l2,r2,dis),void();
	int mid1=(l1+r1)>>1,mid2=(l2+r2)>>1;
	calcs(mid1+1,r1,l2,mid2,dis);
	solve2(l1,mid1,l2,mid2,dis+r1-mid1);
	solve2(mid1+1,r1,mid2+1,r2,dis+mid1-l1+1);
}
signed main(){
#ifdef LOCAL
	freopen("in","r",stdin);
	freopen("out","w",stdout);
#endif
	n=read(),k=read();bool fg=0;
	for(int i=1,x,y;i<=n;i++){
		x=read(),y=read();
		if(x==y){fg=1;continue;}
		add(x,y),add(y,x);
	}
	if(!fg) findcycle(1,0);else s[++num]=1;
	for(int i=1;i<=num;i++) vis[s[i]]=1;
	for(int i=1,x;i<=num;i++){
		x=s[i];vis[x]=0;
		dfs(x,0,0,f[i]);
		S=siz[x],mx[rt=0]=inf;
		getrt(x,0);solve(rt);
	}
	for(int i=1;i<=n;i++) dis[i]=dis[i]*i2%mod;
	if(num>1){
		solve1(1,num/2);solve1(num/2+1,num);
		if(num&1){
			solve2(1,num/2,num/2+1,num-1,1);
			solve2(num/2+2,num,1,num/2,1);
		}else{
			solve2(1,num/2,num/2+1,num,1);
			solve2(num/2+2,num,1,num/2-1,1);
		}
	}
	for(int i=1;i<=n;i++) ans=(ans+dis[i]*qpow(i,k)%mod)%mod;
	printf("%lld\n",ans*qpow(n*(n-1)/2,mod-2)%mod);
	return 0;
}

后话

因为取模丢了 \(130\) \(T2\) \(100\) \(T3\) \(30\) ,嘻嘻

posted @ 2022-01-08 21:26  Max_QAQ  阅读(58)  评论(0编辑  收藏  举报