题解 三元组
卡常题,用了线段树T成和暴力一个分
- 所以只有区间加减的话能树状数组差分就别线段树lazy tag了
发现如果枚举c,可行的b每次只多一个,多出来的这个数只能与 \([1, i]\) 形成匹配
所以每次考虑新加的数的贡献即可,转化为在 \([0, mod-1]\) 值域上的区间加
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, mod;
namespace force{
ll solve() {
ll ans=0;
for (ll c=1; c<=n; ++c)
for (ll b=1; b<=c; ++b)
for (ll a=1; a<=b; ++a)
if ((a+b*b)%mod == c*c%mod*c%mod) ++ans;
return ans;
}
}
namespace task1{
int buc[N];
ll solve() {
ll ans=0;
memset(buc, 0, sizeof(buc));
for (ll i=1; i<=n; ++i) {
ll t=i*i%mod;
for (ll j=1; j<=i; ++j) ++buc[(t+j)%mod];
ans+=buc[i*i%mod*i%mod];
}
return ans;
}
}
namespace task2{
#if 0
int tl[N<<2], tr[N<<2];
int dat[N<<2], tag[N<<2];
#define tl(p) tl[p]
#define tr(p) tr[p]
#define dat(p) dat[p]
#define tag(p) tag[p]
#define pushup(p) dat(p)=dat(p<<1)+dat(p<<1|1)
void spread(int p) {
if (!tag(p)) return ;
dat(p<<1)+=tag(p); tag(p<<1)+=tag(p);
dat(p<<1|1)+=tag(p); tag(p<<1|1)+=tag(p);
tag(p)=0;
}
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r; dat(p)=0; tag(p)=0;
if (l==r) return ;
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
}
void upd(int p, int l, int r, int val) {
if (l<=tl(p) && r>=tr(p)) {dat(p)+=val; tag(p)+=val; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) upd(p<<1, l, r, val);
if (r>mid) upd(p<<1|1, l, r, val);
pushup(p);
}
int query(int p, int pos) {
if (tl(p)==tr(p)) return dat(p);
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) return query(p<<1, pos);
else return query(p<<1|1, pos);
}
#else
ll bin[N];
inline void upd(int i, ll dat) {++i; for (; i<=mod+10; i+=i&-i) bin[i]+=dat;}
inline ll query(int i) {++i; ll ans=0; for (; i; i-=i&-i) ans+=bin[i]; return ans;}
#endif
ll solve() {
ll ans=0;
// memset(tl, 0, sizeof(tl));
// memset(tr, 0, sizeof(tr));
// memset(dat, 0, sizeof(dat));
// memset(tag, 0, sizeof(tag));
// build(1, 0, mod-1);
memset(bin, 0, sizeof(bin));
for (int i=1; i<=n; ++i) {
int t=1ll*i*i%mod;
if (t+i<=mod-1) {
// upd(1, t+1, t+i, 1);
upd(t+1, 1);
upd(t+i+1, -1);
}
else {
if (t+1<=mod-1) {
// upd(1, t+1, mod-1, 1);
upd(t+1, 1);
// upd(1, mod, -1);
}
int len=(t+i)-mod+1;
// upd(1, 0, mod-1, len/mod);
upd(0, len/mod);
if (len%mod) {
// upd(1, 0, len%mod-1, 1);
upd(0, 1);
upd(len%mod, -1);
}
}
ans+=query(1ll*i*i%mod*i%mod);
}
return ans;
}
}
signed main()
{
freopen("exclaim.in", "r", stdin);
freopen("exclaim.out", "w", stdout);
int T=read();
for (int i=1; i<=T; ++i) {
n=read(); mod=read();
// printf("Case %d: %lld\n", i, force::solve());
// printf("Case %d: %lld\n", i, task1::solve());
printf("Case %d: %lld\n", i, task2::solve());
}
return 0;
}