HLD : Non Commutative Merge - YessineJallouli/Competitive-Programming GitHub Wiki

#include <bits/stdc++.h>
#define ll long long
#define SaveTime ios_base::sync_with_stdio(false), cin.tie(0);
using namespace std;

const int N = 1e5+7;
int c[N];
int sz[N], heavy[N], depth[N], top[N], pos[N], par[N];
vector<vector<int>> tree(N);
vector<int> segment_ids;
int n;

struct node {
    int ans, cntL, cntR,first,last;
    bool cr;
    node() {
        ans = 0; cntL = 0; cntR = 0; cr = false; first = 0; last = 0;
    }
    node(int _ans, int _cntL, int _cntR, bool _cr, int _first, int _last) {
        ans = _ans;
        cntL = _cntL;
        cntR = _cntR;
        cr = _cr;
        first = _first;
        last = _last;
    }
    node(int val): node(1,1,1,true,val,val){}
};

using node_pair = array<node,2>;

node_pair segTree[4*N];

node mrg(const node &a, const node &b) {
    if (a.ans == 0)
        return b;
    if (b.ans == 0)
        return a;
    node res;
    res.first = a.first;
    res.last = b.last;
    res.ans = max(a.ans, b.ans);
    if (a.last > b.first) {
        res.cntL = a.cntL;
        res.cntR = b.cntR;
        res.cr = false;
    }
    else {
        res.ans = max(res.ans, a.cntR+b.cntL);
        res.cr = a.cr && b.cr;
        if (a.cr) {
            res.cntL = a.cntL + b.cntL;
            res.ans = max(res.ans, res.cntL);
        }
        else
            res.cntL = a.cntL;
        if (b.cr) {
            res.cntR = a.cntR + b.cntR;
            res.ans = max(res.ans, res.cntR);
        }
        else
            res.cntR = b.cntR;
    }
    return res;
}

node_pair mrg(const node_pair & a, const node_pair & b)
{
    auto &[u,v]= a;
    auto &[p,q]= b;
    return {mrg(u,p),mrg(q,v)};
}

node_pair get(int qs, int qe, int id = 0, int ns = 0, int ne = n-1) {
    if (ns > qe || ne < qs) {
        return {};
    }
    if (qs <= ns && qe >= ne) {
        return segTree[id];
    }
    int l = 2*id+1;
    int r = 2*id+2;
    int md = (ns+ne)/2;
    return mrg(get(qs, qe, l, ns, md), get(qs, qe, r, md+1, ne));
}

void upd(int ps, node val, int id = 0, int ns = 0, int ne = n-1) {
    if (ps > ne || ps < ns)
        return;
    if (ns == ne) {
        node res = val;
        segTree[id] = {res, res};
        return;
    }
    int l = 2*id+1;
    int r = 2*id+2;
    int md = (ns+ne)/2;
    upd(ps, val, l, ns, md);
    upd(ps, val ,r, md+1, ne);
    segTree[id] = mrg(segTree[l], segTree[r]);
}

void dfs(int node, int pr = -1) {
    sz[node] = 1;
    int mx = 0;
    for (int ch : tree[node]) {
        if (ch == pr)
            continue;
        par[ch] = node;
        depth[ch] = depth[node] + 1;
        dfs(ch, node);
        sz[node]+= sz[ch];
        if (mx < sz[ch]) {
            mx = sz[ch];
            heavy[node] = ch;
        }
    }
}

void decompose(int node, int head, int pr = -1) {
    segment_ids.push_back(node);
    pos[node] = (int) segment_ids.size()-1;
    top[node] = head;
    if (heavy[node]) {
        decompose(heavy[node], head, node);
    }
    for (int ch : tree[node]) {
        if (ch == pr || ch == heavy[node])
            continue;
        decompose(ch, ch, node);
    }
}

int lcaHLD(int a, int b) {
    for (; top[a] != top[b]; b = par[top[b]]) {
        if (depth[top[a]] > depth[top[b]])
            swap(a,b);
    }
    if (depth[a] > depth[b])
        swap(a,b);
    return a;
}

node queryWithLCA(int a, int lca, bool direct) {
    // go from a to lca if direct = true
    // go from lca to a if direct = false
    node_pair res;
    for (; top[a] != top[lca]; a = par[top[a]]) {
        res = mrg(get(pos[top[a]],pos[a]), res);
    }
    if (depth[a] >= depth[lca])
        res = mrg(get(pos[lca],pos[a]), res);
    return res[direct];
}

int query(int a, int b) {
    int lca = lcaHLD(a,b);
    node atoLCA = queryWithLCA(a, lca, true);
    upd(pos[lca], node());
    node LCAtob = queryWithLCA(b, lca, false);
    upd(pos[lca], c[lca]);
    node res = mrg(atoLCA, LCAtob);
    return res.ans;
}

int main() {
    SaveTime;
    int q; cin >> n >> q;
    for (int i = 1; i <= n; i++) {
        cin >> c[i];
    }
    for (int i = 0; i < n-1; i++) {
        int u,v; cin >> u >> v;
        tree[u].push_back(v);
        tree[v].push_back(u);
    }
    depth[1] = 0;
    par[1] = 1;
    dfs(1);
    decompose(1, 1);

    for (int i = 1; i <= n; i++) {
        int p = pos[i]; upd(p, c[i]);
    }

    while (q--) {
        int k,a,b; cin >> k >> a >> b;
        if (k == 0) {
            upd(pos[a], node(b));
            c[a] = b;
        }
        if (k == 1) {
            cout << query(a,b) << '\n';
        }
        if (k == 2) {
            cout << query(b,a) << '\n';
        }
        if (k == 3) {
            cout << max(query(a,b), query(b,a)) << '\n';
        }
    }
}

Problem :
https://codeforces.com/gym/567342/problem/A

⚠️ **GitHub.com Fallback** ⚠️