Linear Algebra - YessineJallouli/Competitive-Programming GitHub Wiki

#include <bits/stdc++.h>
#define TEMPLATE template<typename R>

using ll=long long;
constexpr ll mod=998'244'353;
using namespace std;

TEMPLATE
R fast_pow(R a, long long b) {
    R ans = 1;
    while (b > 0)
    {
        if (b & 1) ans *= a;
        a = a*a;
        b/= 2;
    }
    return ans;
}

struct Cyc
{
    ll val;
    Cyc(ll v=0):val((v%mod+mod)%mod){}
    Cyc& operator+=(const Cyc & o)
    {
        val+=o.val;
        val%=mod;
        return *this;
    }

    Cyc& operator-=(const Cyc & o)
    {
        val+=mod-o.val;
        val%=mod;
        return *this;
    }

    Cyc& operator*=(const Cyc & o)
    {
        val*=o.val;
        val%=mod;
        return *this;
    }

    Cyc operator-() const
    {
        return mod-val;
    }

    Cyc inv() const {
        return fast_pow<Cyc>(val,mod-2);
    }


    Cyc& operator/=(const Cyc &o)
    {
        return *this *= o.inv();
    }
};

Cyc operator+(const Cyc &a,const Cyc &b)
{
    auto c=a;
    return c+=b;
}

Cyc operator*(const Cyc &a,const Cyc &b)
{
    auto c=a;
    return c*=b;
}

Cyc operator-(const Cyc &a,const Cyc &b)
{
    auto c=a;
    return c-=b;
}

Cyc operator/(const Cyc &a, const Cyc &b)
{
    auto c=a;
    return c/=b;
}

bool is_zero(double x)
{
    return abs(x) < 1e-6;
}

bool is_zero(Cyc x) {
    return x.val==0;
}

// Set u <- u - k*v
TEMPLATE
void shift_by(vector<R> &u, R k, const vector<R> &v)
{
    if (u.size()!=v.size()) throw std::logic_error("u.size()!=v.size()");
    for (int i=0;i<v.size();i++)
        u[i]-=k*v[i];
}

// Set u <- u / k
TEMPLATE
void divide_by(vector<R> &u,R k) {
    for (int i=0;i<u.size();i++)
        u[i]/=k;
}

TEMPLATE
struct Mat : vector<vector<R>>
{
    using Par = vector<vector<R>>;
    using Vec=vector<R>;
    using Par::Par; // To support vector constructors
    using Par::size;
    using Par::data;
    Mat(size_t rows,size_t cols,R val={}): Par(rows,Vec(cols,val)){}
    static Mat Id(size_t n) {
        Mat I(n,n);
        for (int i=0;i<n;i++)
            I[i][i]=1;
        return I;
    }

    static Mat Diag(vector<R> v) {
        Mat D(v.size(),v.size());
        for (int i=0;i<v.size();i++) D[i][i]=v[i];
        return D;
    }

    int rows() const {
        return size();
    }

    int cols() const {
        return size()?data()->size():0;
    }

    Mat operator*(const Mat& B) const
    {
        auto& A=*this;
        int n=A.rows(),r=A.cols(),m=B.cols();
        Mat C(n,m);
        for (int i=0;i<n;i++) for (int k=0;k<r;k++) for (int j=0;j<m;j++) C[i][j]+=A[i][k]*B[k][j];
        return C;
    }

    Vec operator*(const Vec& U) const
    {
        auto& A=*this;
        int n=A.rows(),m=A.cols();
        Vec V(n);
        for (int i=0;i<n;i++) for (int j=0;j<m;j++) V[i]+=A[i][j]*U[j];
        return V;
    }

    Mat& operator*=(const Mat& B) {
        return *this = *this * B;
    }

    // Trace of a matrix. Sum of its eigenvalues
    R tr() const {
        R w{};
        for (int i=0;i<min(rows(),cols());i++)
            w+=data()[i][i];
        return w;
    }

    // Transpose of a matrix.
    Mat T() const {
        Mat Zt(cols(),rows());
        for (int i=0;i<rows();i++) for (int j=0;j<cols();j++)
                Zt[j][i]=data()[i][j];
        return Zt;
    }

    struct MatSolve {
        Mat E,Y,X; // Results of the row echelon form
        size_t rank; // Rank of E
        bool dir; // Number of swaps modulo 2 in row_echelon is 0. Denotes that E and the original matrix have the same direction
        vector<pair<int,int>> mapper; // Mapping denoting the pivot cells

        MatSolve(Mat E_,Mat Y_,size_t rank,bool dir,vector<pair<int,int>> mapper):
                E(move(E_)),Y(move(Y_)),rank(rank),mapper(move(mapper)),dir(dir),X(E.cols(),Y.cols())
        {
        }

        // Solve the system, in general
        void solve_general()
        {
            int m=E.cols(),n=E.rows(),l=Y.cols();
            for (int i=rank;i<n;i++) for (int j=0;j<l;j++) if (!is_zero(Y[i][j])) {
                        X= {}; // No solution
                        return;
                    }
            for (int k=mapper.size()-1;k>=0;k--)
            {
                auto [r,i] = mapper[k];
                for (int j=0;j<r;j++)
                {
                    auto w = E[j][i]/E[r][i];
                    E[j][i]=0;
                    shift_by(Y[j],w,Y[r]);
                }
            }
            for (auto [r,i]: mapper)
                divide_by(X[i]=Y[r],E[r][i]); // Intentional
        }

        // Solve the system, assuming that E is square and det E != 0
        void solve_invertible()
        {
            int n=E.rows();
            X=Y;
            for (int i=n-1;i>=0;i--)
            {
                if (is_zero(E[i][i])) throw std::invalid_argument("Matrix is not invertible");
                for (int j=0;j<i;j++) {
                    auto w = E[j][i]/E[i][i];
                    E[j][i]=0;
                    shift_by(Y[j],w,Y[i]);
                }
            }
            for (int i=0;i<n;i++)
                divide_by(X[i]=Y[i],E[i][i]); // Intentional
        }
    };

    // Rule for selecting the pivot element: Gauss-Jordan Rule
    inline static function<int(const Mat&,int,int)> pivot_rule= [](const Mat& X,int rnk,int col) {
        int r=rnk;
        int n=X.rows();
        while (r<n && is_zero(X[r][col])) r++;
        return r;
    };

    MatSolve row_echelon(Mat Y) const
    {
        size_t rnk=0;
        auto E=*this;
        int n=rows(),m=cols();
        bool dir=true;
        vector<pair<int,int>> mapper;
        for (int i=0;i<m && rnk < n;i++)
        {
            int r=pivot_rule(E,rnk,i);
            if (r==n) continue;
            mapper.emplace_back(rnk,i);
            if (r!=rnk)
            {
                std::swap(E[rnk],E[r]);
                std::swap(Y[rnk],Y[r]);
                dir=!dir;
            }
            for (int j=rnk+1;j<n;j++)
            {
                auto w = E[j][i]/E[rnk][i];
                E[j][i]=0;
                for (int k=i+1;k<m;k++) E[j][k]-=w*E[rnk][k];
                shift_by(Y[j],w,Y[rnk]);
            }
            rnk++;
        }
        return {E,Y,rnk,dir,mapper};
    }

    // Solve AX = Y, where A and Y are known
    Mat solve(const Mat& O, bool invertible=false) const {
        invertible = invertible && rows() == cols();
        auto dec=row_echelon(O);
        if (invertible)
            dec.solve_invertible();
        else dec.solve_general();
        return dec.X;
    }

    // Solve Ax = Y, where A and b are known
    // 1. If A is known to be invertible, set
    Vec solve(const Vec &V , bool invertible =false) const {
        Mat Y(V.size(),1);
        for (int i=0;i<V.size();i++) Y[i][0]=V[i];
        auto X=solve(Y,invertible);
        if (X.empty()) return {}; // If no solutions, return the empty vector
        Vec U(cols());
        for (int i=0;i<cols();i++) U[i]=X[i][0];
        return U;
    }

    // Basis of vectors (e_1,..,e_r) such that Ae_k=0 for all k
    Mat null_basis() const {
        int n=rows(),m=cols();
        Mat Z(*this);
        for (int i=0;i<m;i++) {
            Z.emplace_back(m);
            Z[rows()+i][i]=1;
        }
        auto C = Z.T().row_echelon(Mat(m,0)).E;
        Mat B;
        for (int i=0;i<m;i++) if (all_of(C[i].begin(),C[i].begin()+n,[](auto x) {return is_zero(x);}))
            {
                Vec u(m);
                for (int j=0;j<m;j++) u[j] = C[i][n+j];
                B.push_back(u);
            }
        return B;
    }

    // Basis induced by the matrix
    Mat image_basis() const {
        auto [E,_1,_2,rnk,_3] = row_echelon(*this,false);
        E.resize(rnk);
        return E;
    }

    // Inverse of a square matrix.
    Mat inv() const
    {
        return solve(Mat::Id(rows()),true);
    }

    // Determinant of a square matrix
    R det() const
    {
        auto dec = row_echelon(Mat(rows(),0));
        R w=dec.dir?1:-1;
        for (int i=0;i<rows();i++) w*=dec.E[i][i];
        return w;
    }

    // Rank of a matrix:
    // 1. The number of independent columns/rows in that matrix
    // 2. The dimension of vector space induced by the matrix
    // 3. Size of the basis generated by the matrix
    size_t rank() const {
        return row_echelon(Mat(rows(),0)).rank;
    }

    // Nullity of the matrix:
    // 1. Size of the basis of elements that annihilate the matrix: Ax = 0
    // 2. S
    size_t nullity() const {
        return cols() - rank();
    }
};

TEMPLATE
Mat<R> matrix_pow(Mat<R> A,size_t m)
{
    int n=A.rows();
    auto ans = Mat<R>::Id(n);
    while (m > 0) {
        if (m & 1)
            ans = ans*A;
        A = A*A;
        m/=2;
    }
    return ans;
}

// Calculate det(xI - A) in O(n⁴)
TEMPLATE
vector<R> faddev_lerrier_characteristic_polynomial(const Mat<R>&A)
{
    int n=A.rows();
    std::vector<R> S(n + 1);
    S[n] = 1;
    Mat<R> C(n,n);
    for (int i = n - 1; i >= 0; i--)
    {
        for (int j = 0; j < n; j++) C[j][j] += S[i + 1];
        C = A * C;
        S[i] = -C.tr() / (n - i);
    }
    return S;
}

// Pivot rule to improve stability for numerical systems:
// Use it only for floating point types!
// Do not forget to add: Mat<double>::pivot_rule = partial_pivot;
int partial_pivot(const Mat<double>&X,int rnk,int col) {
    int n=X.rows();
    int r=rnk;
    for (int s=rnk;s<n;s++) if (abs(X[s][col]) > abs(X[r][col]))
            r=s;
    return r;
}

int main()
{
    ios_base::sync_with_stdio(false);
    int n,m;
    cin >> n >> m;
    Mat<Cyc> A(n,m);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            cin >> A[i][j].val;
        }
    }
    Mat<Cyc> B(n,1);
    for (int i = 0; i < n; i++) {
        cin >> B[i][0].val;
    }
    Mat<Cyc> X = A.solve(B);
    if (X.empty()) {
        cout << -1 << '\n';
    }
    else {
        auto kernel = A.null_basis();
        cout << kernel.size() << '\n';
        for (int i = 0; i < m; i++) {
            cout << X[i][0].val << ' ';
        }
        cout << '\n';
        for (auto e : kernel) {
            for (int i = 0; i < m; i++) {
                cout << e[i].val << ' ';
            }
            cout << '\n';
        }
    }
}
⚠️ **GitHub.com Fallback** ⚠️