【洛谷P3934】炸脖龙 I
题目
题目链接:https://www.luogu.com.cn/problem/P3934
给一个长为 \(n\) 的序列,\(m\) 次操作,每次操作:
- 区间 \([l,r]\) 加 \(x\)。
- 对于区间 \([l,r]\),查询:\(a[l]^{a[l+1]^{a[l+2]^{\dots ^{a[r]}}}} \mod p\)。
\(n,m\leq 5\times 10^5\),\(p\leq 2\times 10^7\)。
思路
扩展欧拉定理:
\[a^b\equiv \left\{\begin{align*}a^b\space (b<\varphi(p))\\ a^{b\text{ mod }\varphi(p) +\varphi(p)}\space(b\geq\varphi(p)) \end{align*}\right.\pmod p
\]
设 \(g(l,r)=a[l]^{a[l+1]^{a[l+2]^{\dots ^{a[r]}}}}\space\space\space\),\(f(l,r,p)=g(l,r)\bmod p\)。
那么有
\[f(l,r,p)=a[l]^{f(l+1,r,\varphi(p))+[g(l+1,r)\geq \varphi(p)]\times \varphi(p)}\bmod p
\]
而因为 \(\varphi(\varphi(p))\)这样一直递归下去,\(O(\log p)\) 次后就会变成 \(1\),而当模数 \(p=1\) 时,显然答案为 \(0\)。所以说我们计算 \(f\) 只会递归 \(O(\log p)\) 层。
那么显然只需要支持区间加单点查询,树状数组即可。
为了降低代码复杂度,可以建一个类,第一维表示值,第二维表示是否超过当前的模数。然后重载一下 *
运算符。
注意指数要讨论是否需要 \(+\varphi(p)\) 这些细节。
时间复杂度 \(O(m\log p\log n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=500010,M=20000010;
int n,m,Q,prm[M],phi[M];
bool v[M];
int read()
{
int d=0; char ch=getchar();
while (!isdigit(ch)) ch=getchar();
while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
return d;
}
void findprm(int n)
{
for (int i=2;i<=n;i++)
{
if (!v[i]) prm[++m]=i,phi[i]=i-1;
for (int j=1;j<=m;j++)
{
if (i>n/prm[j]) break;
v[i*prm[j]]=1; phi[i*prm[j]]=phi[i]*(prm[j]-1);
if (!(i%prm[j])) { phi[i*prm[j]]=phi[i]*prm[j]; break; }
}
}
}
struct node
{
ll a,b;
friend node mul(node x,node y,ll p)
{
return (node){x.a*y.a%p,max((ll)(x.a*y.a>=p),max(x.b,y.b))};
}
};
node fpow(node x,ll k,ll mod)
{
node ans=(node){1,0};
if (x.a>=mod) x.a%=mod,ans.b=1;
for (;k;k>>=1,x=mul(x,x,mod))
if (k&1) ans=mul(ans,x,mod);
return ans;
}
struct BIT
{
ll c[N];
void add(int x,int v)
{
for (int i=x;i<=n;i+=i&-i)
c[i]+=v;
}
ll query(int x)
{
ll ans=0;
for (int i=x;i;i-=i&-i)
ans+=c[i];
return ans;
}
}bit;
node query(int l,int r,ll p)
{
ll x=bit.query(l);
if (p==1) return (node){0,1};
if (l==r) return (node){x%p,(ll)(x>=p)};
node res=query(l+1,r,phi[p]);
return fpow((node){x,0},res.a+res.b*phi[p],p);
}
int main()
{
findprm(M-1);
n=read(); Q=read();
for (int i=1;i<=n;i++)
{
int x=read();
bit.add(i,x); bit.add(i+1,-x);
}
while (Q--)
{
int opt=read(),l=read(),r=read(),p=read();
if (opt==1)
bit.add(l,p),bit.add(r+1,-p);
else
cout<<query(l,r,p).a<<"\n";
}
return 0;
}