[THUPC 2023 初赛] 快速 LCM 变换
题目描述
小 I 今天学习了快速最小公倍数变换(Fast Least-Common-Multiple Transform, FLT),于是他想考考你。
给定一个长度为 \(n\) 的正整数序列 \(r_1,r_2,\cdots,r_n\)。你需要做以下操作恰好一次:
- 选择整数 \(i,j\) 使得 \(1 \le i < j \le n\)。在序列末尾加入 \((r_i+r_j)\),并将 \(r_i\) 和 \(r_j\) 从序列中删除。
可以注意到总共有 \(\frac{n(n-1)}{2}\) 种可能的操作,每种操作会得到一个长度为 \(n-1\) 的序列。
你需要对所有的这 \(\frac{n(n-1)}{2}\) 个序列,求出序列中所有元素的最小公倍数,并给出它们的和模 \(998244353\) 的值。
输入格式
输入的第一行包含一个正整数 \(n\),表示序列的长度。接下来一行 \(n\) 个正整数 \(r_1,r_2,\cdots,r_n\),描述初始序列。
输出格式
输出一行一个整数,表示所有序列的最小公倍数的和模 \(998244353\) 的值。
样例 #1
样例输入 #1
3
2 3 4
样例输出 #1
40
提示
样例解释 1
- \(i=1,j=2\) 时,得到的序列为 \(\{4,5\}\),最小公倍数为 \(20\);
- \(i=1,j=3\) 时,得到的序列为 \(\{3,6\}\),最小公倍数为 \(6\);
- \(i=2,j=3\) 时,得到的序列为 \(\{2,7\}\),最小公倍数为 \(14\)。
因此输出为 \(20+6+14=40\)。
子任务
对于所有测试数据,\(2 \le n \le 5 \times 10^5, 1 \le r_1,r_2,\cdots,r_n \le 10^6\)。
题目来源
来自 2023 清华大学学生程序设计竞赛暨高校邀请赛(THUPC2023)初赛。
题解等资源可在 https://github.com/THUSAAC/THUPC2023-Pre 查看。
求最小公倍数还要取模?只能拆成质数之幂再乘起来了。
那么一个质数的哪些幂次最终可能会被计入答案呢?乍一看,次数前三大的都用到。但仔细分析一下,其实只用次数前两大的数。
这是因为,如果同时删掉了前两大次数的数,那么会增加 \(p^{k_1}+p^{k_2}\),那么他一定是 \(p^{k_2}\) 的倍数。
设原来的最小公倍数为 \(s\),那么如果删掉某一个数之后,答案会变成 \(s\) 的多少倍是固定的,也是分开的。我们可以枚举质数,算出删某一个数 \(x\) 时的答案 \(p_x\)。
此时如果我删掉数字 \(x\) 和 \(y\),那么此时 \(s\) 会变成 \(p_xp_ys\),然后我们还要统计 \(x+y\) 时带来的贡献。
可以通过 NTT 计算出选择和为 \(a\) 时的各种方案的 \(s\) 倍数之和,此时要计算 \(a\) 会变成 \(s\) 的多少倍。分解质因数后看他的次数是否超过 \(s\) 的次数就行了。
注意要好好去重。
#include<bits/stdc++.h>
using namespace std;
const int P=998244353,N=2097152,iv3=332748118,iv2=P+1>>1;
int a[N],f[N],p[N],cmx[N],mx[N],r[N],n,nw[N],ret=1,ans,c[N];
vector<int>g[N];
int pown(int x,int y)
{
if(!y)
return 1;
int t=pown(x,y>>1);
if(y&1)
return 1LL*t*t%P*x%P;
return 1LL*t*t%P;
}
void ntt(int a[],int op)
{
for(int i=0;i<N;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int md=1;md<N;md<<=1)
{
int g=pown(op? 3:iv3,(P-1)/(md<<1));
for(int i=0;i<N;i+=md<<1)
{
int pw=1;
for(int j=0;j<md;j++,pw=1LL*pw*g%P)
{
int ax=pw*1LL*a[i+j+md]%P;
a[i+j+md]=(a[i+j]-ax+P)%P;
(a[i+j]+=ax)%=P;
}
}
}
}
int main()
{
// freopen("P9135_1.in","r",stdin);
for(int i=2;i<N;i++)
if(!g[i].size())
for(int j=1;j*i<N;j++)
g[i*j].push_back(i);
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",a+i),f[a[i]]=1,c[a[i]]++;
for(int i=1;i<N;i++)
r[i]=r[i>>1]>>1|((i&1)*N/2);
for(int i=1;i<=n;i++)
{
for(int j=0;j<g[a[i]].size();j++)
{
int cnt=0,k=a[i];
while(k%g[a[i]][j]==0)
++cnt,k/=g[a[i]][j];
if(cnt>mx[g[a[i]][j]])
{
cmx[g[a[i]][j]]=mx[g[a[i]][j]],mx[g[a[i]][j]]=cnt;
nw[g[a[i]][j]]=a[i];
}
else if(cnt>cmx[g[a[i]][j]])
cmx[g[a[i]][j]]=cnt;
}
}
for(int i=1;i<N;i++)
if(mx[i])
f[nw[i]]=1LL*f[nw[i]]*pown(pown(i,mx[i]-cmx[i]),P-2)%P,ret=1LL*ret*pown(i,mx[i])%P;
for(int i=0;i<N;i++)
p[i]=f[i]*1ll*c[i]%P;
ntt(p,1);
for(int i=0;i<N;i++)
p[i]=1LL*p[i]*p[i]%P;
ntt(p,0);
for(int i=0;i<N;i++)
p[i]=1LL*p[i]*pown(N,P-2)%P;
// for(int i=1;i<=8;i++)
// printf("%d \n",p[i]);
for(int i=1;i<N;i++)
{
if(i*2<N)
(p[i*2]+=P-1LL*c[i]*f[i]%P*f[i]%P)%=P;
// if(i<=8)
// printf("%d ",p[i]);
for(int j=0;j<g[i].size();j++)
{
int cnt=0,k=i;
while(k%g[i][j]==0)
++cnt,k/=g[i][j];
if(cnt>mx[g[i][j]])
p[i]=1LL*p[i]*pown(g[i][j],cnt-mx[g[i][j]])%P;
}
(ans+=p[i])%=P;
}
printf("%lld\n",ans*1LL*ret%P*iv2%P);
}