使用 JavaScript 对图像进行量化并提取主要颜色

723

前言

前段时间在 Halo应用市场 中遇到希望主题和插件的封面图背景色为封面图主色的问题,于是乎需要根据封面图提取主色就想到使用 K-Means 算法来提取。

在图像处理中,图像是由像素点构成的,每个像素点都有一个颜色值,颜色值通常由 RGB 三个分量组成。因此,我们可以将图像看作是一个由颜色值构成的点云,每个点代表一个像素点。

为了更好地理解,我们可以将图像的颜色值可视化为一个 Scatter 3D 图。在 Scatter 3D 图中,每个点的坐标由 RGB 三个分量组成,点的颜色与其坐标对应的颜色值相同。

图像的颜色值量化

以下面的图片为例

k-means-example-whzc.jpg

它的色值分布为如下的图像
k-means-figure_1.png
从上述 RGB 3D Scatter Plot 图如果将相似的颜色值归为一类可以看出图像大概有三种主色调蓝色、绿色和粉色:
k-means-figure_2.jpg
如果我们从三簇中各选一个中心,如以 A、B、C三点表示 A(50, 150, 200)B(240, 150, 200)C(50, 100, 50) 并将每个数据点分配到最近的中心所在的簇中这个过程称之为聚类而这个中心称之为聚类中心,这样就可以得到 K 个以聚类中心为坐标的主色值。而 K-Means 算法是一种常用的聚类算法,它的基本思想就是将数据集分成 K 个簇,每个簇的中心点称为聚类中心,将每个数据点分配到最近的聚类中心所在的簇中。

K-Means 算法的实现过程如下:

  1. 初始化聚类中心:随机选择 K 个点作为聚类中心。

  2. 分配数据点到最近的聚类中心所在的簇中:对于每个数据点,计算它与每个聚类中心的距离,将它分配到距离最近的聚类中心所在的簇中。

  3. 更新聚类中心:对于每个簇,计算它的所有数据点的平均值,将这个平均值作为新的聚类中心。

  4. 重复步骤 2 和步骤 3,直到聚类中心不再改变或达到最大迭代次数。

在图像处理中,我们可以将每个像素点的颜色值看作是一个三维向量,使用欧几里得距离计算两个颜色值之间的距离。对于每个像素点,我们将它分配到距离最近的聚类中心所在的簇中,然后将它的颜色值替换为所在簇的聚类中心的颜色值,如 A1(10, 140, 170) 以距离它最近的距离中心 A 的坐标表示即 A1 = A(50, 150, 200)。这样,我们就可以将图像中的颜色值进行量化,将相似的颜色值归为一类。

最后,我们可以根据聚类中心的颜色值,计算每个颜色值在图像中出现的次数,并按出现次数从大到小排序,取前几个颜色作为主要颜色。

<script>
  const img = new Image();
  img.src = "https://guqing-blog.oss-cn-hangzhou.aliyuncs.com/image.jpg";
  img.setAttribute("crossOrigin", "");

  img.onload = function () {
    const canvas = document.createElement("canvas");
    const ctx = canvas.getContext("2d");
    canvas.width = img.width;
    canvas.height = img.height;
    ctx.drawImage(img, 0, 0);

    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
    const data = imageData.data;

    const k = 3; // 聚类数
    const centers = quantize(data, k);
    console.log(centers)

    for (const color of centers) {
      const div = document.createElement("div");
      div.style.width = "50px";
      div.style.height = "50px";
      div.style.backgroundColor = color;
      document.body.appendChild(div);
    }
  };

  function quantize(data, k) {
    // 将颜色值转换为三维向量
    const vectors = [];
    for (let i = 0; i < data.length; i += 4) {
      vectors.push([data[i], data[i + 1], data[i + 2]]);
    }

    // 随机选择 K 个聚类中心
    const centers = [];
    for (let i = 0; i < k; i++) {
      centers.push(vectors[Math.floor(Math.random() * vectors.length)]);
    }

    // 迭代更新聚类中心
    let iterations = 0;
    while (iterations < 100) {
      // 分配数据点到最近的聚类中心所在的簇中
      const clusters = new Array(k).fill().map(() => []);
      for (let i = 0; i < vectors.length; i++) {
        let minDist = Infinity;
        let minIndex = 0;
        for (let j = 0; j < centers.length; j++) {
          const dist = distance(vectors[i], centers[j]);
          if (dist < minDist) {
            minDist = dist;
            minIndex = j;
          }
        }
        clusters[minIndex].push(vectors[i]);
      }

      // 更新聚类中心
      let converged = true;
      for (let i = 0; i < centers.length; i++) {
        const cluster = clusters[i];
        if (cluster.length > 0) {
          const newCenter = cluster
            .reduce((acc, cur) => [
              acc[0] + cur[0],
              acc[1] + cur[1],
              acc[2] + cur[2],
            ])
            .map((val) => val / cluster.length);
          if (!equal(centers[i], newCenter)) {
            centers[i] = newCenter;
            converged = false;
          }
        }
      }

      if (converged) {
        break;
      }

      iterations++;
    }

    // 将每个像素点的颜色值替换为所在簇的聚类中心的颜色值
    for (let i = 0; i < data.length; i += 4) {
      const vector = [data[i], data[i + 1], data[i + 2]];
      let minDist = Infinity;
      let minIndex = 0;
      for (let j = 0; j < centers.length; j++) {
        const dist = distance(vector, centers[j]);
        if (dist < minDist) {
          minDist = dist;
          minIndex = j;
        }
      }
      const center = centers[minIndex];
      data[i] = center[0];
      data[i + 1] = center[1];
      data[i + 2] = center[2];
    }

    // 计算每个颜色值在图像中出现的次数,并按出现次数从大到小排序
    const counts = {};
    for (let i = 0; i < data.length; i += 4) {
      const color = `rgb(${data[i]}, ${data[i + 1]}, ${data[i + 2]})`;
      counts[color] = counts[color] ? counts[color] + 1 : 1;
    }
    const sortedColors = Object.keys(counts).sort(
      (a, b) => counts[b] - counts[a]
    );

    // 取前 k 个颜色作为主要颜色
    return sortedColors.slice(0, k);
  }

  function distance(a, b) {
    return Math.sqrt(
      (a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2
    );
  }

  function equal(a, b) {
    return a[0] === b[0] && a[1] === b[1] && a[2] === b[2];
  }
</script>

自动选取 K 值

在实际应用中,我们可能不知道应该选择多少个聚类中心,即 K 值。一种常用的方法是使用 Gap 统计量法,它的基本思想是比较聚类结果与随机数据集的聚类结果之间的差异,选择使差异最大的 K 值。

Gap 统计量法的实现过程如下:

  1. 对原始数据集进行 K-Means 聚类,得到聚类结果。

  2. 生成 B 个随机数据集,对每个随机数据集进行 K-Means 聚类,得到聚类结果。

  3. 计算聚类结果与随机数据集聚类结果之间的差异,使用 Gap 统计量表示。

  4. 选择使 Gap 统计量最大的 K 值。

下面是使用 JavaScript 实现 Gap 统计量法的示例代码:

function gap(data, maxK) {
  const gaps = [];
  for (let k = 1; k <= maxK; k++) {
    const quantized = quantize(data, k);
    const gap = logWk(quantized) - logWk(randomData(data.length));
    gaps.push(gap);
  }
  const maxGap = Math.max(...gaps);
  return gaps.findIndex((gap) => gap === maxGap) + 1;
}

function logWk(quantized) {
  const counts = {};
  for (let i = 0; i < quantized.length; i++) {
    counts[quantized[i]] = counts[quantized[i]] ? counts[quantized[i]] + 1 : 1;
  }
  const n = quantized.length;
  const k = Object.keys(counts).length;
  const wk = Object.values(counts).reduce((acc, cur) => acc + cur * Math.log(cur / n), 0);
  return Math.log(n) + wk / n;
}

function randomData(n) {
  const data = new Uint8ClampedArray(n * 4);
  for (let i = 0; i < data.length; i++) {
    data[i] = Math.floor(Math.random() * 256);
  }
  return data;
}

使用:

const k = gap(data, 10)
// const k = 3; // 聚类数
const centers = quantize(data, k);

好吧,挺麻烦的,最终直接将封面图再作为背景图添加 backdrop-filter 来实现了 🤐。

附录

Python 绘制图片 Scatter 3D:

import matplotlib.pyplot as plt
import numpy as np

from mpl_toolkits.mplot3d import Axes3D

def visualize_rgb(image_path):
    image = plt.imread(image_path)
    height, width, _ = image.shape

    # Reshape the image array to a 2D array of pixels
    pixels = np.reshape(image, (height * width, 3))

    # Extract RGB values
    red = pixels[:, 0]
    green = pixels[:, 1]
    blue = pixels[:, 2]

    # Create 3D scatter plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(red, green, blue, c=pixels/255, alpha=0.3)

    ax.set_xlabel('Red')
    ax.set_ylabel('Green')
    ax.set_zlabel('Blue')
    ax.set_title('RGB 3D Scatter Plot')

    plt.show()

# 调用函数并传入图像路径
visualize_rgb('image.jpg')