C++ Library for Competitive Programming
#include "emthrm/math/convolution/mod_convolution.hpp"
剰余環 $\mathbb{Z} / m\mathbb{Z}$ 上で離散フーリエ変換を高速に行うアルゴリズムである。
特に $2^x \geq n$ を満たす $x, k \in \mathbb{N}$ を用いて表される素数 $p = 2^x k + 1$ は、$p$ の原始根 $\omega$ に対して
\[\omega^{p - 1} \equiv 1 \pmod{p}\]すなわち
\[(\omega^k)^{2^x} \equiv 1 \pmod{p}\]が成り立つので、条件を満たす。
$O(N\log{N})$
template <unsigned int T>
struct NumberTheoreticTransform;
名前 | 効果・戻り値 | 備考 |
---|---|---|
NumberTheoreticTransform(); |
コンストラクタ | |
template <typename U> std::vector<ModInt> dft(const std::vector<U>& a);
|
整数列 $A$ に対して数論変換を行ったもの | |
void idft(std::vector<ModInt>* a); |
$A$ に対して数論変換の逆変換を行う。 | |
template <typename U> std::vector<ModInt> convolution(const std::vector<U>& a, const std::vector<U>& b);
|
整数列 $A$ と $B$ の畳み込み | $\max_i{C_i} \leq (\max_i{A_i})(\max_i{B_i})(\min \lbrace \lvert A \rvert, \lvert B \rvert \rbrace)$ |
名前 | 説明 |
---|---|
ModInt |
MInt<T> |
名前 | 戻り値 | 要件 | 備考 |
---|---|---|---|
template <int PRECISION = 15, int T> std::vector<MInt<T>> mod_convolution(const std::vector<MInt<T>>& a, const std::vector<MInt<T>>& b);
|
$A$ と $B$ の畳み込み | $(\text{精度}) \geq \log_2{\sqrt{m}}$ でなければならない。 |
PRECISION は精度を表す。 |
e.g. $(\text{精度}) = 15$ のとき $m \leq 2^{30} = 1073741824$
任意の法の下での畳み込み
#ifndef EMTHRM_MATH_CONVOLUTION_MOD_CONVOLUTION_HPP_
#define EMTHRM_MATH_CONVOLUTION_MOD_CONVOLUTION_HPP_
#include <algorithm>
#include <bit>
#include <cmath>
#include <vector>
#include "emthrm/math/convolution/fast_fourier_transform.hpp"
#include "emthrm/math/modint.hpp"
namespace emthrm {
template <int PRECISION = 15, unsigned int T>
std::vector<MInt<T>> mod_convolution(const std::vector<MInt<T>>& a,
const std::vector<MInt<T>>& b) {
using ModInt = MInt<T>;
const int a_size = a.size(), b_size = b.size(), c_size = a_size + b_size - 1;
const int n = std::max(std::bit_ceil(static_cast<unsigned int>(c_size)), 2U);
constexpr int mask = (1 << PRECISION) - 1;
std::vector<fast_fourier_transform::Complex> x(n), y(n);
std::transform(
a.begin(), a.end(), x.begin(),
[mask](const MInt<T>& x) -> fast_fourier_transform::Complex {
return fast_fourier_transform::Complex(x.v & mask, x.v >> PRECISION);
});
fast_fourier_transform::dft(&x);
std::transform(
b.begin(), b.end(), y.begin(),
[mask](const MInt<T>& y) -> fast_fourier_transform::Complex {
return fast_fourier_transform::Complex(y.v & mask, y.v >> PRECISION);
});
fast_fourier_transform::dft(&y);
const int half = n >> 1;
fast_fourier_transform::Complex tmp_a = x.front(), tmp_b = y.front();
x.front() =
fast_fourier_transform::Complex(tmp_a.re * tmp_b.re, tmp_a.im * tmp_b.im);
y.front() =
fast_fourier_transform::Complex(
tmp_a.re * tmp_b.im + tmp_a.im * tmp_b.re, 0);
for (int i = 1; i < half; ++i) {
const int j = n - i;
const fast_fourier_transform::Complex a_l_i =
(x[i] + x[j].conj()).mul_real(0.5);
const fast_fourier_transform::Complex a_h_i =
(x[j].conj() - x[i]).mul_pin(0.5);
const fast_fourier_transform::Complex b_l_i =
(y[i] + y[j].conj()).mul_real(0.5);
const fast_fourier_transform::Complex b_h_i =
(y[j].conj() - y[i]).mul_pin(0.5);
const fast_fourier_transform::Complex a_l_j =
(x[j] + x[i].conj()).mul_real(0.5);
const fast_fourier_transform::Complex a_h_j =
(x[i].conj() - x[j]).mul_pin(0.5);
const fast_fourier_transform::Complex b_l_j =
(y[j] + y[i].conj()).mul_real(0.5);
const fast_fourier_transform::Complex b_h_j =
(y[i].conj() - y[j]).mul_pin(0.5);
x[i] = a_l_i * b_l_i + (a_h_i * b_h_i).mul_pin(1);
y[i] = a_l_i * b_h_i + a_h_i * b_l_i;
x[j] = a_l_j * b_l_j + (a_h_j * b_h_j).mul_pin(1);
y[j] = a_l_j * b_h_j + a_h_j * b_l_j;
}
tmp_a = x[half];
tmp_b = y[half];
x[half] = fast_fourier_transform::Complex(
tmp_a.re * tmp_b.re, tmp_a.im * tmp_b.im);
y[half] = fast_fourier_transform::Complex(
tmp_a.re * tmp_b.im + tmp_a.im * tmp_b.re, 0);
fast_fourier_transform::idft(&x);
fast_fourier_transform::idft(&y);
std::vector<ModInt> res(c_size);
const ModInt tmp1 = 1 << PRECISION, tmp2 = 1LL << (PRECISION << 1);
for (int i = 0; i < c_size; ++i) {
res[i] = tmp1 * std::llround(y[i].re) + tmp2 * std::llround(x[i].im)
+ std::llround(x[i].re);
}
return res;
}
} // namespace emthrm
#endif // EMTHRM_MATH_CONVOLUTION_MOD_CONVOLUTION_HPP_
#line 1 "include/emthrm/math/convolution/mod_convolution.hpp"
#include <algorithm>
#include <bit>
#include <cmath>
#include <vector>
#line 1 "include/emthrm/math/convolution/fast_fourier_transform.hpp"
#line 6 "include/emthrm/math/convolution/fast_fourier_transform.hpp"
#include <cassert>
#line 8 "include/emthrm/math/convolution/fast_fourier_transform.hpp"
#include <iterator>
#include <utility>
#line 11 "include/emthrm/math/convolution/fast_fourier_transform.hpp"
namespace emthrm {
namespace fast_fourier_transform {
using Real = double;
struct Complex {
Real re, im;
explicit Complex(const Real re = 0, const Real im = 0) : re(re), im(im) {}
inline Complex operator+(const Complex& x) const {
return Complex(re + x.re, im + x.im);
}
inline Complex operator-(const Complex& x) const {
return Complex(re - x.re, im - x.im);
}
inline Complex operator*(const Complex& x) const {
return Complex(re * x.re - im * x.im, re * x.im + im * x.re);
}
inline Complex mul_real(const Real r) const {
return Complex(re * r, im * r);
}
inline Complex mul_pin(const Real r) const {
return Complex(-im * r, re * r);
}
inline Complex conj() const { return Complex(re, -im); }
};
std::vector<int> butterfly{0};
std::vector<std::vector<Complex>> zeta{{Complex(1, 0)}};
void init(const int n) {
const int prev_n = butterfly.size();
if (n <= prev_n) return;
butterfly.resize(n);
const int prev_lg = zeta.size();
const int lg = std::countr_zero(static_cast<unsigned int>(n));
for (int i = 1; i < prev_n; ++i) {
butterfly[i] <<= lg - prev_lg;
}
for (int i = prev_n; i < n; ++i) {
butterfly[i] = (butterfly[i >> 1] >> 1) | ((i & 1) << (lg - 1));
}
zeta.resize(lg);
for (int i = prev_lg; i < lg; ++i) {
zeta[i].resize(1 << i);
const Real angle = -3.14159265358979323846 * 2 / (1 << (i + 1));
for (int j = 0; j < (1 << (i - 1)); ++j) {
zeta[i][j << 1] = zeta[i - 1][j];
const Real theta = angle * ((j << 1) + 1);
zeta[i][(j << 1) + 1] = Complex(std::cos(theta), std::sin(theta));
}
}
}
void dft(std::vector<Complex>* a) {
assert(std::has_single_bit(a->size()));
const int n = a->size();
init(n);
const int shift =
std::countr_zero(butterfly.size()) - std::countr_zero(a->size());
for (int i = 0; i < n; ++i) {
const int j = butterfly[i] >> shift;
if (i < j) std::swap((*a)[i], (*a)[j]);
}
for (int block = 1, den = 0; block < n; block <<= 1, ++den) {
for (int i = 0; i < n; i += (block << 1)) {
for (int j = 0; j < block; ++j) {
const Complex tmp = (*a)[i + j + block] * zeta[den][j];
(*a)[i + j + block] = (*a)[i + j] - tmp;
(*a)[i + j] = (*a)[i + j] + tmp;
}
}
}
}
template <typename T>
std::vector<Complex> real_dft(const std::vector<T>& a) {
const int n = a.size();
std::vector<Complex> c(std::bit_ceil(a.size()));
for (int i = 0; i < n; ++i) {
c[i].re = a[i];
}
dft(&c);
return c;
}
void idft(std::vector<Complex>* a) {
const int n = a->size();
dft(a);
std::reverse(std::next(a->begin()), a->end());
const Real r = 1. / n;
std::transform(a->begin(), a->end(), a->begin(),
[r](const Complex& c) -> Complex { return c.mul_real(r); });
}
template <typename T>
std::vector<Real> convolution(const std::vector<T>& a,
const std::vector<T>& b) {
const int a_size = a.size(), b_size = b.size(), c_size = a_size + b_size - 1;
const int n = std::max(std::bit_ceil(static_cast<unsigned int>(c_size)), 2U);
const int hlf = n >> 1, qtr = hlf >> 1;
std::vector<Complex> c(n);
for (int i = 0; i < a_size; ++i) {
c[i].re = a[i];
}
for (int i = 0; i < b_size; ++i) {
c[i].im = b[i];
}
dft(&c);
c.front() = Complex(c.front().re * c.front().im, 0);
for (int i = 1; i < hlf; ++i) {
const Complex i_square = c[i] * c[i], j_square = c[n - i] * c[n - i];
c[i] = (j_square.conj() - i_square).mul_pin(0.25);
c[n - i] = (i_square.conj() - j_square).mul_pin(0.25);
}
c[hlf] = Complex(c[hlf].re * c[hlf].im, 0);
c.front() = (c.front() + c[hlf]
+ (c.front() - c[hlf]).mul_pin(1)).mul_real(0.5);
const int den = std::countr_zero(static_cast<unsigned int>(hlf));
for (int i = 1; i < qtr; ++i) {
const int j = hlf - i;
const Complex tmp1 = c[i] + c[j].conj();
const Complex tmp2 = ((c[i] - c[j].conj()) * zeta[den][j]).mul_pin(1);
c[i] = (tmp1 - tmp2).mul_real(0.5);
c[j] = (tmp1 + tmp2).mul_real(0.5).conj();
}
if (qtr > 0) c[qtr] = c[qtr].conj();
c.resize(hlf);
idft(&c);
std::vector<Real> res(c_size);
for (int i = 0; i < c_size; i += 2) {
res[i] = c[i >> 1].re;
}
for (int i = 1; i < c_size; i += 2) {
res[i] = c[i >> 1].im;
}
return res;
}
} // namespace fast_fourier_transform
} // namespace emthrm
#line 1 "include/emthrm/math/modint.hpp"
#ifndef ARBITRARY_MODINT
#line 6 "include/emthrm/math/modint.hpp"
#endif
#include <compare>
#include <iostream>
// #include <numeric>
#line 12 "include/emthrm/math/modint.hpp"
namespace emthrm {
#ifndef ARBITRARY_MODINT
template <unsigned int M>
struct MInt {
unsigned int v;
constexpr MInt() : v(0) {}
constexpr MInt(const long long x) : v(x >= 0 ? x % M : x % M + M) {}
static constexpr MInt raw(const int x) {
MInt x_;
x_.v = x;
return x_;
}
static constexpr int get_mod() { return M; }
static constexpr void set_mod(const int divisor) {
assert(std::cmp_equal(divisor, M));
}
static void init(const int x) {
inv<true>(x);
fact(x);
fact_inv(x);
}
template <bool MEMOIZES = false>
static MInt inv(const int n) {
// assert(0 <= n && n < M && std::gcd(n, M) == 1);
static std::vector<MInt> inverse{0, 1};
const int prev = inverse.size();
if (n < prev) return inverse[n];
if constexpr (MEMOIZES) {
// "n!" and "M" must be disjoint.
inverse.resize(n + 1);
for (int i = prev; i <= n; ++i) {
inverse[i] = -inverse[M % i] * raw(M / i);
}
return inverse[n];
}
int u = 1, v = 0;
for (unsigned int a = n, b = M; b;) {
const unsigned int q = a / b;
std::swap(a -= q * b, b);
std::swap(u -= q * v, v);
}
return u;
}
static MInt fact(const int n) {
static std::vector<MInt> factorial{1};
if (const int prev = factorial.size(); n >= prev) {
factorial.resize(n + 1);
for (int i = prev; i <= n; ++i) {
factorial[i] = factorial[i - 1] * i;
}
}
return factorial[n];
}
static MInt fact_inv(const int n) {
static std::vector<MInt> f_inv{1};
if (const int prev = f_inv.size(); n >= prev) {
f_inv.resize(n + 1);
f_inv[n] = inv(fact(n).v);
for (int i = n; i > prev; --i) {
f_inv[i - 1] = f_inv[i] * i;
}
}
return f_inv[n];
}
static MInt nCk(const int n, const int k) {
if (n < 0 || n < k || k < 0) [[unlikely]] return MInt();
return fact(n) * (n - k < k ? fact_inv(k) * fact_inv(n - k) :
fact_inv(n - k) * fact_inv(k));
}
static MInt nPk(const int n, const int k) {
return n < 0 || n < k || k < 0 ? MInt() : fact(n) * fact_inv(n - k);
}
static MInt nHk(const int n, const int k) {
return n < 0 || k < 0 ? MInt() : (k == 0 ? 1 : nCk(n + k - 1, k));
}
static MInt large_nCk(long long n, const int k) {
if (n < 0 || n < k || k < 0) [[unlikely]] return MInt();
inv<true>(k);
MInt res = 1;
for (int i = 1; i <= k; ++i) {
res *= inv(i) * n--;
}
return res;
}
constexpr MInt pow(long long exponent) const {
MInt res = 1, tmp = *this;
for (; exponent > 0; exponent >>= 1) {
if (exponent & 1) res *= tmp;
tmp *= tmp;
}
return res;
}
constexpr MInt& operator+=(const MInt& x) {
if ((v += x.v) >= M) v -= M;
return *this;
}
constexpr MInt& operator-=(const MInt& x) {
if ((v += M - x.v) >= M) v -= M;
return *this;
}
constexpr MInt& operator*=(const MInt& x) {
v = (unsigned long long){v} * x.v % M;
return *this;
}
MInt& operator/=(const MInt& x) { return *this *= inv(x.v); }
constexpr auto operator<=>(const MInt& x) const = default;
constexpr MInt& operator++() {
if (++v == M) [[unlikely]] v = 0;
return *this;
}
constexpr MInt operator++(int) {
const MInt res = *this;
++*this;
return res;
}
constexpr MInt& operator--() {
v = (v == 0 ? M - 1 : v - 1);
return *this;
}
constexpr MInt operator--(int) {
const MInt res = *this;
--*this;
return res;
}
constexpr MInt operator+() const { return *this; }
constexpr MInt operator-() const { return raw(v ? M - v : 0); }
constexpr MInt operator+(const MInt& x) const { return MInt(*this) += x; }
constexpr MInt operator-(const MInt& x) const { return MInt(*this) -= x; }
constexpr MInt operator*(const MInt& x) const { return MInt(*this) *= x; }
MInt operator/(const MInt& x) const { return MInt(*this) /= x; }
friend std::ostream& operator<<(std::ostream& os, const MInt& x) {
return os << x.v;
}
friend std::istream& operator>>(std::istream& is, MInt& x) {
long long v;
is >> v;
x = MInt(v);
return is;
}
};
#else // ARBITRARY_MODINT
template <int ID>
struct MInt {
unsigned int v;
constexpr MInt() : v(0) {}
MInt(const long long x) : v(x >= 0 ? x % mod() : x % mod() + mod()) {}
static constexpr MInt raw(const int x) {
MInt x_;
x_.v = x;
return x_;
}
static int get_mod() { return mod(); }
static void set_mod(const unsigned int divisor) { mod() = divisor; }
static void init(const int x) {
inv<true>(x);
fact(x);
fact_inv(x);
}
template <bool MEMOIZES = false>
static MInt inv(const int n) {
// assert(0 <= n && n < mod() && std::gcd(x, mod()) == 1);
static std::vector<MInt> inverse{0, 1};
const int prev = inverse.size();
if (n < prev) return inverse[n];
if constexpr (MEMOIZES) {
// "n!" and "M" must be disjoint.
inverse.resize(n + 1);
for (int i = prev; i <= n; ++i) {
inverse[i] = -inverse[mod() % i] * raw(mod() / i);
}
return inverse[n];
}
int u = 1, v = 0;
for (unsigned int a = n, b = mod(); b;) {
const unsigned int q = a / b;
std::swap(a -= q * b, b);
std::swap(u -= q * v, v);
}
return u;
}
static MInt fact(const int n) {
static std::vector<MInt> factorial{1};
if (const int prev = factorial.size(); n >= prev) {
factorial.resize(n + 1);
for (int i = prev; i <= n; ++i) {
factorial[i] = factorial[i - 1] * i;
}
}
return factorial[n];
}
static MInt fact_inv(const int n) {
static std::vector<MInt> f_inv{1};
if (const int prev = f_inv.size(); n >= prev) {
f_inv.resize(n + 1);
f_inv[n] = inv(fact(n).v);
for (int i = n; i > prev; --i) {
f_inv[i - 1] = f_inv[i] * i;
}
}
return f_inv[n];
}
static MInt nCk(const int n, const int k) {
if (n < 0 || n < k || k < 0) [[unlikely]] return MInt();
return fact(n) * (n - k < k ? fact_inv(k) * fact_inv(n - k) :
fact_inv(n - k) * fact_inv(k));
}
static MInt nPk(const int n, const int k) {
return n < 0 || n < k || k < 0 ? MInt() : fact(n) * fact_inv(n - k);
}
static MInt nHk(const int n, const int k) {
return n < 0 || k < 0 ? MInt() : (k == 0 ? 1 : nCk(n + k - 1, k));
}
static MInt large_nCk(long long n, const int k) {
if (n < 0 || n < k || k < 0) [[unlikely]] return MInt();
inv<true>(k);
MInt res = 1;
for (int i = 1; i <= k; ++i) {
res *= inv(i) * n--;
}
return res;
}
MInt pow(long long exponent) const {
MInt res = 1, tmp = *this;
for (; exponent > 0; exponent >>= 1) {
if (exponent & 1) res *= tmp;
tmp *= tmp;
}
return res;
}
MInt& operator+=(const MInt& x) {
if ((v += x.v) >= mod()) v -= mod();
return *this;
}
MInt& operator-=(const MInt& x) {
if ((v += mod() - x.v) >= mod()) v -= mod();
return *this;
}
MInt& operator*=(const MInt& x) {
v = (unsigned long long){v} * x.v % mod();
return *this;
}
MInt& operator/=(const MInt& x) { return *this *= inv(x.v); }
auto operator<=>(const MInt& x) const = default;
MInt& operator++() {
if (++v == mod()) [[unlikely]] v = 0;
return *this;
}
MInt operator++(int) {
const MInt res = *this;
++*this;
return res;
}
MInt& operator--() {
v = (v == 0 ? mod() - 1 : v - 1);
return *this;
}
MInt operator--(int) {
const MInt res = *this;
--*this;
return res;
}
MInt operator+() const { return *this; }
MInt operator-() const { return raw(v ? mod() - v : 0); }
MInt operator+(const MInt& x) const { return MInt(*this) += x; }
MInt operator-(const MInt& x) const { return MInt(*this) -= x; }
MInt operator*(const MInt& x) const { return MInt(*this) *= x; }
MInt operator/(const MInt& x) const { return MInt(*this) /= x; }
friend std::ostream& operator<<(std::ostream& os, const MInt& x) {
return os << x.v;
}
friend std::istream& operator>>(std::istream& is, MInt& x) {
long long v;
is >> v;
x = MInt(v);
return is;
}
private:
static unsigned int& mod() {
static unsigned int divisor = 0;
return divisor;
}
};
#endif // ARBITRARY_MODINT
} // namespace emthrm
#line 11 "include/emthrm/math/convolution/mod_convolution.hpp"
namespace emthrm {
template <int PRECISION = 15, unsigned int T>
std::vector<MInt<T>> mod_convolution(const std::vector<MInt<T>>& a,
const std::vector<MInt<T>>& b) {
using ModInt = MInt<T>;
const int a_size = a.size(), b_size = b.size(), c_size = a_size + b_size - 1;
const int n = std::max(std::bit_ceil(static_cast<unsigned int>(c_size)), 2U);
constexpr int mask = (1 << PRECISION) - 1;
std::vector<fast_fourier_transform::Complex> x(n), y(n);
std::transform(
a.begin(), a.end(), x.begin(),
[mask](const MInt<T>& x) -> fast_fourier_transform::Complex {
return fast_fourier_transform::Complex(x.v & mask, x.v >> PRECISION);
});
fast_fourier_transform::dft(&x);
std::transform(
b.begin(), b.end(), y.begin(),
[mask](const MInt<T>& y) -> fast_fourier_transform::Complex {
return fast_fourier_transform::Complex(y.v & mask, y.v >> PRECISION);
});
fast_fourier_transform::dft(&y);
const int half = n >> 1;
fast_fourier_transform::Complex tmp_a = x.front(), tmp_b = y.front();
x.front() =
fast_fourier_transform::Complex(tmp_a.re * tmp_b.re, tmp_a.im * tmp_b.im);
y.front() =
fast_fourier_transform::Complex(
tmp_a.re * tmp_b.im + tmp_a.im * tmp_b.re, 0);
for (int i = 1; i < half; ++i) {
const int j = n - i;
const fast_fourier_transform::Complex a_l_i =
(x[i] + x[j].conj()).mul_real(0.5);
const fast_fourier_transform::Complex a_h_i =
(x[j].conj() - x[i]).mul_pin(0.5);
const fast_fourier_transform::Complex b_l_i =
(y[i] + y[j].conj()).mul_real(0.5);
const fast_fourier_transform::Complex b_h_i =
(y[j].conj() - y[i]).mul_pin(0.5);
const fast_fourier_transform::Complex a_l_j =
(x[j] + x[i].conj()).mul_real(0.5);
const fast_fourier_transform::Complex a_h_j =
(x[i].conj() - x[j]).mul_pin(0.5);
const fast_fourier_transform::Complex b_l_j =
(y[j] + y[i].conj()).mul_real(0.5);
const fast_fourier_transform::Complex b_h_j =
(y[i].conj() - y[j]).mul_pin(0.5);
x[i] = a_l_i * b_l_i + (a_h_i * b_h_i).mul_pin(1);
y[i] = a_l_i * b_h_i + a_h_i * b_l_i;
x[j] = a_l_j * b_l_j + (a_h_j * b_h_j).mul_pin(1);
y[j] = a_l_j * b_h_j + a_h_j * b_l_j;
}
tmp_a = x[half];
tmp_b = y[half];
x[half] = fast_fourier_transform::Complex(
tmp_a.re * tmp_b.re, tmp_a.im * tmp_b.im);
y[half] = fast_fourier_transform::Complex(
tmp_a.re * tmp_b.im + tmp_a.im * tmp_b.re, 0);
fast_fourier_transform::idft(&x);
fast_fourier_transform::idft(&y);
std::vector<ModInt> res(c_size);
const ModInt tmp1 = 1 << PRECISION, tmp2 = 1LL << (PRECISION << 1);
for (int i = 0; i < c_size; ++i) {
res[i] = tmp1 * std::llround(y[i].re) + tmp2 * std::llround(x[i].im)
+ std::llround(x[i].re);
}
return res;
}
} // namespace emthrm