I. Latitude Compressor 题解 - 2022 Hubei Provincial Collegiate Programming Contest
赛时基本对了,结果被我推的错误数值表 hack 了队友,并且由于对多项式的操作不熟,导致最后一步一直没想出来
【大意】
给定 \(n, m\) ,求集合 \(\left\{ q | \left(\begin{matrix}1&2&3&\cdots&n\\p_1&p_2&p_3&\cdots&p_n\end{matrix}\right) ^ m = \left(\begin{matrix}1&2&3&\cdots&n\\q_1&q_2&q_3&\cdots&q_n\end{matrix}\right) \right\}\)
其中 \(\left(\begin{matrix}1&2&3&\cdots&n\\p_1&p_2&p_3&\cdots&p_n\end{matrix}\right)\) 表示 \(p_1p_2\cdots p_n\) 对应的置换
【分析】
根据群论的基础知识,排列对应的置换可以拆解为若干个轮换,每个轮换又刚好对应到不同的环上
设某个环大小为 \(a_i\) ,则在置换 \(m\) 次后,其会分解为 \(\gcd(a_i, m)\) 个环,并且每个环有 \(\displaystyle {a_i\over \gcd(a_i, m)}\) 个元素
我们不妨考虑这 \(n\) 个元素在 \(m\) 次置换后构成了哪些环:
我们令 \(b_{x, y}\) 表示这 \(n\) 个元素在 \(m\) 次置换后,是否恰好含有 \(y\) 个 \(x\) 元环的逻辑值(可以有其他非 \(x\) 元环)
那么,设 \(n\) 个元素构成的环大小分别为 \(a_1, a_2, \cdots, a_t\) ,则 \(b_{x, y}\) 即为是否存在一组 \(a_1, a_2, \cdots, a_t\) 满足 \(\forall a_i\) 都有 \(\displaystyle {a_i\over \gcd(a_i, m)}=x\) 且 \(\displaystyle \sum_{i}\gcd(a_i, m)=y\)
队友是说,对于每个 \(x\) ,求出 \(\displaystyle t_x=\min_{d\mid x} [\gcd(x, {m\over d})=1]\) ,那么就有 \(b_{x, y}=[t_x\mid y]\)
考虑 \(n\) 个不同元素构成 \(n\) 元环的方案数,为 \(\displaystyle {n!\over n}=(n-1)!\)
因此,在 \(m\) 次置换后, \(ij\) 个元素恰好构成 \(j\) 个 \(i\) 元环的 EGF 即为 \(\displaystyle {1\over j!}\cdot ({(i-1)!x^i\over i!})^j={x^{ij}\over i^j\cdot j!}\)
故任意多的 \(i\) 元环的 EGF 为 \(\displaystyle \sum_{j=0}^\infty {x^{ij}\over i^j\cdot j!}\)
因此,总的 EGF 为 \(\displaystyle \prod_{i=1}^n (\sum_{j=0}^\infty {x^{ij}\over i^j\cdot j!})\)
故答案为 \(\displaystyle n!\cdot [x^n]\prod_{i=1}^n (\sum_{j=0}^\infty {x^{ij}\over i^j\cdot j!})\)
我们考虑令 \(\displaystyle F_i(x)=\sum_{j=0}^\infty {x^{j}\over i^j\cdot j!}\) ,则答案为 \(\displaystyle G(x)=\prod_{i=1}^n F_i(x^i)\pmod {x^{n+1}}\)
因此 \(\displaystyle \ln G(x)=\sum_{i=1}^n \ln F_i(x^i)\pmod {x^{n+1}}\)
令 \(\displaystyle \ln F_i(x)=\sum_{j=0}^{n/i} f_j x^j\) ,则 \(\displaystyle \ln G(x)=\sum_{i=1}^n (\sum_{j=0}^{n/i}f_jx^{ij})\)
因此对于每个 \(F_i(x)\) ,求解其具体数值的复杂度为 \(O(n/i)\) ;求解 \(\ln F_i(x)\) 的复杂度为 \(\displaystyle O((n/i)\log (n/i))=O({n\over i}\log n)\) ;再因此累加到 \(\ln G(x)\) 上,复杂度为 \(O(n/i)\)
因此,总复杂度为 \(\displaystyle T(n)=\sum_{i=1}^n [O(n/i)+O({n\over i}\log n)+O(n/i)]=\sum_{i=1}^n C\cdot {n\over i}\log n=C\cdot n\log n\cdot \sum_{i=1}^n {1\over i}=O(n\log ^2 n)\)
最后只需要 \(O(n\log n)\) 由 \(\ln G(x)\) 进行多项式 exp 求解出 \(G(x)\) 得到答案
而求解 \(b_{x, y}\) 需要的 \(t_x\) ,需要先预处理 \(m\) 的所有因子,数量级为 \(o(\sqrt m)\)
之后,再对每一个 \(x\) ,枚举 \(m,x\) 的因子求解 \(\gcd(x, {m\over d})=1\) 的 \(d\) 最小值,复杂度即为 \(O(n)\cdot o(\sqrt m)\cdot O(\log n)=O(n\log n\cdot \sqrt m)\)
因此,总复杂度为 \(T(n, m)=O(n\log n(\log n+\sqrt m))\)
【代码】
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define fi first
#define sz(a) (int)a.size()
#define de(x) cout << #x <<" = "<<x<<endl
#define dd(x) cout << #x <<" = "<<x<<" "
#define all(x) x.begin(), x.end()
#define pw(x) (1ll<<(x))
#define lc(x) ((x)<<1)
#define rc(x) ((x)<<1|1)
#define rsz(a, x) (a.resize(x))
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef double db;
const int P=998244353;
const int LimBit=18;
const int M=1<<LimBit<<1;
const int MAXN=5e4+10;
inline int kpow(int a, int x, int p=P) { int ans=1; for(;x;x>>=1, a=(ll)a*a%p) if(x&1) ans=(ll)ans*a%p; return ans; }
inline int exgcd(int a, int b, int &x, int &y) {
static int g;
return b?(exgcd(b, a%b, y, x), y-=a/b*x, g):(x=1, y=0, g=a);
}
inline int inv(int a, int p=P) {
static int x, y;
return exgcd(a, p, x, y)==1?(x<0?x+p:x):(-1);
}
namespace Poly {
const int G=3;
struct vir {
int v;
vir(int v_=0):v(v_>=P?v_-P:v_) {
}
inline vir operator + (const vir &x) const { return vir(v+x.v); }
inline vir operator - (const vir &x) const { return vir(v+P-x.v); }
inline vir operator * (const vir &x) const { return vir((ll)v*x.v%P); }
inline vir operator - () const { return vir(P-v); }
inline vir operator ! () const { return vir(inv(v)); }
inline operator int() const { return v; }
};
struct poly : public vector<vir> {
inline friend ostream& operator << (ostream& out, const poly &p) {
if(!p.empty()) out<<(int)p[0];
for(int i=1; i<sz(p); ++i) out<<" "<<(int)p[i];
return out;
}
};
int N, N_, Stk[M], curStk, rev[M];
vir invN, Inv[M], w[2][M];
inline void init() {
N_=-1;
curStk=0;
Inv[1]=1;
for(int i=2; i<M; ++i)
Inv[i]=-vir(P/i)*Inv[P%i];
}
inline void work() {
if(N_==N) return ;
N_=N;
int d=__builtin_ctz(N);
vir x(kpow(G, (P-1)/N)), y=!x;
w[0][0]=w[1][0]=1;
for(int i=1; i<N; ++i) {
rev[i]=(rev[i>>1]>>1)|((i&1)<<d-1);
w[0][i]=x*w[0][i-1], w[1][i]=y*w[1][i-1];
}
invN=!vir(N);
}
inline void FFT(vir a[M], int f) {
static auto make = [=](vir w, vir &a, vir &b) { w=w*a; a=b-w; b=b+w; };
for(int i=0; i<N; ++i) if(i<rev[i]) swap(a[i], a[rev[i]]);
for(int i=1; i<N; i<<=1)
for(int j=0, t=N/(i<<1); j<N; j+=i<<1)
for(int k=0, l=0; k<i; ++k, l+=t)
make(w[f][l], a[j+k+i], a[j+k]);
if(f) for(int i=0; i<N; ++i) a[i]=a[i]*invN;
}
vir p1[M], p0[M];
inline void get_mul(poly &a, poly &b, int na, int nb) {
for(N=1; N<na+nb-1; N<<=1);
for(int i=0; i<na; ++i) p1[i]=(int)a[i]; for(int i=na; i<N; ++i) p1[i]=0;
for(int i=0; i<nb; ++i) p0[i]=(int)b[i]; for(int i=nb; i<N; ++i) p0[i]=0;
work(); FFT(p1, 0); FFT(p0, 0);
for(int i=0; i<N; ++i) p1[i]=p1[i]*p0[i];
FFT(p1, 1);
rsz(a, na+nb-1); for(int i=0; i<sz(a); ++i) a[i]=p1[i];
}
poly a, b;
inline void get_inv(poly &f, poly &g, int n) {
int pos=curStk;
for(int i=n; i>1; i=i+1>>1) Stk[++curStk]=i;
rsz(g, 1); g[0]=!f[0];
b=f; rsz(b, n);
for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
get_mul(a=g, g, l+1>>1, l+1>>1); rsz(a, l);
get_mul(a, b, l, l);
rsz(g, l);
for(int i=0; i<l; ++i) g[i]=g[i]+g[i]-a[i];
}
}
inline void get_der(poly &f, poly &g) {
rsz(g, sz(f));
for(int i=0; i<sz(f)-1; ++i) g[i]=f[i+1]*vir(i+1);
rsz(g, sz(f)-1);
}
inline void get_int(poly &f, poly &g, int C=0) {
rsz(g, sz(f)+1);
for(int i=sz(f); i; --i) g[i]=f[i-1]*Inv[i]; g[0]=C;
}
inline void get_ln(poly &f, poly &g, int n, int ln0=0) {
get_inv(f, g, n);
get_der(f, a); rsz(a, n);
get_mul(g, a, n, n); rsz(g, n);
get_int(g, g, ln0); rsz(g, n);
}
poly c;
inline void get_exp(poly &f, poly &g, int n, int exp0=1) {
int pos=curStk;
for(int i=n; i>1; i=i+1>>1) Stk[++curStk]=i;
rsz(g, 1); g[0]=exp0;
for(int l=Stk[curStk]; curStk>pos; l=Stk[--curStk]) {
get_ln(g, c, l);
for(int i=0; i<l; ++i) c[i]=f[i]-c[i];
c[0]=c[0]+vir(1);
get_mul(g, c, l+1>>1, l); rsz(g, l);
}
}
}
using Poly::poly;
using Poly::vir;
poly a, b, c;
vector<int> divv;
int n, m;
int mind[MAXN];
vir fact[MAXN], inft[MAXN];
int fxy(int x, int y) { return y%mind[x]==0; }
inline void work() {
fact[0]=1;
for(int i=1; i<=n; ++i) fact[i]=vir(i)*fact[i-1];
inft[n]=!fact[n];
for(int i=n; i>=1; --i) inft[i-1]=vir(i)*inft[i];
rsz(a, n+1);
for(int i=1, x; i<=n; ++i) {
x=n/i;
rsz(b, x+1);
vir tmp=1, r=!vir(i);
for(int j=0; j<=x; ++j, tmp=tmp*r)
if(fxy(i, j))
b[j]=inft[j]*tmp;
else
b[j]=0;
Poly::get_ln(b, c, x+1);
for(int j=0, k=0; j<=x; ++j, k+=i)
a[k]=a[k]+c[j];
}
Poly::get_exp(a, b, n+1);
vir x=fact[n]*b[n];
cout<<(int)x;
}
inline int solve(int x) {
for(auto d : divv)
if(__gcd(x, d)==1)
return m/d;
}
inline void divit(int m) {
for(int i=1; i*i<=m; ++i) if(m%i==0) {
divv.push_back(i);
if(i!=m/i) divv.push_back(m/i);
}
sort(divv.begin(), divv.end());
reverse(divv.begin(), divv.end());
}
inline void init() {
Poly::init();
cin>>n>>m;
divit(m);
for(int x=1; x<=n; ++x)
mind[x]=solve(x);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
init();
work();
cout.flush();
return 0;
}