21.5.13 t2
tag:背包dp,数论
首先可以把给定的排列分成若干循环,将长度相同的分为一组,则可以分别处理每组然后乘起来。
对于一组数量为 \(cnt_a\) 长度为 \(a\) 的循环,再分成若干组,假设其中一组有 \(b\) 个,则必须满足 \(\gcd(ab,k)=b\)。而这样一组的贡献为 \((b-1)!a^{b-1}\)
所以相当于是将 \(cnt_a\) 划分成若干整数 \(b_i\),满足 \(\gcd(ab_i,k)=b_i\),然后一种划分方案的贡献为 \(\binom{cnt_a}{b_1\ \cdots\ b_k}\Pi(b_i-1)!a^{b_i-1}=cnt_a!\Pi\frac 1{b_i}a^{b_i-1}\)
所以想到一种做法,设 \(f_i\) 表示前 \(i\) 个分成若干组,然后可以枚举最后一组分多少个进行dp。
考虑它的复杂度,为 \(O(n*\)合法的\(b\)的个数\()\),实际上是 \(\sigma_0(k)\) 级别的,然后跑得飞快(卡常是在输入部分)
简易证明:
分别考虑每一个质数 \(p\),设 \(a,b,k\) 分别包含 \(n_a,n_b,n_k\) 个 \(p\)。
则有 \(\min(n_a+n_b,n_k)=n_b\)
- 若 \(n_a=0\),则 \(n_b\le n_k\)
- 若 \(n_a\not=0\),则 \(n_b=n_k\)
所以 \(b|k\)
#include<bits/stdc++.h>
using namespace std;
namespace IO {
#define getc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++
//#define getc() getchar()
char buf[1<<21],*p1,*p2,ch;
void rd(int &x){
x=0;char c=getc();
while(c<48||c>57) c=getc();
while(c>=48&&c<=57) x=x*10+c-48,c=getc();
}
}
using IO::rd;
template<typename T>
inline void Read(T &n){
char ch; bool flag=false;
while(!isdigit(ch=getchar())) if(ch=='-')flag=true;
for(n=ch^48; isdigit(ch=getchar()); n=(n<<1)+(n<<3)+(ch^48));
if(flag) n=-n;
}
#define Read rd
enum{
MAXN = 10000005,
MOD = 998244353
};
inline int ksm(int base, int k=MOD-2){
int res=1;
while(k){
if(k&1)
res = 1ll*res*base%MOD;
base = 1ll*base*base%MOD;
k >>= 1;
}
return res;
}
inline void upd(int &a, long long b){a = (a+b)%MOD;}
int n, k;
int a[MAXN];
char vis[MAXN];
int cnt[MAXN], inv[MAXN];
int q[MAXN], val[MAXN], f[MAXN], jc[MAXN];
int gcd(int a, int b){return b?gcd(b,a%b):a;}
inline int calc(int len){
int top=0, num = cnt[len];
for(register int i=1; i<=k; i++) if(k%i==0 and gcd(len,k/i)==1) q[++top] = i, val[top] = ksm(len,i-1);
f[0] = 1; int tp = 0; q[top+1] = 1e9;
for(register int i=1; i<=num; i++){
f[i] = 0;
for(register int j=1; q[j]<=i; j++)
f[i] = (f[i]+1ll*f[i-q[j]]*val[j])%MOD;
f[i] = 1ll*f[i]*inv[i]%MOD;
}
return 1ll*f[num]*jc[num]%MOD;
}
int main(){
// freopen("3.in","r",stdin);
// freopen("3.out","w",stdout);
Read(n); Read(k);
for(register int i=1; i<=n; i++) Read(a[i]);
jc[0] = 1; for(register int i=1; i<=n; i++) jc[i] = 1ll*jc[i-1]*i%MOD;
inv[1] = 1; for(register int i=2; i<=n; i++) inv[i] = 1ll*(MOD-MOD/i)*inv[MOD%i]%MOD;
for(register int i=1; i<=n; i++) if(!vis[i]){
int len=1, x=i; vis[i] = 1;
while(!vis[a[x]]) vis[x = a[x]] = true, len++;
cnt[len]++;
}
int ans=1;
for(register int i=1; i<=n; i++) if(cnt[i]) ans = 1ll*ans*calc(i)%MOD;
cout<<ans<<'\n';
return 0;
}