luogu P4709 信息传递

https://www.luogu.com.cn/problem/P4709

首先肯定有如下定理:

  • 长度为 m m m的循环 k k k次幂得到的置换有 g c d ( m , k ) gcd(m,k) gcd(m,k)个大小相等的循环
    (可以看作每个数每次往后跳k个,显然恰好可以走 m / g c d ( m , k ) 步 m/gcd(m, k)步 m/gcd(m,k)
  • 由上面那个可得, f f f r r r个长度为 s s s的循环能拼在一起当且仅当 g c d ( r s , m ) = r gcd(rs, m)=r gcd(rs,m)=r
    可以发现大小不同的环互相不影响
    先考虑可以拼成一个循环的一组 r , s r,s r,s
    钦定最小的那个元素所在的循环为第一个循环的第一个元素
    那么剩下 r − 1 r-1 r1个环顺序随便共 ( r − 1 ) ! (r-1)! (r1)!
    每个环自己还可以转 s r − 1 s^{r-1} sr1
    记为 f ( r , s ) = ( r − 1 ) ! × s r − 1 f(r,s)=(r-1)!\times s^{r-1} f(r,s)=(r1)!×sr1

考虑一共有 c c c
那么可以设 d p [ i ] dp[i] dp[i]表示用了 i i i个环的方案数

d p [ i ] = ∑ r [ g c d ( r s , m ) = r ] ( i − 1 r − 1 ) f ( r , s ) d p [ i − r ] dp[i]=\sum\limits_{r}[gcd(rs,m)=r]\binom{i-1}{r-1}f(r,s)dp[i-r] dp[i]=r[gcd(rs,m)=r](r1i1)f(r,s)dp[ir]
因为环是无序的,所以规定第一个环必选,大概yy一下
时间复杂度 O ( n ∗ d ( n ) ) O(n*d(n)) O(nd(n))
code:

#include<bits/stdc++.h>
#define N 1000050
#define mod 998244353
#define ll long long
using namespace std;
int gcd(int x, int y) {
	return y? gcd(y, x % y) : x;
}
int qpow(ll x, int y) {
	ll ret = 1;
	for(; y; y >>= 1, x = x * x % mod) if(y & 1) ret = ret * x % mod;
	return ret;
}
ll fac[N], ifac[N];
int d[N], sz;
void init(int n) {
	fac[0] = 1;
	for(int i = 1; i <= n; i ++) fac[i] = fac[i - 1] * i % mod;
	ifac[n] = qpow(fac[n], mod - 2);
	for(int i = n - 1; i >= 0; i --) ifac[i] = ifac[i + 1] * (i + 1) % mod;
	for(int i = 1; i <= n; i ++) if(n % i == 0) d[++ sz] = i;
}
ll P(int r, int s) {
	return fac[r - 1] * qpow(s, r - 1) % mod;
}
ll C(int x, int y) {
	return fac[x] * ifac[y] % mod * ifac[x - y] % mod;
}
ll dp[N];
int n, a[N], cir[N], vis[N];
ll calc(int gs, int s) {
	dp[0] = 1;
	for(int i = 1; i <= gs; i ++) {
		dp[i] = 0;
		for(int j = 1; j <= sz && d[j] <= i; j ++) {
			int r = d[j];
			if(gcd(n, r * s) == r) dp[i] += C(i - 1, r - 1) * P(r, s) % mod * dp[i - r] % mod, dp[i] %= mod;
		}
	}
	return dp[gs];
}
int main() {
	scanf("%d", &n);
	init(n);
	for(int i = 1; i <= n; i ++) scanf("%d", &a[i]);
	for(int i = 1; i <= n; i ++) if(!vis[i]) {
		int x = i, size = 0;
		while(!vis[x]) {
			vis[x] = 1;
			x = a[x];
			++ size;
		}
		cir[size] ++;
	}
	ll ans = 1;
	for(int i = 1; i <= n; i ++) 
		if(cir[i]) ans = ans * calc(cir[i], i) % mod;
	printf("%lld", ans);
	return 0;
}
posted @ 2021-06-28 18:53  lahlah  阅读(47)  评论(0编辑  收藏  举报