AT_abc317_f 题解
调了一小时结果发现爆 long long 了。
考虑数位 dp,具体来说,设计状态 \(dp_{i,r_1,r_2,r_3,mx_1,mx_2,mx3_,c_1,c_2,c_3}\) 表示当前考虑到第 \(i\) 位,\(x_1,x_2,x_3\) 模 \(a_1,a_2,a_3\) 等于 \(r_1,r_2,r_3\) 三个数是否达到 \(n\) 的上界以及是否全部是 \(0\)。
然后从高到低枚举三个数这一位填什么(满足异或为 \(0\) 一限制)往下转移即可。
为了写起来方便这里写的是记忆化搜索。
#include<bits/stdc++.h>
#define int long long
//#define lowbit(x) (x&-(x))
using namespace std;
//const int maxn =
const int mod = 998244353;
vector<int> dight;
int dp[65][10][10][10][2][2][2][2][2][2];
int a1,a2,a3;
int solve(int p,int r1,int r2,int r3,bool mx1,bool mx2,bool mx3,bool c1,bool c2,bool c3){//当前考虑到第 p 位且 x1,x2,x3 模 a1,a2,a3 为 r1,r2,r3 是否到达上界 是否全部填 0
if(dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]!=-1) return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3];
if(p==0){
if(r1==0&&r2==0&&r3==0&&c1==false&&c2==false&&c3==false){
return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=1;
}
else{
return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=0;
}
}
int res=0;
for(int i=0;i<=(mx1==true?dight[p-1]:1);i++){
for(int j=0;j<=(mx2==true?dight[p-1]:1);j++){
for(int k=0;k<=(mx3==true?dight[p-1]:1);k++){
if((i^j)^k==0){
res+=solve(p-1,(a1+r1+((1ll<<(p-1))*i)%a1)%a1,(a2+r2+((1ll<<(p-1))*j)%a2)%a2,(a3+r3+((1ll<<(p-1))*k)%a3)%a3,mx1&(i==dight[p-1]),mx2&(j==dight[p-1]),mx3&(k==dight[p-1]),c1&&(i==0),c2&&(j==0),c3&&(k==0));
res%=mod;
}
}
}
}
return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=res;
}
int n;
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
for(int i=0;i<65;i++)
for(int r1=0;r1<10;r1++)
for(int r2=0;r2<10;r2++)
for(int r3=0;r3<10;r3++)
for(int mx1=0;mx1<2;mx1++)
for(int mx2=0;mx2<2;mx2++)
for(int mx3=0;mx3<2;mx3++)
for(int c1=0;c1<2;c1++)
for(int c2=0;c2<2;c2++)
for(int c3=0;c3<2;c3++) dp[i][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=-1;
cin>>n;
cin>>a1>>a2>>a3;
while(n>0) dight.push_back(n%2),n/=2;
cout<<solve(dight.size(),0,0,0,true,true,true,true,true,true)<<'\n';
return 0;
}