题解 方程

传送门

前三个小时都有点神游导致没发现这题是水题

首先化式子
\(x=\frac{a+b}{c}\),则要求 \(-\frac{x^3+1}{x}\equiv t\)
这个可以枚举 \(x\) 开桶判断
然后 \(a+b\) 可以 ntt 卷积预处理出来
接下来赛时就只会对每个 \(x\) 枚举一遍 \(c\) check 了
然而移项发现 \(a+b\equiv x\times c\),所以这是个乘法卷积

  • 关于乘法/除法卷积:原根转化一下就变成了加法/减法卷积

于是就直接做了,复杂度 \(O(n\log n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int p, n, m;
int s[N], t[N];
const ll mod=998244353, rt=3, phi=mod-1;
inline ll qpow(ll a, ll b, ll mod=p) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	ll ans;
	unordered_map<int, bool> mp;
	inline int calc(int a, int b, int c) {
		return (( ( (a+c)*(a-c)+b*(2*a+b) )%p*qpow((a*c+b*c)%p, p-2)%p - ( (c+a+b)*(a+b)+c*c )*qpow(c*c%p, p-2)%p )%p+p)%p;
	}
	void solve() {
		for (int i=1; i<=m; ++i) mp[t[i]]=1;
		for (int a=1; a<=n; ++a)
			for (int b=1; b<=n; ++b)
				for (int c=1; c<=n; ++c)
					if ((s[a]*s[c]+s[b]*s[c])%p!=0 && s[c]*s[c]%p!=0 && mp.find(calc(s[a], s[b], s[c]))!=mp.end())
						++ans;
		cout<<ans<<endl;
	}
}

namespace task1{
	ll f[N], ans;
	unordered_map<int, bool> mp;
	void solve() {
		for (int i=1; i<=m; ++i) mp[t[i]]=1;
		for (int i=1; i<=n; ++i)
			for (int j=1; j<=n; ++j)
				++f[(s[i]+s[j])%p];
		const int neg_one=(-1%p+p)%p;
		for (int x=1; x<p; ++x) {
			// for (int i=1; i<=m; ++i) if ( (( x*x%p*x%p+(t[i]+1)*x%p )%p+p)%p == (-1%p+p)%p) goto jump;
			// for (int i=1; i<=m; ++i) if ( (x*x*x+(t[i]+1)*x)%p == neg_one) goto jump;
			ll tem = ((-(x*x*x%p+1)*qpow(x, p-2)%p-1)%p+p)%p ;
			if (mp.find(tem)==mp.end()) continue;
			// cout<<"x: "<<x<<endl;
			for (int c=1; c<=n; ++c) if (s[c]*x%p) {
				// cout<<"c: "<<c<<' '<<f[s[c]*x%p]<<endl;
				ans=(ans+f[s[c]*x%p])%mod;
			}
		}
		cout<<ans<<endl;
	}
}

namespace task2{
	ll f[N], g[N], h[N], rt, ans;
	int rev[N], div[N], tr[N], dcnt, bln, bct;
	unordered_map<int, bool> mp;
	void ntt(ll* a, int len, int op) {
		for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
		ll w, wn, t;
		for (int i=1; i<len; i<<=1) {
			wn=qpow(3, (op*phi/(i<<1)+phi)%phi, mod);
			for (int j=0,step=i<<1; j<len; j+=step) {
				w=1;
				for (int k=j; k<j+i; ++k,w=w*wn%mod) {
					t=w*a[k+i]%mod;
					a[k+i]=(a[k]-t+mod)%mod;
					a[k]=(a[k]+t)%mod;
				}
			}
		}
		if (op==-1) {
			ll inv=qpow(len, mod-2, mod);
			for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
		}
	}
	void divide(int n) {
		int m=n;
		for (int i=2; i*i<=m; ++i) if (n%i==0) {
			div[++dcnt]=i;
			do {n/=i;} while (n%i==0);
		}
	}
	bool isrt(int t) {
		for (int i=1; i<=dcnt; ++i) if (qpow(t, (p-1)/div[i])==1) return 0;
		return 1;
	}
	int getrt() {for (int i=1; ; ++i) if (__gcd(i, p)==1&&isrt(i)) return i;}
	void solve() {
		for (int i=1; i<=m; ++i) mp[t[i]]=1;
		divide(p-1); rt=getrt();
		for (int i=0,t=1; i<p-1; ++i,t=t*rt%p) tr[t]=i;
		for (int i=1; i<=n; ++i) ++f[s[i]];
		for (bln=1; bln<=p*2; bln<<=1,++bct) ;
		for (int i=0; i<bln; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct-1));
		ntt(f, bln, 1);
		for (int i=0; i<bln; ++i) f[i]=f[i]*f[i]%mod;
		ntt(f, bln, -1);
		for (int i=p; i<bln; ++i) f[i%p]=(f[i%p]+f[i])%mod;
		const int neg_one=(-1%p+p)%p;
		for (int x=1; x<p; ++x) {
			ll tem = ((-(x*x*x%p+1)*qpow(x, p-2)%p-1)%p+p)%p ;
			if (mp.find(tem)!=mp.end()) g[tr[x]]=1; //, cerr<<"x: "<<x<<endl;
		}
		// cout<<"tr: "; for (int i=0; i<p; ++i) cout<<tr[i]<<' '; cout<<endl;
		for (int i=1; i<=n; ++i) if (s[i]) ++h[tr[s[i]]];
		// cout<<"i: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<i<<' '; cout<<endl;
		// cout<<"g: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<g[i]<<' '; cout<<endl;
		// cout<<"h: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<h[i]<<' '; cout<<endl;
		ntt(g, bln, 1); ntt(h, bln, 1);
		for (int i=0; i<bln; ++i) g[i]=g[i]*h[i]%mod;
		ntt(g, bln, -1);
		// cout<<"g: "; for (int i=0; i<bln; ++i) cout<<setw(2)<<g[i]<<' '; cout<<endl;
		for (int i=p-1; i<bln; ++i) g[i%(p-1)]=(g[i%(p-1)]+g[i])%mod;
		for (int i=1; i<p; ++i) ans=(ans+f[i]*g[tr[i]])%mod;
		cout<<ans<<endl;
	}
}

signed main()
{
	freopen("equation.in", "r", stdin);
	freopen("equation.out", "w", stdout);

	p=read(); n=read(); m=read();
	for (int i=1; i<=n; ++i) s[i]=read();
	for (int i=1; i<=m; ++i) t[i]=read();
	// if (n<=300) force::solve();
	// else if (n<=4000) task1::solve();
	// else task2::solve();
	task2::solve();

	return 0;
}
posted @ 2022-03-17 18:12  Administrator-09  阅读(1)  评论(0编辑  收藏  举报