Skip to content

Commit 65e66e7

Browse files
committed
xor convolution + test
1 parent b6d899b commit 65e66e7

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

cp-algo/math/xor_convolution.hpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#ifndef CP_ALGO_MATH_XOR_CONVOLUTION_HPP
2+
#define CP_ALGO_MATH_XOR_CONVOLUTION_HPP
3+
#include "../number_theory/modint.hpp"
4+
#include "../util/bit.hpp"
5+
#include "../util/checkpoint.hpp"
6+
#include <cassert>
7+
#include <algorithm>
8+
#include <vector>
9+
10+
namespace cp_algo::math {
11+
// Recursive FWHT (XOR) transform for size N (power of two)
12+
template<auto N>
13+
void xor_transform(auto &&a) {
14+
if constexpr (N == 1) {
15+
return;
16+
} else {
17+
constexpr auto half = N / 2;
18+
xor_transform<half>(&a[0]);
19+
xor_transform<half>(&a[half]);
20+
for (uint32_t i = 0; i < half; i++) {
21+
auto x = a[i] + a[i + half];
22+
auto y = a[i] - a[i + half];
23+
a[i] = x;
24+
a[i + half] = y;
25+
}
26+
}
27+
}
28+
29+
// FWHT wrapper that deduces N at compile time via with_bit_floor
30+
inline void xor_transform(auto &&a, auto n) {
31+
with_bit_floor(n, [&]<auto NN>() {
32+
assert(NN == n);
33+
xor_transform<NN>(a);
34+
});
35+
}
36+
37+
inline void xor_transform(auto &&a) {
38+
xor_transform(a, std::size(a));
39+
}
40+
41+
// In-place XOR convolution on sequences of equal length (power of two)
42+
void xor_convolution_inplace(auto &a, auto &b) {
43+
auto N = static_cast<uint32_t>(std::size(a));
44+
xor_transform(a);
45+
xor_transform(b);
46+
checkpoint("transform");
47+
for (uint32_t i = 0; i < N; i++) {
48+
a[i] *= b[i];
49+
}
50+
checkpoint("dot");
51+
xor_transform(a);
52+
checkpoint("transform");
53+
using base = std::decay_t<decltype(a[0])>;
54+
base ni = base(N).inv();
55+
for (auto &it : a) {
56+
it *= ni;
57+
}
58+
checkpoint("mul_inv");
59+
}
60+
61+
// Returns XOR convolution of a and b; pads to next power of two
62+
auto xor_convolution(auto a, auto b) {
63+
auto n = std::bit_ceil(std::max(std::size(a), std::size(b)));
64+
a.resize(n);
65+
b.resize(n);
66+
xor_convolution_inplace(a, b);
67+
return a;
68+
}
69+
}
70+
#endif // CP_ALGO_MATH_XOR_CONVOLUTION_HPP
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// @brief Bitwise Xor Convolution
2+
#define PROBLEM "https://judge.yosupo.jp/problem/bitwise_xor_convolution"
3+
#pragma GCC optimize("O3,unroll-loops")
4+
#include <bits/allocator.h>
5+
#pragma GCC target("avx2")
6+
#include <iostream>
7+
#include "blazingio/blazingio.min.hpp"
8+
#define CP_ALGO_CHECKPOINT
9+
#include "cp-algo/number_theory/modint.hpp"
10+
#include "cp-algo/util/big_alloc.hpp"
11+
#include "cp-algo/util/checkpoint.hpp"
12+
#include "cp-algo/math/xor_convolution.hpp"
13+
#include <bits/stdc++.h>
14+
15+
using namespace std;
16+
17+
const int mod = 998244353;
18+
using base = cp_algo::math::modint<mod>;
19+
20+
void solve() {
21+
uint32_t n;
22+
cin >> n;
23+
uint32_t N = 1u << n;
24+
cp_algo::big_vector<base> a(N), b(N);
25+
for (auto &it : a) {cin >> it;}
26+
for (auto &it : b) {cin >> it;}
27+
cp_algo::checkpoint("read");
28+
cp_algo::math::xor_convolution_inplace(a, b);
29+
for (auto it : a) {cout << it << ' ';}
30+
cp_algo::checkpoint("write");
31+
cp_algo::checkpoint<1>();
32+
}
33+
34+
signed main() {
35+
ios::sync_with_stdio(0);
36+
cin.tie(0);
37+
int t = 1;
38+
while (t--) {
39+
solve();
40+
}
41+
}

0 commit comments

Comments
 (0)