【SCOI2018】—Numazu 的蜜柑(二次剩余)
题面
Numazu 是一座坐落在太平洋西岸的美丽海滨小镇,蜜柑是这里的特产, Chika 最喜欢蜜柑了!这天她来到一棵蜜柑树下,发现这棵树由 个结点组成(1 号结 点为根节点),每个结点都生长了一些蜜柑,结点 的蜜柑数量记为 。她想问 聪明的你这样一个问题,有多少顶点对满足:
她答应送给你好多好多蜜柑作为回礼
考虑化方程
现在就是要解在意义下的二次剩余
用解一下
在树上跑一遍就完了
特判一下为的情况
#include<bits/stdc++.h>
#include<tr1/unordered_map>
using namespace std;
#define gc getchar
#define int long long
inline int read(){
char ch=gc();
int res=0,f=1;
while(!isdigit(ch))f^=ch=='-',ch=gc();
while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
return f?res:-res;
}
#define re register
#define pb push_back
#define cs const
#define pii pair<int,int>
#define fi first
#define se second
#define ll long long
int mod;
inline int add(int a,int b){return (a+=b)>=mod?a-mod:a;}
inline void Add(int &a,int b){(a+=b)>=mod?(a-=mod):0;}
inline int dec(int a,int b){return (a-=b)<0?a+mod:a;}
inline void Dec(int &a,int b){(a-=b)<0?(a+=mod):0;}
inline int mul(int a,int b){return (a*b-(ll)((long double)a/mod*b)*mod+mod)%mod;}
inline void Mul(int &a,int b){a=mul(a,b);}
inline int ksm(int a,int b,int res=1){
for(;b;b>>=1,a=mul(a,a))(b&1)&&(res=mul(res,a));return res;
}
cs int N=100005;
int n,A,B,ans,k1,k2;
vector<int> e[N];
int a[N];
tr1::unordered_map<ll,int> buc;
namespace Cipolla{
int w;
struct plx{
int x,y;
plx(int _x=1,int _y=0):x(_x),y(_y){}
friend inline plx operator *(cs plx &a,cs plx &b){
return plx(add(mul(a.x,b.x),mul(mul(a.y,b.y),w)),add(mul(a.x,b.y),mul(a.y,b.x)));
}
};
inline plx ksm(plx a,int b){
plx res=plx();
for(;b;b>>=1,a=a*a)if(b&1)res=res*a;
return res;
}
inline int ksm(int a,int b,int res=1){
for(;b;b>>=1,a=mul(a,a))(b&1)&&(res=mul(res,a));return res;
}
inline int solve(int a){
if(a==0)return 0;
if(ksm(a,(mod-1)/2)==mod-1)return -1;
int b;
while(1){
b=rand();
w=dec(mul(b,b),a);
if(ksm(w,(mod-1)/2)==mod-1)break;
}
return ksm(plx(b,1),(mod+1)/2).x;
}
}
void dfs1(int u){
ans+=buc[a[u]];
int f1=mul(a[u],k1),f2=mul(a[u],k2);
if(f1==f2)buc[f1]++;
else buc[f1]++,buc[f2]++;
for(int &v:e[u])
dfs1(v);
if(f1==f2)buc[f1]--;
else buc[f1]--,buc[f2]--;
}
int tot;
void dfs2(int u){
if(a[u]==0)ans+=tot,tot++;
for(int &v:e[u])dfs2(v);
if(a[u]==0)tot--;
}
signed main(){
srand(time(NULL));
n=read(),mod=read(),A=read(),B=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=2;i<=n;i++){
int u=read();e[u].pb(i);
}
int k=Cipolla::solve(dec(mul(A,A),mul(4,B))),inv2=ksm(2,mod-2);
if(k==-1)dfs2(1);
else{
k1=mul(dec(k,A),inv2),k2=mul(mod-add(k,A),inv2);
dfs1(1);
cout<<ans;
}