实现 Fruchterman 布局算法 - antvis/g-webgl-compute GitHub Wiki

问题背景

G6 DEMO

尽管节点数目并不多(本例中为 200+),但每个节点需要和其他节点进行大量计算,并且需要进行一定次数的迭代才能达到稳定(本例中为 8000 次)。 因此性能主要在计算而非渲染上。

在本文中,我们将尝试把这部分节点位置计算放在 GPU 侧,渲染仍使用 G6 原有的 Canvas 2D/SVG 技术。

image

创建计算管线

这里有三点需要说明:

  • 我们希望每一个线程组负责计算一个节点的位置,因此我们设置一个一维的线程网格,长度等于节点数目。
  • 和之前 Reduce 和向量加法的例子不同,布局计算需要运行多次达到稳定状态,这里我们设置 8000 次。
  • 由于 Shader 中我们需要遍历所有节点和每个节点的所有边,因此涉及到循环长度,受限于 Shader 语法对于循环长度的严格限制(必须为常量或常量表达式),这里我们传入运行时计算的常量,即节点数和所有节点包含的最大边数目。
const compute = this.world.createComputePipeline({
  shader: gCode,
  dispatch: [numParticles, 1, 1], // 线程数等同节点数目
  maxIteration: MAX_ITERATION, // 8000 次迭代
  onCompleted: (finalParticleData) => {
    // 使用 g-canvas 渲染,数据中包含了最终的节点位置
  }
});

// 省略其他变量设置
this.world.setBinding(
  compute,
  'MAX_EDGE_PER_VERTEX',
  this.maxEdgePerVetex,
);
this.world.setBinding(compute, 'VERTEX_COUNT', numParticles);

数据结构设计

详见 https://github.com/antvis/GWebGPUEngine/issues/1

Shader 编写

如果不熟悉该布局算法,可以参考 G6 的原版 TS 实现

简单来说在每次迭代中,每个节点位置需要通过斥力、引力和重力计算完成更新:

下面的代码应该很容易理解:

import { globalInvocationID } from 'g-webgpu';

const SPEED_DIVISOR = 800;
const MAX_EDGE_PER_VERTEX;
const VERTEX_COUNT;

// 每个线程组仅包含一个线程
@numthreads(1, 1, 1)
class Fruchterman {
  @in @out
  u_Data: vec4[];

  @in
  u_K: float;

  @in
  u_K2: float;

  @in
  u_Gravity: float;

  @in
  u_Speed: float;

  @in
  u_MaxDisplace: float;

  calcRepulsive(i: int, currentNode: vec4): vec2 {
    let dx = 0, dy = 0;
    for (let j = 0; j < VERTEX_COUNT; j++) {
      if (i != j + 1) {
        const nextNode = this.u_Data[j];
        const xDist = currentNode[0] - nextNode[0];
        const yDist = currentNode[1] - nextNode[1];
        const dist = sqrt(xDist * xDist + yDist * yDist) + 0.01;
        if (dist > 0.0) {
          const repulsiveF = this.u_K2 / dist;
          dx += xDist / dist * repulsiveF;
          dy += yDist / dist * repulsiveF;
        }
      }
    }
    return [dx, dy];
  }

  calcGravity(currentNode: vec4): vec2 {
    const d = sqrt(currentNode[0] * currentNode[0] + currentNode[1] * currentNode[1]);
    const gf = 0.01 * this.u_K * this.u_Gravity * d;
    return [gf * currentNode[0] / d, gf * currentNode[1] / d];
  }

  calcAttractive(currentNode: vec4): vec2 {
    let dx = 0, dy = 0;
    const arr_offset = int(floor(currentNode[2] + 0.5));
    const length = int(floor(currentNode[3] + 0.5));
    const node_buffer: vec4;
    for (let p = 0; p < MAX_EDGE_PER_VERTEX; p++) {
      if (p >= length) break;
      const arr_idx = arr_offset + p;
      // when arr_idx % 4 == 0 update currentNodedx_buffer
      const buf_offset = arr_idx - arr_idx / 4 * 4;
      if (p == 0 || buf_offset == 0) {
        node_buffer = this.u_Data[int(arr_idx / 4)];
      }
      const float_j = buf_offset == 0 ? node_buffer[0] :
                      buf_offset == 1 ? node_buffer[1] :
                      buf_offset == 2 ? node_buffer[2] :
                                        node_buffer[3];
      const nextNode = this.u_Data[int(float_j)];
      const xDist = currentNode[0] - nextNode[0];
      const yDist = currentNode[1] - nextNode[1];
      const dist = sqrt(xDist * xDist + yDist * yDist) + 0.01;
      const attractiveF = dist * dist / this.u_K;
      if (dist > 0.0) {
        dx -= xDist / dist * attractiveF;
        dy -= yDist / dist * attractiveF;
      }
    }
    return [dx, dy];
  }

  @main
  compute() {
    const i = globalInvocationID.x;
    const currentNode = this.u_Data[i];
    let dx = 0, dy = 0;

    if (i > VERTEX_COUNT) {
      this.u_Data[i] = currentNode;
      return;
    }

    // repulsive
    const repulsive = this.calcRepulsive(i, currentNode);
    dx += repulsive[0];
    dy += repulsive[1];

    // attractive
    const attractive = this.calcAttractive(currentNode);
    dx += attractive[0];
    dy += attractive[1];

    // gravity
    const gravity = this.calcGravity(currentNode);
    dx -= gravity[0];
    dy -= gravity[1];

    // speed
    dx *= this.u_Speed;
    dy *= this.u_Speed;

    // move
    const distLength = sqrt(dx * dx + dy * dy);
    if (distLength > 0.0) {
      const limitedDist = min(this.u_MaxDisplace * this.u_Speed, distLength);
      
      // 设置当前节点的最终位置
      this.u_Data[i] = [
        currentNode[0] + dx / distLength * limitedDist,
        currentNode[1] + dy / distLength * limitedDist,
        currentNode[2],
        currentNode[3]
      ];
    }
  }
}

Benchmarks

以下都在 Chrome Canary 中运行

计算时间 DEMO
CPU(WebWorker) 80s DEMO
运行时编译 + 计算 WebGL 5.05s
运行时编译 + 计算 WebGPU 2.5s
预编译 + 计算 WebGL 4.3s
预编译 + 计算 WebGPU 0.71s