[CSP-S模拟测试]:走路(期望DP+分治消元)
题目传送门(内部题100)
输入格式
第一行两个整数$n,m$,接下来$m$行每行两个整数$u,v$表示一条$u$连向$v$的边。不保证没有重边和自环。
输出格式
$n-1$行每行一个整数,第$i$行表示$k=i+1$时的答案。对$998244353$取模。
样例
样例输入:
3 6
1 1
1 2
2 1
2 2
1 3
3 1
样例输出:
4
5
数据范围与提示
数据范围:
对于$25\%$的数据,$n\leqslant 50$。
对于另外$20\%$的数据,前$m-1$条边满足$u<v$。
对于另外$15\%$的数据,不存在$u,v$使得$u!=v$且$\min(u,v)>1$。
对于$100\%$的数据,$1\leqslant n\leqslant 300,1\leqslant m\leqslant 10^5,1\leqslant u,v\leqslant n$。
提示:
对于质数$p$和有理数$\frac{a}{b}(b\mod p>0)$,存在恰好一个整数$c$满足$0\leqslant c<p$且$a\equiv bc(\mod p)$,我们称$c$为$\frac{a}{b}$对$p$取模的结果。
题解
概率正着推,期望倒着推。
不妨设$f[i]$表示从$i$出发走到$k$的期望步数,那么可以列出式子:
$$f[i]=\sum \frac{f[j]}{du[i]}+1$$
$f[k]=0$
其中$j$是$i$的所有出边所能到达的点,$du[i]$则为$i$的出度。
那么枚举所有的$k$,然后暴力高斯消元即可拿到$25$分。
考虑优化,因为高斯消元中一段的值可以对很多$k$做贡献,所以我们可以用分治消元。
原理就是:对于区间$[l,r]$,先消$[mid+1,r]$,然后继续递归$[l,mid]$,递归之后再消$[l,mid]$,接着递归$[mid+1,r]$。
时间复杂度:$\Theta(n^3\log n)$。
期望得分:$100$分。
实际得分:$100$分。
代码时刻
#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
int n,m;
int du[301],top;
long long Map[301][302],ans[301];
long long qpow(long long x,long long y)
{
long long res=1;
while(y)
{
if(y&1)res=res*x%mod;
x=x*x%mod;y>>=1;
}
return res;
}
void gauss(int x,int l,int r)
{
long long inv=qpow(Map[x][x],mod-2);
for(int i=1;i<=n+1;i++)Map[x][i]=Map[x][i]*inv%mod;
for(int i=1;i<=n;i++)
{
if(i==x)continue;
long long flag=Map[i][x];
for(int j=l;j<=r;j++)
Map[i][j]=(Map[i][j]-Map[x][j]*flag)%mod;
Map[i][n+1]=(Map[i][n+1]-Map[x][n+1]*flag)%mod;
}
}
void solve(int l,int r)
{
if(l==r){ans[l]=(Map[1][n+1]+mod)%mod;return;}
int mid=(l+r)>>1;int wzc[301][302];
for(int i=1;i<=n;i++)
for(int j=1;j<=n+1;j++)
wzc[i][j]=Map[i][j];
for(int i=mid+1;i<=r;i++)gauss(i,l,r);
solve(l,mid);
for(int i=1;i<=n;i++)
for(int j=1;j<=n+1;j++)
Map[i][j]=wzc[i][j];
for(int i=l;i<=mid;i++)gauss(i,l,r);
solve(mid+1,r);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++)
{
int u,v;
scanf("%d%d",&u,&v);
Map[u][v]++;du[u]++;
}
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
Map[i][j]=Map[i][j]*qpow(du[i],mod-2)%mod;
Map[i][i]+=(Map[i][n+1]=-1);
}
solve(1,n);
for(int i=2;i<=n;i++)printf("%d\n",ans[i]);
}
rp++