AT_abc317_f 题解

调了一小时结果发现爆 long long 了。

考虑数位 dp,具体来说,设计状态 \(dp_{i,r_1,r_2,r_3,mx_1,mx_2,mx3_,c_1,c_2,c_3}\) 表示当前考虑到第 \(i\) 位,\(x_1,x_2,x_3\)\(a_1,a_2,a_3\) 等于 \(r_1,r_2,r_3\) 三个数是否达到 \(n\) 的上界以及是否全部是 \(0\)

然后从高到低枚举三个数这一位填什么(满足异或为 \(0\) 一限制)往下转移即可。

为了写起来方便这里写的是记忆化搜索。

#include<bits/stdc++.h>
#define int long long
//#define lowbit(x) (x&-(x))
using namespace std;
//const int maxn =
const int mod = 998244353;
vector<int> dight;
int dp[65][10][10][10][2][2][2][2][2][2];
int a1,a2,a3;
int solve(int p,int r1,int r2,int r3,bool mx1,bool mx2,bool mx3,bool c1,bool c2,bool c3){//当前考虑到第 p 位且 x1,x2,x3 模 a1,a2,a3 为 r1,r2,r3 是否到达上界 是否全部填 0
    if(dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]!=-1) return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3];
    if(p==0){ 
        if(r1==0&&r2==0&&r3==0&&c1==false&&c2==false&&c3==false){
            return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=1;
        }
        else{
            return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=0;
        }
    }
    int res=0;
    for(int i=0;i<=(mx1==true?dight[p-1]:1);i++){
        for(int j=0;j<=(mx2==true?dight[p-1]:1);j++){
            for(int k=0;k<=(mx3==true?dight[p-1]:1);k++){
                if((i^j)^k==0){
                    res+=solve(p-1,(a1+r1+((1ll<<(p-1))*i)%a1)%a1,(a2+r2+((1ll<<(p-1))*j)%a2)%a2,(a3+r3+((1ll<<(p-1))*k)%a3)%a3,mx1&(i==dight[p-1]),mx2&(j==dight[p-1]),mx3&(k==dight[p-1]),c1&&(i==0),c2&&(j==0),c3&&(k==0));
                    res%=mod;
                }
            }
        }
    }
    return dp[p][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=res;
}
int n;       
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    for(int i=0;i<65;i++)
        for(int r1=0;r1<10;r1++)
            for(int r2=0;r2<10;r2++)
                for(int r3=0;r3<10;r3++)
                    for(int mx1=0;mx1<2;mx1++)
                        for(int mx2=0;mx2<2;mx2++)
                            for(int mx3=0;mx3<2;mx3++) 
                                for(int c1=0;c1<2;c1++)
                                    for(int c2=0;c2<2;c2++)
                                        for(int c3=0;c3<2;c3++) dp[i][r1][r2][r3][mx1][mx2][mx3][c1][c2][c3]=-1;
    cin>>n;
    cin>>a1>>a2>>a3;
    while(n>0) dight.push_back(n%2),n/=2;
    cout<<solve(dight.size(),0,0,0,true,true,true,true,true,true)<<'\n';
    return 0;
}
posted @ 2024-02-27 18:13  ChiFAN鸭  阅读(12)  评论(0编辑  收藏  举报