<!DOCTYPE html>

<head>
  <title>WebGPU subgroups sample</title>
</head>

<body>
  <h1>WebGPU subgroups sample</h1>
  <p>
    This sample shows:
  <ul>
    <li>Checking whether the adapter supports the <tt>subgroups</tt> feature.
    <li>Requesting a device with the <tt>subgroups</tt> feature on the device.
    <li>Inspecting <tt>adapter.info.subgroupMinSize</tt> and
      <tt>adapter.info.subgroupMaxSize</tt>.
    <li>A shader that uses the <tt>subgroupExclusiveMul()</tt> built-in function to compute factorials
      of the subgroup invocation ID, without reading or writing memory to communicate intermediate results.
  </ul>
  <hr>
  <pre id="logs"></pre>
  </p>
</body>
/* CSS files add styling rules to your content */

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

h1 {
  margin-block-end: 0;
}
a {
  display: block;
  margin-block-end: 1.4em;
}

#logs {
  margin-top: 1em;
}

@media screen and (min-width: 640px) {
  body {
    margin: 2em auto;
    max-width: calc(640px - 2em);
  }
}
(async function () {
  if (!navigator.gpu) {
    log("This browser does not support WebGPU.");
    return;
  }

  const adapter = await navigator.gpu.requestAdapter();
  if (!adapter.features.has("subgroups")) {
    log('The "subgroups" feature is not supported.');
    return;
  }

  log(`- adapter.info.vendor = '${adapter.info.vendor}'`);
  log(`- adapter.info.architecture '${adapter.info.architecture}'`);
  log(`- adapter.info.subgroupMinSize = ${adapter.info.subgroupMinSize}`);
  log(`- adapter.info.subgroupMaxSize = ${adapter.info.subgroupMaxSize}`);

  const device = await adapter.requestDevice({
    requiredFeatures: ["subgroups"]
  });

  const workgroupSize = device.limits.maxComputeWorkgroupSizeX;
  log(
    `- workgroupSize = ${workgroupSize} (= device.limits.maxComputeWorkgroupSizeX)`
  );

  const numWorkgroups = 2; // Show at least 2 workgroups.
  const N = numWorkgroups * workgroupSize;

  const bufSize = N * Float32Array.BYTES_PER_ELEMENT; // 32-bit floats take up 4 bytes each
  const gpuBuffer = device.createBuffer({
    size: bufSize,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
  });
  // Create a buffer to copy the output to.
  const outputBuffer = device.createBuffer({
    size: bufSize,
    usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
  });

  const subgroupSizeBufferSize = Uint32Array.BYTES_PER_ELEMENT; // 32-bit unsigned integer takes up 4 bytes
  const subgroupSizeBuffer = device.createBuffer({
    size: subgroupSizeBufferSize,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
  });
  const subgroupSizeOutputBuffer = device.createBuffer({
    size: subgroupSizeBufferSize,
    usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
  });

  // Here's the compute shader.
  const shader = device.createShaderModule({
    code: `
    enable subgroups;

    @group(0) @binding(0) var<storage,read_write> factorials_buffer: array<f32>;
    @group(0) @binding(1) var<storage,read_write> subgroup_size: u32;

    override workgroupSize: u32;

    @compute @workgroup_size(workgroupSize)
    fn simple_factorial(
        @builtin(global_invocation_id) gid: vec3u,
        @builtin(subgroup_invocation_id) sid: u32,
        @builtin(subgroup_size) ssize: u32) {
      
      // Each invocation contributes 1 + its subgroup ID, so:
      //   invocation ID 0 contributes 1
      //   invocation ID 1 contributes 2
      //   invocation ID 2 contributes 3
      //   invocation ID 3 contributes 4
      //   invocation ID 4 contributes 5
      //
      // Calculate the prefix product, excluding its own element:
      //   ID 0:    1                = 1    == 0!
      //   ID 1:    1                = 1    == 1!
      //   ID 2:    1 * 2            = 2    == 2!
      //   ID 3:    1 * 2 * 3        = 6    == 3!
      //   ID 4:    1 * 2 * 3 * 4    = 24   == 4!
      //      ...
      // This is all done without reading or writing memory.
      let result = subgroupExclusiveMul(f32(1 + sid));

      // Write to the output.
      factorials_buffer[gid.x] = result;
      if gid.x == 0 { // avoid a data race.
        subgroup_size = ssize;
      }
    }
  `
  });

  // Do the factorial.
  const pipeline = device.createComputePipeline({
    layout: "auto",
    compute: { module: shader, constants: { workgroupSize } }
  });

  const bindGroup = device.createBindGroup({
    layout: pipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: gpuBuffer } },
      { binding: 1, resource: { buffer: subgroupSizeBuffer } }
    ]
  });

  const encoder = device.createCommandEncoder();
  const computeEncoder = encoder.beginComputePass();
  computeEncoder.setPipeline(pipeline);
  computeEncoder.setBindGroup(0, bindGroup);
  computeEncoder.dispatchWorkgroups(numWorkgroups);
  computeEncoder.end();
  encoder.copyBufferToBuffer(gpuBuffer, 0, outputBuffer, 0, bufSize);
  encoder.copyBufferToBuffer(
    subgroupSizeBuffer,
    0,
    subgroupSizeOutputBuffer,
    0,
    subgroupSizeBufferSize
  );
  const commands = encoder.finish();
  device.queue.submit([commands]);

  await outputBuffer.mapAsync(GPUMapMode.READ);
  const outputs = new Float32Array(outputBuffer.getMappedRange());

  await subgroupSizeOutputBuffer.mapAsync(GPUMapMode.READ);
  const subgroupSize = new Uint32Array(
    subgroupSizeOutputBuffer.getMappedRange()
  )[0];

  logWorkgroups(outputs, subgroupSize, N, workgroupSize);
})();

/* Utils */
function log(s) {
  document.getElementById("logs").innerHTML += s + "<br/>";
}

function logWorkgroups(o, subgroupSize, N, workgroupSize) {
  log(`<hr>Ran shader and got subgroup size = ${subgroupSize}`);
  log(`The output, grouped into workgroups, and one subgroup per line is:`);
  let line = "";
  for (let i = 0; i < N; i++) {
    if (i % workgroupSize == 0) {
      line += `<br>workgroup ${i / workgroupSize}:`;
    }
    if (i % subgroupSize == 0) {
      line += `<br>${String(i).padStart(3)}:  `;
    }
    line += ` ${o[i]}`;
  }
  log(line);
}

External CSS

This Pen doesn't use any external CSS resources.

External JavaScript

This Pen doesn't use any external JavaScript resources.