bzoj 3561: DZY Loves Math VI
3561: DZY Loves Math VI
Time Limit: 10 Sec Memory Limit: 256 MBSubmit: 489 Solved: 323
[Submit][Status][Discuss]
Description
给定正整数n,m。求
Input
一行两个整数n,m。
Output
一个整数,为答案模1000000007后的值。
Sample Input
5 4
Sample Output
424
HINT
数据规模:
1<=n,m<=500000,共有3组数据。
Source
常年不做数论了,略感生疏。。。
首先和gcd有关的可以枚举公因数:
ANS=∑(d= 1 to n) d^d *f(n/d,m/d,d)。
其中f(x,y,z)表示1<=i<=x,1<=j<=y中gcd(i,j)==1的(xy)^z之和。
f直接求不好求,我们考虑反演得到f。
设g(x,y,z)为1<=i<=x,1<=j<=y的(xy)^z之和。
拆括号可以得到:g(x,y,z)=(∑ i^z)*(∑ j^z)。
g可以用f表示: g(x,y,z)=∑f(x/i,y/i,z)* i^(2*z)。
然后用g反演得f: f(x,y,z)=∑g(x/i,y/i,z) * i^(2*z) * μ(i)
总式子太长了,懒得打上了。。。
总之这个题最后的式子没法用一般的数论dark技巧优化,只能暴力求。。
也就是外层d枚举,d,n/d,m/d 确定了之后内层也是暴力做。。
(所以我也不知道我怎么没有TLE)
(而且感觉我内层的预处理既浪费了运行时间又显得很蠢hhh)
/************************************************************** Problem: 3561 User: JYYHH Language: C++ Result: Accepted Time:7860 ms Memory:11936 kb ****************************************************************/ #include<bits/stdc++.h> #define ll long long #define ha 1000000007 #define maxn 500005 using namespace std; ll n,m,k,ans=0,ci[maxn],qz[maxn]; int miu[maxn],t=0,zs[maxn/5]; bool v[maxn]; inline ll ksm(ll x,ll y){ ll an=1; for(;y;y>>=1,x=x*x%ha) if(y&1) an=an*x%ha; return an; } inline void init(){ miu[1]=1; for(int i=2;i<=500000;i++){ if(!v[i]) zs[++t]=i,miu[i]=-1; for(int j=1,u;j<=t&&(u=zs[j]*i)<=500000;j++){ v[u]=1; if(!(i%zs[j])) break; miu[u]=-miu[i]; } } // for(int i=1;i<=500000;i++) miu[i]+=miu[i-1]; } inline ll mul(ll x,ll y){ x*=y; if(x>=ha) x-=x/ha*ha; return x; } inline ll f(ll x,ll y,ll z){ // printf("round %lld %lld %lld:\n"); ll an=0,now,a1,a2; qz[1]=1,ci[1]=1; for(int i=2;i<=z;i++){ ci[i]=ksm(i,y); qz[i]=ci[i]*ci[i]; if(qz[i]>=ha) qz[i]-=qz[i]/ha*ha; qz[i]*=miu[i]; if(qz[i]<0) qz[i]+=ha; ci[i]+=ci[i-1]; if(ci[i]>=ha) ci[i]-=ha; qz[i]+=qz[i-1]; if(qz[i]>=ha) qz[i]-=ha; } for(int i=1,j;i<=x;i=j+1){ a1=x/i,a2=z/i; j=min(x/a1,z/a2); now=qz[j]-qz[i-1]+ha; if(now>=ha) now-=ha; an+=mul(mul(ci[a1],ci[a2]),now); if(an>=ha) an-=ha; } return an; } int main(){ init(); scanf("%lld%lld",&n,&m); if(n>m) swap(n,m); for(int i=1;i<=n;i++){ ans+=ksm(i,i)*f(n/i,i,m/i); if(ans>=ha) ans-=ans/ha*ha; } printf("%lld\n",ans); return 0; }
我爱学习,学习使我快乐