多项式完全模板

  • 多项式 +,+,-

  • 多项式乘法

  • 多项式求逆

  • 多项式开根

  • 多项式求幂

  • 多项式 ln\ln

  • 多项式 exp\exp

  • 多项式整除/取模

  • 多项式立方根(upd@11.28),支持任意常数项(平方根同)

    常数项二次/三次剩余BSGS求解复杂度 O(p)O(\sqrt{p})

#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define D(x) cout << #x << " : " << x << endl
using namespace std;

constexpr int N = 1e5 + 10;
constexpr int mod = 998244353;
using ll = long long;

int qpow(int x, int y) {
    int ret = 1;
    while (y) {
        if (y & 1) ret = ll(ret) * x % mod;
        y >>= 1;
        x = ll(x) * x % mod;
    }
    return ret;
}

namespace Poly {

#define Rep(i, a, b) for (int i = (a); i < (b); i++)
using poly = vector<int>;

constexpr int add(int x, int y) { return x += y, x < mod ? x : x - mod; }
constexpr int sub(int x, int y) { return x -= y, x >= 0 ? x : x + mod; }

void print(const poly& x) {
    for (int i : x) printf("%d ", i);
    printf("\n");
}

// poly +
poly& operator+=(poly& x, const poly& y) {
    int ny = y.size();
    x.resize(max(int(x.size()), ny));
    Rep(i, 0, ny) x[i] = add(x[i], y[i]);
    return x;
}
poly operator+(poly x, poly y) { return x += y, x; }
poly& operator+=(poly& x, int y) { return x[0] = add(x[0], y), x; }
poly operator+(poly x, int y) { return x += y, x; }
poly operator+(int x, poly y) { return y += x, y; }

// poly -
poly& operator-=(poly& x, const poly& y) {
    int ny = y.size();
    x.resize(max(int(x.size()), ny));
    Rep(i, 0, ny) x[i] = sub(x[i], y[i]);
    return x;
}
poly operator-(poly x, poly y) { return x -= y, x; }
poly& operator-=(poly& x, int y) { return x[0] = sub(x[0], y), x; }
poly operator-(poly x, int y) { return x -= y, x; }
poly operator-(int x, poly y) { return poly({x}) - y; }

// ntt
void ntt(poly& x, int op) {
    static int rev[N * 4], gi[N * 2];
    int n = x.size();
    rev[0] = 0;
    Rep(i, 1, n) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) rev[i] |= n >> 1;
    }
    Rep(i, 0, n) if (i < rev[i]) swap(x[i], x[rev[i]]);
    gi[0] = 1, gi[1] = qpow(3, (mod - 1) / n);
    rep(i, 2, n) gi[i] = ll(gi[i - 1]) * gi[1] % mod;

    for (int h = 1, p = n / 2; h < n; h *= 2, p /= 2) {
        for (int j = 0; j < n; j += h * 2) {
            for (int k = j, g = 0; k < j + h; ++k, g += p) {
                int t = (ll)gi[g] * x[k + h] % mod;
                x[k + h] = sub(x[k], t);
                x[k] = add(x[k], t);
            }
        }
    }
    if (op == -1) {
        reverse(&x[1], &x[n]);
        int inv = qpow(n, mod - 2);
        Rep(i, 0, n) x[i] = (ll)x[i] * inv % mod;
    }
}

poly& calc_eq(
    poly& x, const poly& y, function<int(int, int)> func,
    function<int(int, int)> getdeg = [](int n, int m) { return n + m - 1; }) {
    static poly t;
    t = y;
    int n = x.size(), m = y.size(), l = 1;
    while (l < getdeg(n, m)) l *= 2;
    x.resize(l), t.resize(l);
    ntt(x, 1), ntt(t, 1);
    Rep(i, 0, l) x[i] = func(x[i], t[i]);
    ntt(x, -1);
    return x;
}

// poly *
poly& operator*=(poly& x, const poly& y) {
    return calc_eq(x, y, [](int X, int Y) { return ll(X) * Y % mod; }), x;
}
poly operator*(poly x, poly y) { return x *= y, x; }
poly& operator*=(poly& x, const int y) {
    int n = x.size();
    Rep(i, 0, n) x[i] = (ll(x[i]) * y % mod + mod) % mod;
    return x;
}
poly operator*(poly x, int y) { return x *= y, x; }
poly operator*(int x, poly y) { return y *= x, y; }

// poly % x^n
poly operator%(poly x, int n) { return x.resize(n), x; }

// inv, poly inv
int inv(int x) { return qpow(x, mod - 2); }
void inv(const poly& h, poly& x, int n) {
    if (n == 1) return x = {inv(h[0])}, void();
    inv(h, x, (n + 1) / 2);
    calc_eq(
        x, h % n,
        [](int X, int Y) -> int { return (ll)X * sub(2, ll(X) * Y % mod) % mod; },
        [](int n, int m) { return n * 2 + m - 1; });
    x.resize(n);
}
poly inv(const poly& h) {
    poly t;
    inv(h, t, h.size());
    return t;
}

// poly / (mod x^n)
poly operator/(poly x, poly y) { return x * inv(y); }
poly operator/(poly x, int y) { return y = inv(y), x * y; }

// poly derivative, integrate
poly derivative(poly x) {
    int n = x.size();
    Rep(i, 1, n) x[i - 1] = ll(x[i]) * i % mod;
    x.pop_back();
    return x;
}
poly integrate(poly x) {
    int n = x.size();
    x.push_back(0);
    per(i, n - 1, 0) x[i + 1] = ll(x[i]) * inv(i + 1) % mod;
    x[0] = 0;
    return x;
}

// poly log
poly ln(poly x) {
    assert(x[0] == 1);
    int n = x.size();
    return integrate(derivative(x) / x % n) % n;
}

// poly exp
void exp(const poly& h, poly& x, int n) {
    if (n == 1) return x = {1}, void();
    exp(h, x, (n + 1) / 2);
    x *= (1 - ln(x % n) + h % n);  // note `ln(x % n)`
    x.resize(n);
}
poly exp(const poly& x) {
    assert(x[0] == 0);
    poly t;
    exp(x, t, x.size());
    return t;
}

// poly qpow
poly pow(poly x, int y) {
    int n = x.size();
    if (x[0] == 1) return exp(ln(x) * y);
    ll p = -1, v, invv;
    Rep(i, 0, n) if (x[i] > 0) {
        p = i, v = x[i], invv = inv(v);
        Rep(j, i, n) x[j - i] = invv * x[j] % mod;
        x.resize(n - i), x.resize(n);
        break;
    }
    if (p == -1) return poly({0});
    x = exp(ln(x) * y);
    v = qpow(v, y), p = p * y % mod;
    per(i, n - 1, 0) if (p + i < n) x[i + p] = ll(x[i]) * v % mod;
    fill(&x[0], &x[p], 0);
    return x;
}

// poly / poly
poly div(poly x, poly y) {
    reverse(x.begin(), x.end()), reverse(y.begin(), y.end());
    int l = x.size() - y.size() + 1;
    x *= inv(y % l), x.resize(l);
    reverse(x.begin(), x.end());
    return x;
}

// poly mod poly
poly pmod(poly x, poly y) { return x - div(x, y) * y; }

// BSGS to get root (O(log p))
int BSGS(int a, int b) {
    static unordered_map<int, int> mp;
    mp.clear();
    const int bond = int(::sqrt(mod)) + 1, c = qpow(a, bond);
    for (int i = 1, t = ll(b) * a % mod; i <= bond; i++, t = ll(t) * a % mod) mp[t] = i;
    for (int i = 1, t = c % mod; i <= bond; i++, t = ll(t) * c % mod)
        if (mp.count(t)) return ll(i) * bond - mp[t];
    return -1;
}
// x ^ a = b (mod p) (O(log p))
// 2 possible solutions when a is an even.
int root(int a, int b) {
    const int pr = 3;
    int t = BSGS(qpow(pr, a), b);
    if (t == -1) return -1;
    return qpow(pr, t);
}

// poly sqrt
void sqrt(const poly& h, poly& x, int n) {
    if (n == 1) return x = {root(2, h[0])}, void();
    sqrt(h, x, (n + 1) / 2);
    x = (x + h % n / (x % n)) / 2 % n;
}
poly sqrt(const poly& h) {
    poly t;
    sqrt(h, t, h.size());
    return t;
}

// poly sqrt3
void sqrt3(const poly& h, poly& x, int n) {
    if (n == 1) return x = {root(3, h[0])}, void();
    sqrt3(h, x, (n + 1) / 2);
    x = (x * 2 + h % n / (x * x % n)) / 3 % n;
}
poly sqrt3(const poly& h) {
    poly t;
    sqrt3(h, t, h.size());
    return t;
}

}  // namespace Poly

多项式完全模板
https://yzzzf.xyz/2021/11/27/多项式模板/
Author
Zifan Ying
Posted on
November 27, 2021
Licensed under