HDU 6189 Law of Commutation
题意:
给出\(n,a\)解满足\(a^b\equiv b^a(mod~~2^n)\)的\(b的整数解个数\)
解题思路:
分奇偶讨论\(a\),可以发现在该模条件下\(a\)与\(b\)同奇偶,当\(a\)为奇数的时候,暴力跑了下大概只有在\(b=a\)的情况下等式成立
那么讨论在\(a,b\)为偶数的情况
当\(a\geq n\)时,显然\(b^a\)为\(2^n\)的倍数,那么\(a^b\equiv b^a\equiv 0(mod~~2^n)\)
再分类讨论一波,当\(b\geq n\)时,所有的偶数\(b\)必然满足条件,对于其他的\(b<n\)暴力跑就好了,毕竟\(n\)只有30
当\(a<n\)时,再次讨论\(b\)的情况,当\(b\geq n\)时,要求的是\(b^a\equiv0(mod~2^n)\)的方案数
令\(b^a=k2^n,k\in Z^+\)
那么每个\(b\)至少包含\(\lceil\frac{n}{a}\rceil\)个\(2\)的因子,即\(b\)为\(2^{\lceil\frac{n}{a}\rceil}\)的倍数求在区间\(n\leq b\leq2^n\)的方案数即可,\(b<n\)仍旧暴力
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
using namespace __gnu_pbds;
using namespace std;
// freopen("k.in", "r", stdin);
// freopen("k.out", "w", stdout);
// clock_t c1 = clock();
// std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
mt19937 rnd(time(NULL));
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef pair<char, char> PCC;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const int MAXN = 1e6 + 7;
const int MAXM = 4e5 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-7;
const double pi = acos(-1.0);
ll quick_pow(ll a, ll b, ll mod)
{
ll ans = 1;
while (b)
{
if (b & 1)
ans = (1LL * ans * a) % mod;
a = (1LL * a * a) % mod;
b >>= 1;
}
return ans;
}
int main()
{
ll n, a;
while (~scanf("%lld%lld", &n, &a))
{
ll ans = 0;
ll mod = 1 << n;
if (a & 1)
printf("1\n");
else if (a >= n)
{
ans += mod / 2 - (n - 1) / 2;
for (int i = 2; i < n; i += 2)
if (quick_pow(a, i, mod) == quick_pow(i, a, mod))
ans++;
printf("%lld\n", ans);
}
else
{
ll temp = quick_pow(2, ceil(1.0 * n / a), mod);
ans += mod / temp - (n - 1) / temp;
for (int i = 2; i < n; i += 2)
if (quick_pow(a, i, mod) == quick_pow(i, a, mod))
ans++;
printf("%lld\n", ans);
}
}
return 0;
}