程序设计与计算思维¶

Computer Programming and Computational Thinking

第 8 讲:接缝裁剪、数据结构与 SVD¶

2025—2026学年度春季学期

清华大学 韩文弢

准备工作¶

In [1]:
from urllib.request import urlretrieve

import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.linalg import svd
from ipywidgets import interact, IntSlider

plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC']

第一部分:接缝裁剪(Seam Carving)算法¶

问题引入¶

问题:希望缩小图像的某个维度,同时保持主要物体的大小和比例不变。

降采样的问题:会导致图像比例失调。

如何解决?

接缝裁剪的思想¶

接缝裁剪(Seam Carving) 的思想是通过移除图像中"最不重要"的部分来缩小图像,但不改变图像中物体的大小,也就是要移除图像中的"空白区域"。

核心思想:

  • 试图找到一条接缝(seam),即从图像顶部到底部的连通像素路径,按照某种度量标准,这些像素是"最不重要"的。
  • 然后移除该接缝中的像素,得到一幅宽度减少一个像素的图像。
In [2]:
img_path, _ = urlretrieve('https://pacman.cs.tsinghua.edu.cn/~hanwentao/cpct/08/The-Persistence-of-Memory-salvador-deli-painting.jpg')
img = plt.imread(img_path)
plt.imshow(img)
plt.title('原始图像')
plt.axis('off')
plt.show()
No description has been provided for this image

能量计算¶

需要指定像素重要性的概念。接缝将累加接缝上像素的重要性,并选择使总重要性最小化的接缝。

可以将重要性定义为"像素位于边缘的程度"——即能量。使用 Sobel 边缘检测滤波器:

$$G_x = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix} \star A$$

$$G_y = \begin{bmatrix} 1 & 2 & 1 \\ 0 & 0 & 0 \\ -1 & -2 & -1 \end{bmatrix} \star A$$

$$G_{total} = \sqrt{G_x^2 + G_y^2}$$

In [3]:
# Sobel 算子用于边缘检测
# Sx: 检测水平方向边缘(左右亮度变化)
Sx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=np.float64)
# Sy: 检测垂直方向边缘(上下亮度变化)
Sy = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=np.float64)
print("Sx (水平方向):\n", Sx)
print("\nSy (垂直方向):\n", Sy)
Sx (水平方向):
 [[ 1.  0. -1.]
 [ 2.  0. -2.]
 [ 1.  0. -1.]]

Sy (垂直方向):
 [[ 1.  2.  1.]
 [ 0.  0.  0.]
 [-1. -2. -1.]]
In [4]:
def brightness(rgb):
    """将 RGB 图像转换为灰度亮度图
    
    使用加权平均(符合人眼对不同颜色的敏感度):
    绿色权重最大(59%),红色次之(30%),蓝色最小(11%)
    """
    return 0.3 * rgb[:,:,0] + 0.59 * rgb[:,:,1] + 0.11 * rgb[:,:,2]

def edgeness(img):
    """计算图像的能量图(边缘强度)
    
    能量定义为梯度幅值,用于衡量像素的"重要性"
    边缘区域能量高(重要),平坦区域能量低(不重要)
    """
    b = brightness(img)  # 转换为灰度图
    grad_x = ndimage.convolve(b, Sx, mode='nearest')  # 水平梯度
    grad_y = ndimage.convolve(b, Sy, mode='nearest')  # 垂直梯度
    return np.sqrt(grad_x**2 + grad_y**2)  # 梯度幅值
In [5]:
edge_result = edgeness(img)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(img)
axes[0].set_title('原始图像')
axes[0].axis('off')
axes[1].imshow(edge_result, cmap='gray')
axes[1].set_title('能量图')
axes[1].axis('off')
plt.show()
No description has been provided for this image

动态规划找最小能量路径¶

在向下走的每一步,路径可以向西南、南或东南方向移动。

可以使用动态规划计算从每个像素到底部的最小累积能量。

In [6]:
def least_edgy(E):
    """使用动态规划计算从每个像素到底部的最小能量路径
    
    参数:
        E: 能量图 (m×n 矩阵)
    
    返回:
        least_E: 最小累积能量图,每个像素存储从该像素到底部的最小累积能量
        dirs: 方向图,记录每一步应该向哪个方向走(-1:左, 0:中, 1:右)
    """
    m, n = E.shape
    least_E = np.zeros((m, n))  # 存储最小累积能量
    dirs = np.zeros((m, n), dtype=np.int32)  # 存储方向信息
    
    # 最后一行:累积能量就是自身能量
    least_E[-1, :] = E[-1, :]
    
    # 从倒数第二行开始,自底向上计算
    for i in range(m - 2, -1, -1):
        for j in range(n):
            # 考虑下一行相邻的三个像素(左下、正下、右下)
            j1 = max(0, j - 1)  # 左边界
            j2 = min(j + 2, n)  # 右边界
            candidates = least_E[i + 1, j1:j2]  # 候选像素的累积能量
            min_idx = np.argmin(candidates)  # 找最小值的索引
            # 当前像素的累积能量 = 自身能量 + 下方最小累积能量
            least_E[i, j] = candidates[min_idx] + E[i, j]
            # 记录方向(-1表示左下,0表示正下,1表示右下)
            dirs[i, j] = min_idx - 1
    return least_E, dirs
In [7]:
least_e, dirs = least_edgy(edgeness(img))
plt.imshow(least_e)
plt.title('最小累积能量图')
plt.axis('off')
plt.show()
No description has been provided for this image

移除接缝¶

In [8]:
def get_seam_at(dirs, j):
    """根据方向图回溯,找到从顶部第 j 列开始的接缝路径
    
    参数:
        dirs: 方向图(由 least_edgy 返回)
        j: 起始列位置
    
    返回:
        接缝路径:[(行, 列), ...] 坐标列表
    """
    m, n = dirs.shape
    js = np.zeros(m, dtype=np.int32)  # 存储每一行的列位置
    js[0] = j  # 从第一行的第 j 列开始
    
    # 根据方向图,从上到下追踪路径
    for i in range(1, m):
        # 下一行的列位置 = 当前列位置 + 方向偏移量
        js[i] = max(0, min(js[i-1] + dirs[i-1, js[i-1]], n - 1))
    
    return list(zip(range(m), js))

def mark_path(img, path, color=(255, 0, 255)):
    """在图像上标记接缝路径(用于可视化)
    
    参数:
        img: 原始图像
        path: 接缝路径
        color: 标记颜色(默认为洋红色)
    """
    img_marked = img.copy()
    n = img.shape[1]
    for i, j in path:
        # 标记接缝像素及其左右邻居,使路径更明显
        for j_p in range(max(0, j - 1), min(j + 2, n)):
            img_marked[i, j_p] = color
    return img_marked

def rm_path(img, path):
    """从图像中移除接缝路径
    
    参数:
        img: 原始图像
        path: 要移除的接缝路径
    
    返回:
        宽度减少 1 像素的新图像
    """
    # 创建比原图窄 1 像素的新图像
    if len(img.shape) == 3:
        img_new = np.zeros((img.shape[0], img.shape[1] - 1, img.shape[2]), dtype=img.dtype)
    else:
        img_new = np.zeros((img.shape[0], img.shape[1] - 1), dtype=img.dtype)
    
    # 对每一行,移除接缝上的像素
    for i, j in path:
        img_new[i, :j] = img[i, :j]      # 复制接缝左侧的像素
        img_new[i, j:] = img[i, j+1:]    # 复制接缝右侧的像素
    return img_new
In [9]:
def show_seam(start_column):
    path = get_seam_at(dirs, start_column)
    plt.imshow(mark_path(img, path))
    plt.title(f'从列 {start_column} 开始的接缝')
    plt.axis('off')
    plt.show()

show_seam(img.shape[1] // 2)
No description has been provided for this image
In [10]:
interact(show_seam, start_column=IntSlider(min=0, max=img.shape[1]-1, value=0));
interactive(children=(IntSlider(value=0, description='start_column', max=787), Output()), _dom_classes=('widge…
In [11]:
def shrink_n(img, n):
    """移除 n 条接缝,缩小图像宽度
    
    接缝裁剪的核心流程:
    1. 计算能量图
    2. 动态规划找最小能量路径
    3. 移除该路径
    4. 重复 n 次
    
    参数:
        img: 原始图像
        n: 要移除的接缝数量
    
    返回:
        每一步结果的图像列表(用于动画演示)
    """
    current_img = img.copy()
    e = edgeness(current_img)  # 计算初始能量图
    imgs = [current_img.copy()]  # 保存每一步的结果
    
    for i in range(n):
        # 1. 动态规划计算最小能量路径
        least_E, dirs = least_edgy(e)
        
        # 2. 找到能量最小的起始位置
        min_j = np.argmin(least_E[0, :])
        
        # 3. 回溯得到接缝路径
        seam = get_seam_at(dirs, min_j)
        
        # 4. 从图像和能量图中移除接缝
        current_img = rm_path(current_img, seam)
        e = rm_path(e, seam)
        
        imgs.append(current_img.copy())
        print(f'\r已移除 {i+1}/{n} 个接缝', end='')
    
    return imgs

n_seams = min(150, img.shape[1] // 2)
carved = shrink_n(img, n_seams)
已移除 150/150 个接缝
In [12]:
def show_carved_result(n):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].imshow(img)
    axes[0].set_title(f'原始图像 ({img.shape[1]}×{img.shape[0]})')
    axes[0].axis('off')
    axes[1].imshow(carved[n])
    axes[1].set_title(f'移除 {n} 个接缝后 ({carved[n].shape[1]}×{carved[n].shape[0]})')
    axes[1].axis('off')
    plt.show()

show_carved_result(len(carved)-1)
No description has been provided for this image
In [13]:
interact(show_carved_result, n=IntSlider(min=0, max=len(carved)-1, value=0));
interactive(children=(IntSlider(value=0, description='n', max=150), Output()), _dom_classes=('widget-interact'…

接缝裁剪总结¶

  1. 计算能量图:使用 Sobel 算子检测边缘
  2. 动态规划:从底部向上计算最小累积能量
  3. 找接缝:找到能量最小的路径
  4. 移除接缝:图像宽度减 1
  5. 重复:直到达到目标宽度

第二部分:特殊形态向量和矩阵¶

One-hot 向量(独热向量)¶

一个 one-hot 向量有一个单独的"热"元素,即在一片零中有一个单独的 1。

In [14]:
my_one_hot_vector = np.array([0, 1, 0, 0, 0, 0])
print(my_one_hot_vector)
[0 1 0 0 0 0]

你需要多少"信息"来表示一个 one-hot 向量?是 $n$ 个数字,还是两个(长度和热位置)?

Python 自定义类¶

可以在 Python 中创建自己的新类型。

In [15]:
class OneHot:
    """简易的独热向量"""
    
    def __init__(self, n, k):
        self.n = n
        self.k = k
    
    def __len__(self):
        return self.n
    
    def __getitem__(self, i):
        return int(self.k == i)
    
    def __repr__(self):
        return f"OneHot(n={self.n}, k={self.k})"
    
    def to_array(self):
        arr = np.zeros(self.n, dtype=int)
        arr[self.k] = 1
        return arr
In [16]:
my_one_hot = OneHot(6, 2)
print(f"内部表示: {my_one_hot}")
print(f"向量形式: {my_one_hot.to_array()}")
内部表示: OneHot(n=6, k=2)
向量形式: [0 0 1 0 0 0]

查看对象内部¶

__dict__ 显示存储在对象内部的属性:

In [17]:
print("存储的属性:")
for key, value in my_one_hot.__dict__.items():
    print(f"  {key}: {value}")
存储的属性:
  n: 6
  k: 2

对角矩阵¶

In [18]:
D = np.diag([5, 6, -10])
print(D)
[[  5   0   0]
 [  0   6   0]
 [  0   0 -10]]
In [19]:
print(f"密集矩阵存储: {D.nbytes} bytes")
print(f"对角线元素存储: {np.array([5, 6, -10]).nbytes} bytes")
密集矩阵存储: 72 bytes
对角线元素存储: 24 bytes

稀疏矩阵¶

一个稀疏矩阵是有很多零的矩阵。

In [20]:
from scipy.sparse import csr_matrix

dense_M = np.array([[0, 0, 9], [0, 0, 0], [12, 0, 4]])
M_sparse = csr_matrix(dense_M)
print("稀疏矩阵:")
print(M_sparse)
稀疏矩阵:
<Compressed Sparse Row sparse matrix of dtype 'int64'
	with 3 stored elements and shape (3, 3)>
  Coords	Values
  (0, 2)	9
  (2, 0)	12
  (2, 2)	4
In [21]:
print("CSR 格式内部存储:")
print(f"  data: {M_sparse.data}")
print(f"  indices: {M_sparse.indices}")
print(f"  indptr: {M_sparse.indptr}")
CSR 格式内部存储:
  data: [ 9 12  4]
  indices: [2 0 2]
  indptr: [0 1 1 3]

乘法表(外积)¶

In [22]:
mult_table = np.outer(range(1, 10), range(1, 10))
print(mult_table)
[[ 1  2  3  4  5  6  7  8  9]
 [ 2  4  6  8 10 12 14 16 18]
 [ 3  6  9 12 15 18 21 24 27]
 [ 4  8 12 16 20 24 28 32 36]
 [ 5 10 15 20 25 30 35 40 45]
 [ 6 12 18 24 30 36 42 48 54]
 [ 7 14 21 28 35 42 49 56 63]
 [ 8 16 24 32 40 48 56 64 72]
 [ 9 18 27 36 45 54 63 72 81]]
In [23]:
def show_multiplication_table(k):
    table = np.outer(range(1, k+1), range(1, k+1))
    plt.imshow(table, cmap='Reds')
    plt.xticks(range(k), range(1, k+1))
    plt.yticks(range(k), range(1, k+1))
    plt.title(f'{k}×{k} 乘法表')

show_multiplication_table(9)
No description has been provided for this image

乘法表只需要两个向量就能表示,而不是 $n^2$ 个元素。

分解乘法表¶

给定矩阵,能否找到它是哪两个向量的外积?

In [24]:
def factor(mult_table):
    mult_table = np.array(mult_table)
    v = mult_table[:, 0]
    w = mult_table[0, :]
    if v[0] != 0:
        w = w / v[0]
    if np.allclose(np.outer(v, w), mult_table):
        return v, w
    else:
        raise ValueError("输入不是乘法表")
In [25]:
v, w = factor(np.outer([1, 2, 3], [2, 2, 2]))
print(f"v = {v}, w = {w}")
v = [2 4 6], w = [1. 1. 1.]
In [26]:
try:
    factor(np.random.rand(2, 2))
except ValueError as e:
    print(f"错误: {e}")
错误: 输入不是乘法表

如果矩阵是两个或更多的乘法表的和呢?可以用奇异值分解(SVD)。


第三部分:奇异值分解(SVD)与图像压缩¶

SVD:发现矩阵的结构¶

SVD 将矩阵分解为三个矩阵的乘积:$A = U \Sigma V^T$

可以看作是将矩阵分解为多个外积(乘法表)的和。

In [27]:
np.random.seed(42)
A = np.outer(np.random.rand(3), np.random.rand(3)) + np.outer(np.random.rand(3), np.random.rand(3))
print("两个外积的和:\n", A)
两个外积的和:
 [[0.26534903 0.05963086 0.11476207]
 [1.18246876 0.16615895 0.988419  ]
 [0.86384744 0.12657835 0.69721442]]
In [28]:
U, S, Vt = svd(A)
reconstructed = np.outer(U[:, 0], Vt[0, :] * S[0]) + np.outer(U[:, 1], Vt[1, :] * S[1])
print("SVD 重建:\n", reconstructed)
print(f"\n与原矩阵的差异: {np.max(np.abs(A - reconstructed)):.10f}")
SVD 重建:
 [[0.26534903 0.05963086 0.11476207]
 [1.18246876 0.16615895 0.988419  ]
 [0.86384744 0.12657835 0.69721442]]

与原矩阵的差异: 0.0000000000

SVD 图像压缩¶

可以使用前 $k$ 个奇异值近似图像。

In [29]:
image_path, _ = urlretrieve('https://pacman.cs.tsinghua.edu.cn/~hanwentao/cpct/08/tree.png')
image = plt.imread(image_path)[:, :, :3]
plt.imshow(image)
plt.title('原始图像')
plt.axis('off')
Out[29]:
(np.float64(-0.5), np.float64(357.5), np.float64(199.5), np.float64(-0.5))
No description has been provided for this image
In [30]:
pr, pg, pb = image[:, :, 0], image[:, :, 1], image[:, :, 2]
Ur, Sr, Vtr = svd(pr, full_matrices=False)
Ug, Sg, Vtg = svd(pg, full_matrices=False)
Ub, Sb, Vtb = svd(pb, full_matrices=False)
print(f"奇异值数量: {len(Sr)}")
奇异值数量: 200
In [31]:
def show_svd_approximation(n):
    pr_approx = sum(np.outer(Ur[:, i], Vtr[i, :]) * Sr[i] for i in range(min(n, len(Sr))))
    pg_approx = sum(np.outer(Ug[:, i], Vtg[i, :]) * Sg[i] for i in range(min(n, len(Sg))))
    pb_approx = sum(np.outer(Ub[:, i], Vtb[i, :]) * Sb[i] for i in range(min(n, len(Sb))))
    img_approx = np.clip(np.stack([pr_approx, pg_approx, pb_approx], axis=2), 0, 1)
    
    original_size = image.size
    compressed_size = n * (Ur.shape[0] + Vtr.shape[1] + 1) * 3
    ratio = original_size / compressed_size
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].imshow(image)
    axes[0].set_title(f'原始图像 ({original_size:,} 像素)')
    axes[0].axis('off')
    axes[1].imshow(img_approx)
    axes[1].set_title(f'SVD 近似 (n={n}, 压缩比 {ratio:.1f}x)')
    axes[1].axis('off')
    plt.show()

show_svd_approximation(50)
No description has been provided for this image
In [32]:
interact(show_svd_approximation, n=IntSlider(min=1, max=200, value=50));
interactive(children=(IntSlider(value=50, description='n', max=200, min=1), Output()), _dom_classes=('widget-i…

本讲小结¶

本讲探讨了数据中的结构以及如何利用它:

  1. 接缝裁剪:通过动态规划找到最小能量路径,智能缩放图像
  2. 特殊形态向量和矩阵:one-hot 向量、对角矩阵、稀疏矩阵、乘法表
  3. Python 对象:使用 class 定义自定义数据结构
  4. SVD:发现矩阵隐藏结构,用于图像压缩

学到的语法¶

  • class:定义自定义数据结构
  • __dict__:查看对象内部属性
  • raise:抛出异常
  • np.diag:创建对角矩阵
  • scipy.sparse:稀疏矩阵
  • scipy.linalg.svd:奇异值分解