AtCoder Nim
題意
給定 \(N (10^{18})\), \(A (10)\), \(B (10)\), \(C (10)\),
求有多少個 \(tuple(a, b, c)\) 符合
- \(a, b, c\) 分別為 \(A, B, C\) 的倍數。
- \(a, b, c\)是正整數且不超過\(N\), \(1 <= a, b, c <= N\)
- 三數xor值為零,\(a \otimes b \otimes c = 0\)
分析
看到ABC的大小大概知道要用數位dp,
但想了20分鐘還是無從下手,感覺有做過 \(xor\) 類的,有做過倍數類的,但都很不熟😵。\
\(=>\) 題解局
這波數位dp需要知道
- Bit 數,因為要 \(xor\),因此是使用二進位而非十進位。
- 倍數 -> 用餘數為零來處理。
- 是否抵達上限 \(N\),as usual,數位dp老招。
三數都不可為零,因此需要特判製作完一個數字時,此數是否為零。
再加三維來判斷。
至於此時的三數 \(xor\) 是否為零,則是簡單的貪心。
每次製作新的 bit 時,都保證三數 \(xor\) 為零,則最後必為零。
因此製作方法有
- 全零,\(bit_a = 0, bit_b = 0, bit_c = 0\)
- AB當前位數設一,\(bit_a = 1, bit_b = 1, bit_c = 0\)
- AC當前位數設一,\(bit_a = 1, bit_b = 0, bit_c = 1\)
- BC當前位數設一,\(bit_a = 0, bit_b = 1, bit_c = 1\)
實作
在遞迴函式中,檢查製作完成時是否
- 各自是倍數
- 均為正整數
1
2
3
4
5
6
7
8
9
10
11
12
13
14
int digit_dfs(int cur_bit, int amod, int bmod, int cmod,
int azero, int bzero, int czero, int alim, int blim, int clim) {
if (cur_bit == -1) {
return !amod and !bmod and !cmod and !azero and !bzero and !czero;
}
int& res = dp[cur_bit][amod][bmod][cmod][azero][bzero][czero][alim][blim][clim];
if (res != -1) return res;
/*
recursion part...
*/
return res;
}
遞迴的 part ,我們就任填 \(01\) 給 \(a, b, c\),
- \(lim\) 表示截至目前的製作,是否都一直貼和上界。
如果是,則不可填超過當前的上界,否則可隨意填。 - \(mod\) 表示各自的餘數。當前填的數,會讓下次遞迴的餘數狀態改變。
- \(zero\) 表示當前的數是否還是全填 \(0\)。
如此便可完成此題
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int kMod = 998244353;
int dp[64][10][10][10][2][2][2][2][2][2];
int n, a, b, c;
int digit_dfs(int cur_bit, int amod, int bmod, int cmod,
int azero, int bzero, int czero, int alim, int blim, int clim) {
if (cur_bit == -1) {
return !amod and !bmod and !cmod and !azero and !bzero and !czero;
}
int& res = dp[cur_bit][amod][bmod][cmod][azero][bzero][czero][alim][blim][clim];
if (res != -1) return res;
res = 0;
int lim = (n >> cur_bit) & 1;
for (int abit = 0; abit <= 1; abit++) {
for (int bbit = 0; bbit <= 1; bbit++) {
for (int cbit = 0; cbit <= 1; cbit++) {
if ((alim and abit) > lim) continue;
if ((blim and bbit) > lim) continue;
if ((clim and cbit) > lim) continue;
if (abit ^ bbit ^ cbit) continue;
int pamod = (amod + (abit << cur_bit)) % a;
int pbmod = (bmod + (bbit << cur_bit)) % b;
int pcmod = (cmod + (cbit << cur_bit)) % c;
int pazero = azero and !abit;
int pbzero = bzero and !bbit;
int pczero = czero and !cbit;
int palim = alim and abit == lim;
int pblim = blim and bbit == lim;
int pclim = clim and cbit == lim;
res += digit_dfs(cur_bit - 1, pamod, pbmod, pcmod,
pazero, pbzero, pczero, palim, pblim, pclim);
res %= kMod;
}
}
}
return res;
}
int32_t main() {
cin.tie(nullptr)->sync_with_stdio(false);
memset(dp, -1LL, sizeof(dp));
cin >> n >> a >> b >> c;
cout << digit_dfs(63, 0, 0, 0, 1, 1, 1, 1, 1, 1) << '\n';
return 0;
}
This post is licensed under CC BY 4.0 by the author.