题目链接
https://atcoder.jp/contests/agc034/tasks/agc034_f
题解
无论多水的题我都不会啊.jpg
首先考虑一个图上随机游走的经典问题,无向图求从\(0\)号点出发随机游走到每个点的期望时间。做法是显然答案等于从每个点走到\(0\)号点的期望时间,然后列方程高斯消元。
设答案向量为\(\textbf{x}\), 则有\(x_i=\sum_{j\ \text{xor}\ k=i}p_kx_j+1(1\le i\lt 2^n)\)且\(x_0=0\). 即\(\textbf{x}\)与\(\textbf{p}\)异或卷积的结果为\(\begin{bmatrix}x_0' & x_1-1 & x_2-1 & ... & x_{2^n-1}-1\end{bmatrix}\). 观察到\(\sum^{2^n-1}_{i=0}p_i=1\),故\(\sum^{2^n-1}_{i=0}x_i=x_0'+\sum^{2^n-1}_{i=1}(x_i-1)\), 即\(x_0'=x_0+2^n-1\). 再用\(p_0-1\)替换\(p_0\)得\(\textbf{x}\)与\(\textbf{p}\)异或卷积的结果为常数数列\(\textbf{a}=\begin{bmatrix}2^n-1&-1&-1&...&-1\end{bmatrix}\).
现在已知\(\textbf{p}\)和\(\textbf{a}\),要求出\(\textbf{x}\). FWT后的数列对应位置作除法即可。设\(\text{FWT}(\textbf{p})=\textbf{P}\) (其余字母同理), 则\(\forall 1\le i\le 2^n-1, P_i\lt \sum^{2^n-1}_{i=0}p_i=P_0=0\), 也即\(\textbf{P}\)序列有且仅有\(P_0\)为\(0\). \(P_0\)与\(A_0\)皆为\(0\), 我们无法还原出\(X_0\).
\(\textbf{x}=\text{IFWT}(\textbf{X})\), 设\(\textbf{X'}=\begin{bmatrix}0&X_1&X_2&...&X_{2^n-1}\end{bmatrix}\), 则\(\forall 0\le i\le 2^n-1, x'_i=x_i-\frac{X_0}{2^n}=x_i-(x_0-x'_0)=x_i+x'_0\), 故用\(x_i=x'_i-x'_0\)计算即可。
时间复杂度\(O(2^nn)\).
代码
#include<bits/stdc++.h>
#define llong long long
using namespace std;
inline int read()
{
int x = 0,f = 1; char ch = getchar();
for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
return x*f;
}
const int N = 18;
const int P = 998244353;
const llong INV2 = 499122177ll;
llong p[(1<<N)+3];
llong a[(1<<N)+3],b[(1<<N)+3];
int n,sum;
llong quickpow(llong x,llong y)
{
llong cur = x,ret = 1ll;
for(int i=0; y; i++)
{
if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
cur = cur*cur%P;
}
return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}
void fwt(int dgr,int coe,llong poly[],llong ret[])
{
memcpy(ret,poly,sizeof(llong)*(1<<dgr));
for(int i=0; i<dgr; i++)
{
for(int j=0; j<(1<<dgr); j+=(1<<i+1))
{
for(int k=0; k<(1<<i); k++)
{
llong x = poly[k+j],y = poly[k+(1<<i)+j];
poly[k+j] = x+y>=P?x+y-P:x+y; poly[k+(1<<i)+j] = x-y<0?x-y+P:x-y;
}
}
}
if(coe==-1) {llong tmp = mulinv(1<<dgr); for(int i=0; i<(1<<dgr); i++) ret[i] = ret[i]*tmp%P;}
}
int main()
{
scanf("%d",&n); for(int i=0; i<(1<<n); i++) {scanf("%lld",&p[i]); sum += p[i];} sum = mulinv(sum);
for(int i=0; i<(1<<n); i++) p[i] = p[i]*sum%P; p[0] = (p[0]-1+P)%P;
a[0] = (1<<n)-1; for(int i=1; i<(1<<n); i++) a[i] = P-1;
fwt(n,1,a,a); fwt(n,1,p,p);
b[0] = 0ll; for(int i=1; i<(1<<n); i++) b[i] = a[i]*mulinv(p[i])%P;
fwt(n,-1,b,b);
llong tmp = b[0]; for(int i=0; i<(1<<n); i++) b[i] = (b[i]-tmp+P)%P;
for(int i=0; i<(1<<n); i++) printf("%lld\n",b[i]);
return 0;
}