AtCoder Zero or One
題意
給定 \(2<=N<=10^{18}\) 。求共有多少個不同的 \(base\),
使得將 \(N\) 轉換為 \(base\) 進位時,每個位數均為 \(0\) 或 \(1\) ?
分析
一樣想20分鐘想不出來🥲,想了暴搜或是遞迴的方法,都會大超時。
-> 題解局
對於不同 level 的 \(base\) 使用不同的方式檢查。(算是數論分塊?)
- 若 \(base\) 很小,因為產生的位數很多(二進位最多63位),因此暴力檢查
\(N\) 轉換為 \(base\) 進位是否都是 \(0, 1\) 。 - 若 \(base\) 很大,因為產生的位數很少(一千進位最多7位),因此使用 \(bitmask\) 這七位數
再二分搜索是否有合理的 \(base\) 可以產生 \(N\)。
實作
檢查 \(N\) 轉換為 \(base\) 時,是否都是 \(0, 1\)。
Time complexity: \(O(log(n))\)
1
2
3
4
5
6
7
bool check_base_match(int n, int base) {
while (n) {
if (n % base > 1) return false;
n /= base;
}
return true;
}
根據 \(mask\) 和 猜測的 \(base\) 轉換為十進位數。
因為 \(mask\) 最多只有二進位七位,因此此function可當 \(O(1)\)。
1
2
3
4
5
6
7
8
9
10
int transform(int base, int mask) {
__uint128_t res = 0, p = 1;
while (mask) {
res += p * (mask % 2);
if (res > 1e18) return -1;
mask /= 2;
p *= base;
}
return res;
}
猜測 \(N\) 是由多大的 \(base\) 製作。
Time complexity: \(O(lg(n))\)
1
2
3
4
5
6
7
8
9
10
11
bool check_mask_match(int n, int mask) {
int left = 1001, right = 1e18, mid;
while (left <= right) {
mid = left + (right - left) / 2;
int res = transform(mid, mask);
if (res == -1 || res > n) right = mid - 1;
else if (res < n) left = mid + 1;
else return true;
}
return false;
}
解題函式,測 1000 以下的 \(base\) 用暴搜。 測 1000 以上的 \(base\) 用 bitmask + binary search。
Time complexity: \(O(1000 * lg(n)) + O(1^7 * lg(n))\)
1
2
3
4
5
6
7
8
9
10
int zuha(int n) {
int res = 0;
for (int i = 2; i <= 1000; i++) {
if (check_base_match(n, i)) res++;
}
for (int mask = 1; mask < (1 << 6); mask++) {
if (check_mask_match(n, mask)) res++;
}
return res;
}
最多有 1000 筆測資,因此約為
\(O(1000 * 2000 * lg (1^{18})) = O(128 * 10^{6})\)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <bits/stdc++.h>
#define int long long
using namespace std;
bool check_base_match(int n, int base) {
while (n) {
if (n % base > 1) return false;
n /= base;
}
return true;
}
int transform(int base, int mask) {
__uint128_t res = 0, p = 1;
while (mask) {
res += p * (mask % 2);
if (res > 1e18) return -1;
mask /= 2;
p *= base;
}
return res;
}
bool check_mask_match(int n, int mask) {
int left = 1001, right = 1e18, mid;
while (left <= right) {
mid = left + (right - left) / 2;
int res = transform(mid, mask);
if (res == -1 || res > n) right = mid - 1;
else if (res < n) left = mid + 1;
else return true;
}
return false;
}
int zuha(int n) {
int res = 0;
for (int i = 2; i <= 1000; i++) {
if (check_base_match(n, i)) res++;
}
for (int mask = 1; mask < (1 << 6); mask++) {
if (check_mask_match(n, mask)) res++;
}
return res;
}
int32_t main() {
cin.tie(nullptr)->sync_with_stdio(false);
int t, n;
cin >> t;
while (t--) {
cin >> n;
cout << zuha(n) << '\n';
}
return 0;
}
This post is licensed under CC BY 4.0 by the author.