luogu P4709 信息传递
https://www.luogu.com.cn/problem/P4709
首先肯定有如下定理:
- 长度为
m
m
m的循环
k
k
k次幂得到的置换有
g
c
d
(
m
,
k
)
gcd(m,k)
gcd(m,k)个大小相等的循环
(可以看作每个数每次往后跳k个,显然恰好可以走 m / g c d ( m , k ) 步 m/gcd(m, k)步 m/gcd(m,k)步) - 由上面那个可得,
f
f
f中
r
r
r个长度为
s
s
s的循环能拼在一起当且仅当
g
c
d
(
r
s
,
m
)
=
r
gcd(rs, m)=r
gcd(rs,m)=r
可以发现大小不同的环互相不影响
先考虑可以拼成一个循环的一组 r , s r,s r,s
钦定最小的那个元素所在的循环为第一个循环的第一个元素
那么剩下 r − 1 r-1 r−1个环顺序随便共 ( r − 1 ) ! (r-1)! (r−1)!种
每个环自己还可以转 s r − 1 s^{r-1} sr−1
记为 f ( r , s ) = ( r − 1 ) ! × s r − 1 f(r,s)=(r-1)!\times s^{r-1} f(r,s)=(r−1)!×sr−1
考虑一共有
c
c
c个
那么可以设
d
p
[
i
]
dp[i]
dp[i]表示用了
i
i
i个环的方案数
则
d
p
[
i
]
=
∑
r
[
g
c
d
(
r
s
,
m
)
=
r
]
(
i
−
1
r
−
1
)
f
(
r
,
s
)
d
p
[
i
−
r
]
dp[i]=\sum\limits_{r}[gcd(rs,m)=r]\binom{i-1}{r-1}f(r,s)dp[i-r]
dp[i]=r∑[gcd(rs,m)=r](r−1i−1)f(r,s)dp[i−r]
因为环是无序的,所以规定第一个环必选,大概yy一下
时间复杂度
O
(
n
∗
d
(
n
)
)
O(n*d(n))
O(n∗d(n))
code:
#include<bits/stdc++.h>
#define N 1000050
#define mod 998244353
#define ll long long
using namespace std;
int gcd(int x, int y) {
return y? gcd(y, x % y) : x;
}
int qpow(ll x, int y) {
ll ret = 1;
for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
return ret;
}
ll fac[N], ifac[N];
int d[N], sz;
void init(int n) {
fac[0] = 1;
for(int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod;
ifac[n] = qpow(fac[n], mod - 2);
for(int i = n - 1; i >= 0; i --) ifac[i] = ifac[i + 1] * (i + 1) % mod;
for(int i = 1; i <= n; i ++) if(n % i == 0) d[++ sz] = i;
}
ll P(int r, int s) {
return fac[r - 1] * qpow(s, r - 1) % mod;
}
ll C(int x, int y) {
return fac[x] * ifac[y] % mod * ifac[x - y] % mod;
}
ll dp[N];
int n, a[N], cir[N], vis[N];
ll calc(int gs, int s) {
dp[0] = 1;
for(int i = 1; i <= gs; i ++) {
dp[i] = 0;
for(int j = 1; j <= sz && d[j] <= i; j ++) {
int r = d[j];
if(gcd(n, r * s) == r) dp[i] += C(i - 1, r - 1) * P(r, s) % mod * dp[i - r] % mod, dp[i] %= mod;
}
}
return dp[gs];
}
int main() {
scanf("%d", &n);
init(n);
for(int i = 1; i <= n; i ++) scanf("%d", &a[i]);
for(int i = 1; i <= n; i ++) if(!vis[i]) {
int x = i, size = 0;
while(!vis[x]) {
vis[x] = 1;
x = a[x];
++ size;
}
cir[size] ++;
}
ll ans = 1;
for(int i = 1; i <= n; i ++)
if(cir[i]) ans = ans * calc(cir[i], i) % mod;
printf("%lld", ans);
return 0;
}