题解 方程
前三个小时都有点神游导致没发现这题是水题
首先化式子
令 \(x=\frac{a+b}{c}\),则要求 \(-\frac{x^3+1}{x}\equiv t\)
这个可以枚举 \(x\) 开桶判断
然后 \(a+b\) 可以 ntt 卷积预处理出来
接下来赛时就只会对每个 \(x\) 枚举一遍 \(c\) check 了
然而移项发现 \(a+b\equiv x\times c\),所以这是个乘法卷积
- 关于乘法/除法卷积:原根转化一下就变成了加法/减法卷积
于是就直接做了,复杂度 \(O(n\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int p, n, m;
int s[N], t[N];
const ll mod=998244353, rt=3, phi=mod-1;
inline ll qpow(ll a, ll b, ll mod=p) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
ll ans;
unordered_map<int, bool> mp;
inline int calc(int a, int b, int c) {
return (( ( (a+c)*(a-c)+b*(2*a+b) )%p*qpow((a*c+b*c)%p, p-2)%p - ( (c+a+b)*(a+b)+c*c )*qpow(c*c%p, p-2)%p )%p+p)%p;
}
void solve() {
for (int i=1; i<=m; ++i) mp[t[i]]=1;
for (int a=1; a<=n; ++a)
for (int b=1; b<=n; ++b)
for (int c=1; c<=n; ++c)
if ((s[a]*s[c]+s[b]*s[c])%p!=0 && s[c]*s[c]%p!=0 && mp.find(calc(s[a], s[b], s[c]))!=mp.end())
++ans;
cout<<ans<<endl;
}
}
namespace task1{
ll f[N], ans;
unordered_map<int, bool> mp;
void solve() {
for (int i=1; i<=m; ++i) mp[t[i]]=1;
for (int i=1; i<=n; ++i)
for (int j=1; j<=n; ++j)
++f[(s[i]+s[j])%p];
const int neg_one=(-1%p+p)%p;
for (int x=1; x<p; ++x) {
// for (int i=1; i<=m; ++i) if ( (( x*x%p*x%p+(t[i]+1)*x%p )%p+p)%p == (-1%p+p)%p) goto jump;
// for (int i=1; i<=m; ++i) if ( (x*x*x+(t[i]+1)*x)%p == neg_one) goto jump;
ll tem = ((-(x*x*x%p+1)*qpow(x, p-2)%p-1)%p+p)%p ;
if (mp.find(tem)==mp.end()) continue;
// cout<<"x: "<<x<<endl;
for (int c=1; c<=n; ++c) if (s[c]*x%p) {
// cout<<"c: "<<c<<' '<<f[s[c]*x%p]<<endl;
ans=(ans+f[s[c]*x%p])%mod;
}
}
cout<<ans<<endl;
}
}
namespace task2{
ll f[N], g[N], h[N], rt, ans;
int rev[N], div[N], tr[N], dcnt, bln, bct;
unordered_map<int, bool> mp;
void ntt(ll* a, int len, int op) {
for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
ll w, wn, t;
for (int i=1; i<len; i<<=1) {
wn=qpow(3, (op*phi/(i<<1)+phi)%phi, mod);
for (int j=0,step=i<<1; j<len; j+=step) {
w=1;
for (int k=j; k<j+i; ++k,w=w*wn%mod) {
t=w*a[k+i]%mod;
a[k+i]=(a[k]-t+mod)%mod;
a[k]=(a[k]+t)%mod;
}
}
}
if (op==-1) {
ll inv=qpow(len, mod-2, mod);
for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
}
}
void divide(int n) {
int m=n;
for (int i=2; i*i<=m; ++i) if (n%i==0) {
div[++dcnt]=i;
do {n/=i;} while (n%i==0);
}
}
bool isrt(int t) {
for (int i=1; i<=dcnt; ++i) if (qpow(t, (p-1)/div[i])==1) return 0;
return 1;
}
int getrt() {for (int i=1; ; ++i) if (__gcd(i, p)==1&&isrt(i)) return i;}
void solve() {
for (int i=1; i<=m; ++i) mp[t[i]]=1;
divide(p-1); rt=getrt();
for (int i=0,t=1; i<p-1; ++i,t=t*rt%p) tr[t]=i;
for (int i=1; i<=n; ++i) ++f[s[i]];
for (bln=1; bln<=p*2; bln<<=1,++bct) ;
for (int i=0; i<bln; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct-1));
ntt(f, bln, 1);
for (int i=0; i<bln; ++i) f[i]=f[i]*f[i]%mod;
ntt(f, bln, -1);
for (int i=p; i<bln; ++i) f[i%p]=(f[i%p]+f[i])%mod;
const int neg_one=(-1%p+p)%p;
for (int x=1; x<p; ++x) {
ll tem = ((-(x*x*x%p+1)*qpow(x, p-2)%p-1)%p+p)%p ;
if (mp.find(tem)!=mp.end()) g[tr[x]]=1; //, cerr<<"x: "<<x<<endl;
}
// cout<<"tr: "; for (int i=0; i<p; ++i) cout<<tr[i]<<' '; cout<<endl;
for (int i=1; i<=n; ++i) if (s[i]) ++h[tr[s[i]]];
// cout<<"i: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<i<<' '; cout<<endl;
// cout<<"g: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<g[i]<<' '; cout<<endl;
// cout<<"h: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<h[i]<<' '; cout<<endl;
ntt(g, bln, 1); ntt(h, bln, 1);
for (int i=0; i<bln; ++i) g[i]=g[i]*h[i]%mod;
ntt(g, bln, -1);
// cout<<"g: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<g[i]<<' '; cout<<endl;
for (int i=p-1; i<bln; ++i) g[i%(p-1)]=(g[i%(p-1)]+g[i])%mod;
for (int i=1; i<p; ++i) ans=(ans+f[i]*g[tr[i]])%mod;
cout<<ans<<endl;
}
}
signed main()
{
freopen("equation.in", "r", stdin);
freopen("equation.out", "w", stdout);
p=read(); n=read(); m=read();
for (int i=1; i<=n; ++i) s[i]=read();
for (int i=1; i<=m; ++i) t[i]=read();
// if (n<=300) force::solve();
// else if (n<=4000) task1::solve();
// else task2::solve();
task2::solve();
return 0;
}