NTT - YessineJallouli/Competitive-Programming GitHub Wiki
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
constexpr ll mod = 998244353;
long long fast_pow(long long a, long long b) {
long long ans = 1;
while (b > 0) {
if (b & 1)
ans = (ans*a)%mod;
a = (a*a)%mod;
b/= 2;
}
return ans;
}
long long modinv(long long a)
{
return fast_pow(a,mod-2);
}
constexpr int L=20; // >= to log(N)
// Find a generator for the multiplicative group
ll findGen()
{
for(int i=2;i<mod;i++) if(fast_pow(i,(mod-1)/2)==mod-1) return i;
throw std::runtime_error("No Generator found");
}
ll G1[L+1],G2[L+1];
// Gl = 3 for M=998244353
ll Gl=findGen();
ll Gr=modinv(Gl);
/**
* @abstract Calculate a primitive nth root of unity
* @param G a generator of the multiplicative group
* @param n The order of the primitive root
* */
ll RoU(ll G, ll n)
{
constexpr auto phi=mod-1;
ll w=G;
auto [q,r]=div(phi,n);
if(r!=0) throw std::invalid_argument("size must divide phi(m)");
return fast_pow(w,q);
}
/**
* @abstract Construct the primitive roots of unity for NTT. SHOULD BE CALLED BEFORE ANY NTT
* @param L An integer such that all arrays have no more than 2^L elements
* */
void init_ntt(unsigned L)
{
for(int i=1;i<=L;i++)
{
G1[i] = RoU(Gl, 1<<i);
G2[i] = RoU(Gr, 1<<i); // Root of unity for inverse
}
}
enum ntt_norm {AUTO=-1,NO,YES};
/**
* @abstract Inplace NTT of an array a
* @param a the array
* @param inv Calculate NTT, or inverse NTT, depending on inv.
* @param norm Whether to normalize the NTT. For advanced use
* */
void inplace_ntt2(vector<ll> & a, bool inv=false, int norm= AUTO)
{
if(norm==AUTO) norm = inv;
int n = a.size();
for (int i = 1, j = 0; i < n; i++)
{
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
swap(a[i], a[j]);
}
for (int len = 2,r=1; len <= n; len <<= 1,r++)
{
auto wlen = inv? G2[r] : G1[r];
for (int i = 0; i < n; i += len)
{
ll w(1);
for (int j = 0; j < len / 2; j++)
{
ll u = a[i+j], v = a[i+j+len/2] * w % mod;
a[i+j] = (u + v)%mod;
a[i+j+len/2] = (u - v+mod)%mod;
w = w*wlen % mod;
}
}
}
if(norm)
{
auto r = modinv(n);
for (int i = 0; i < n; i++) {
a[i] = (a[i]* r)%mod;
}
}
}
/**
* @abstract Multiply two polynomials
* @param a The first polynomial
* @param b The second polynomial
* @result The result of polynomial multiplication
* */
ll bit_ceil(ll r) {
ll n = 1;
while (n < r)
n*= 2;
return n;
}
vector<ll> multiply(vector<ll> a, vector<ll> b)
{
size_t r=a.size()+b.size()-1;
auto m= bit_ceil(r);
a.resize(m);
b.resize(m);
inplace_ntt2(a);
inplace_ntt2(b);
for(int i=0;i<m;i++)
a[i]=(a[i]*b[i])%mod;
inplace_ntt2(a,true);
a.resize(r);
return a;
}
// divide and conquer for prod(1+x_i)
vector<ll> fast_expansion(const vector<ll> &X, int l, int r) {
if(r == l)
return {X[l],1};
auto m = (r+l)/2;
return multiply(fast_expansion(X, l,m), fast_expansion(X, m+1, r));
}
int main()
{
ios_base::sync_with_stdio(false); cin.tie(nullptr);
init_ntt(L);
int n,m; cin >> n >> m;
vector<ll> A(n);
vector<ll> B(m);
for (int i = 0; i < n; i++) {
cin >> A[i];
}
for (auto &i : B) {
cin >> i;
}
auto C = multiply(A,B);
for (ll i : C) {
cout << i << ' ';
}
}