【洛谷P3934】炸脖龙 I

题目

题目链接:https://www.luogu.com.cn/problem/P3934
给一个长为 \(n\) 的序列,\(m\) 次操作,每次操作:

  1. 区间 \([l,r]\)\(x\)
  2. 对于区间 \([l,r]\),查询:\(a[l]^{a[l+1]^{a[l+2]^{\dots ^{a[r]}}}} \mod p\)

\(n,m\leq 5\times 10^5\)\(p\leq 2\times 10^7\)

思路

扩展欧拉定理:

\[a^b\equiv \left\{\begin{align*}a^b\space (b<\varphi(p))\\ a^{b\text{ mod }\varphi(p) +\varphi(p)}\space(b\geq\varphi(p)) \end{align*}\right.\pmod p \]

\(g(l,r)=a[l]^{a[l+1]^{a[l+2]^{\dots ^{a[r]}}}}\space\space\space\)\(f(l,r,p)=g(l,r)\bmod p\)
那么有

\[f(l,r,p)=a[l]^{f(l+1,r,\varphi(p))+[g(l+1,r)\geq \varphi(p)]\times \varphi(p)}\bmod p \]

而因为 \(\varphi(\varphi(p))\)这样一直递归下去,\(O(\log p)\) 次后就会变成 \(1\),而当模数 \(p=1\) 时,显然答案为 \(0\)。所以说我们计算 \(f\) 只会递归 \(O(\log p)\) 层。
那么显然只需要支持区间加单点查询,树状数组即可。
为了降低代码复杂度,可以建一个类,第一维表示值,第二维表示是否超过当前的模数。然后重载一下 * 运算符。
注意指数要讨论是否需要 \(+\varphi(p)\) 这些细节。
时间复杂度 \(O(m\log p\log n)\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=500010,M=20000010;
int n,m,Q,prm[M],phi[M];
bool v[M];

int read()
{
	int d=0; char ch=getchar();
	while (!isdigit(ch)) ch=getchar();
	while (isdigit(ch)) d=(d<<3)+(d<<1)+ch-48,ch=getchar();
	return d;
}

void findprm(int n)
{
	for (int i=2;i<=n;i++)
	{
		if (!v[i]) prm[++m]=i,phi[i]=i-1;
		for (int j=1;j<=m;j++)
		{
			if (i>n/prm[j]) break;
			v[i*prm[j]]=1; phi[i*prm[j]]=phi[i]*(prm[j]-1);
			if (!(i%prm[j])) { phi[i*prm[j]]=phi[i]*prm[j]; break; }
		}
	}
}

struct node
{
	ll a,b;
	
	friend node mul(node x,node y,ll p)
	{
		return (node){x.a*y.a%p,max((ll)(x.a*y.a>=p),max(x.b,y.b))};
	}
};

node fpow(node x,ll k,ll mod)
{
	node ans=(node){1,0};
	if (x.a>=mod) x.a%=mod,ans.b=1;
	for (;k;k>>=1,x=mul(x,x,mod))
		if (k&1) ans=mul(ans,x,mod);
	return ans;
}

struct BIT
{
	ll c[N];
	
	void add(int x,int v)
	{
		for (int i=x;i<=n;i+=i&-i)
			c[i]+=v;
	}
	
	ll query(int x)
	{
		ll ans=0;
		for (int i=x;i;i-=i&-i)
			ans+=c[i];
		return ans;
	}
}bit;

node query(int l,int r,ll p)
{
	ll x=bit.query(l);
	if (p==1) return (node){0,1};
	if (l==r) return (node){x%p,(ll)(x>=p)};
	node res=query(l+1,r,phi[p]);
	return fpow((node){x,0},res.a+res.b*phi[p],p);
}

int main()
{
	findprm(M-1);
	n=read(); Q=read();
	for (int i=1;i<=n;i++) 
	{
		int x=read();
		bit.add(i,x); bit.add(i+1,-x);
	}
	while (Q--)
	{
		int opt=read(),l=read(),r=read(),p=read();
		if (opt==1)
			bit.add(l,p),bit.add(r+1,-p);
		else
			cout<<query(l,r,p).a<<"\n";
	}
	return 0;
}
posted @ 2021-09-12 21:46  stoorz  阅读(52)  评论(0编辑  收藏  举报