NTT - YessineJallouli/Competitive-Programming GitHub Wiki

#include <bits/stdc++.h>
using namespace  std;
using ll=long long;

constexpr ll mod = 998244353;

long long fast_pow(long long a, long long b) {
    long long ans = 1;
    while (b > 0) {
        if (b & 1)
            ans = (ans*a)%mod;
        a = (a*a)%mod;
        b/= 2;
    }
    return ans;
}

long long modinv(long long a)
{
    return fast_pow(a,mod-2);
}

constexpr int L=20; // >= to log(N)

// Find a generator for the multiplicative group
ll findGen()
{
    for(int i=2;i<mod;i++) if(fast_pow(i,(mod-1)/2)==mod-1) return i;
    throw std::runtime_error("No Generator found");
}

ll G1[L+1],G2[L+1];
// Gl = 3 for M=998244353
ll Gl=findGen();
ll Gr=modinv(Gl);


/**
 * @abstract Calculate a primitive nth root of unity
 * @param G a generator of the multiplicative group
 * @param n The order of the primitive root
 * */


ll RoU(ll G, ll n)
{
    constexpr auto phi=mod-1;
    ll w=G;
    auto [q,r]=div(phi,n);
    if(r!=0) throw std::invalid_argument("size must divide phi(m)");
    return fast_pow(w,q);
}

/**
 * @abstract Construct the primitive roots of unity for NTT. SHOULD BE CALLED BEFORE ANY NTT
 * @param L An integer such that all arrays have no more than 2^L elements
 * */
void init_ntt(unsigned L)
{
    for(int i=1;i<=L;i++)
    {
        G1[i] = RoU(Gl, 1<<i);
        G2[i] = RoU(Gr, 1<<i); // Root of unity for inverse
    }
}

enum ntt_norm {AUTO=-1,NO,YES};

/**
 * @abstract Inplace NTT of an array a
 * @param a the array
 * @param inv Calculate NTT, or inverse NTT, depending on inv.
 * @param norm Whether to normalize the NTT. For advanced use
 * */
void inplace_ntt2(vector<ll> & a, bool inv=false, int norm= AUTO)
{
    if(norm==AUTO) norm = inv;
    int n = a.size();
    for (int i = 1, j = 0; i < n; i++)
    {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1)
            j ^= bit;
        j ^= bit;
        if (i < j)
            swap(a[i], a[j]);
    }
    for (int len = 2,r=1; len <= n; len <<= 1,r++)
    {
        auto wlen = inv? G2[r] : G1[r];
        for (int i = 0; i < n; i += len)
        {
            ll w(1);
            for (int j = 0; j < len / 2; j++)
            {
                ll u = a[i+j], v = a[i+j+len/2] * w % mod;
                a[i+j] = (u + v)%mod;
                a[i+j+len/2] = (u - v+mod)%mod;
                w = w*wlen % mod;
            }
        }
    }
    if(norm)
    {
        auto r = modinv(n);
        for (int i = 0; i < n; i++) {
            a[i] = (a[i]* r)%mod;
        }
    }
}

/**
 * @abstract Multiply two polynomials
 * @param a The first polynomial
 * @param b The second polynomial
 * @result The result of polynomial multiplication
 * */

ll bit_ceil(ll r) {
    ll n = 1;
    while (n < r)
        n*= 2;
    return n;
}

vector<ll> multiply(vector<ll> a, vector<ll>  b)
{
    size_t r=a.size()+b.size()-1;
    auto m= bit_ceil(r);
    a.resize(m);
    b.resize(m);
    inplace_ntt2(a);
    inplace_ntt2(b);
    for(int i=0;i<m;i++)
        a[i]=(a[i]*b[i])%mod;
    inplace_ntt2(a,true);
    a.resize(r);
    return a;
}

// divide and conquer for prod(1+x_i)
vector<ll> fast_expansion(const vector<ll> &X, int l, int r) {
    if(r == l)
        return {X[l],1};
    auto m = (r+l)/2;
    return multiply(fast_expansion(X, l,m), fast_expansion(X, m+1, r));
}

int main()
{
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    init_ntt(L);
    int n,m; cin >> n >> m;
    vector<ll> A(n);
    vector<ll> B(m);
    for (int i = 0; i < n; i++) {
        cin >> A[i];
    }
    for (auto &i : B) {
        cin >> i;
    }
    auto C = multiply(A,B);
    for (ll i : C) {
        cout << i << ' ';
    }
}