准备工作¶
In [1]:
from urllib.request import urlretrieve
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.linalg import svd
from ipywidgets import interact, IntSlider
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC']
第一部分:接缝裁剪(Seam Carving)算法¶
接缝裁剪的思想¶
接缝裁剪(Seam Carving) 的思想是通过移除图像中"最不重要"的部分来缩小图像,但不改变图像中物体的大小,也就是要移除图像中的"空白区域"。
核心思想:
- 试图找到一条接缝(seam),即从图像顶部到底部的连通像素路径,按照某种度量标准,这些像素是"最不重要"的。
- 然后移除该接缝中的像素,得到一幅宽度减少一个像素的图像。
In [2]:
img_path, _ = urlretrieve('https://pacman.cs.tsinghua.edu.cn/~hanwentao/cpct/08/The-Persistence-of-Memory-salvador-deli-painting.jpg')
img = plt.imread(img_path)
plt.imshow(img)
plt.title('原始图像')
plt.axis('off')
plt.show()
能量计算¶
需要指定像素重要性的概念。接缝将累加接缝上像素的重要性,并选择使总重要性最小化的接缝。
可以将重要性定义为"像素位于边缘的程度"——即能量。使用 Sobel 边缘检测滤波器:
$$G_x = \begin{bmatrix} 1 & 0 & -1 \\ 2 & 0 & -2 \\ 1 & 0 & -1 \end{bmatrix} \star A$$
$$G_y = \begin{bmatrix} 1 & 2 & 1 \\ 0 & 0 & 0 \\ -1 & -2 & -1 \end{bmatrix} \star A$$
$$G_{total} = \sqrt{G_x^2 + G_y^2}$$
In [3]:
# Sobel 算子用于边缘检测
# Sx: 检测水平方向边缘(左右亮度变化)
Sx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=np.float64)
# Sy: 检测垂直方向边缘(上下亮度变化)
Sy = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=np.float64)
print("Sx (水平方向):\n", Sx)
print("\nSy (垂直方向):\n", Sy)
Sx (水平方向): [[ 1. 0. -1.] [ 2. 0. -2.] [ 1. 0. -1.]] Sy (垂直方向): [[ 1. 2. 1.] [ 0. 0. 0.] [-1. -2. -1.]]
In [4]:
def brightness(rgb):
"""将 RGB 图像转换为灰度亮度图
使用加权平均(符合人眼对不同颜色的敏感度):
绿色权重最大(59%),红色次之(30%),蓝色最小(11%)
"""
return 0.3 * rgb[:,:,0] + 0.59 * rgb[:,:,1] + 0.11 * rgb[:,:,2]
def edgeness(img):
"""计算图像的能量图(边缘强度)
能量定义为梯度幅值,用于衡量像素的"重要性"
边缘区域能量高(重要),平坦区域能量低(不重要)
"""
b = brightness(img) # 转换为灰度图
grad_x = ndimage.convolve(b, Sx, mode='nearest') # 水平梯度
grad_y = ndimage.convolve(b, Sy, mode='nearest') # 垂直梯度
return np.sqrt(grad_x**2 + grad_y**2) # 梯度幅值
In [5]:
edge_result = edgeness(img)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(img)
axes[0].set_title('原始图像')
axes[0].axis('off')
axes[1].imshow(edge_result, cmap='gray')
axes[1].set_title('能量图')
axes[1].axis('off')
plt.show()
In [6]:
def least_edgy(E):
"""使用动态规划计算从每个像素到底部的最小能量路径
参数:
E: 能量图 (m×n 矩阵)
返回:
least_E: 最小累积能量图,每个像素存储从该像素到底部的最小累积能量
dirs: 方向图,记录每一步应该向哪个方向走(-1:左, 0:中, 1:右)
"""
m, n = E.shape
least_E = np.zeros((m, n)) # 存储最小累积能量
dirs = np.zeros((m, n), dtype=np.int32) # 存储方向信息
# 最后一行:累积能量就是自身能量
least_E[-1, :] = E[-1, :]
# 从倒数第二行开始,自底向上计算
for i in range(m - 2, -1, -1):
for j in range(n):
# 考虑下一行相邻的三个像素(左下、正下、右下)
j1 = max(0, j - 1) # 左边界
j2 = min(j + 2, n) # 右边界
candidates = least_E[i + 1, j1:j2] # 候选像素的累积能量
min_idx = np.argmin(candidates) # 找最小值的索引
# 当前像素的累积能量 = 自身能量 + 下方最小累积能量
least_E[i, j] = candidates[min_idx] + E[i, j]
# 记录方向(-1表示左下,0表示正下,1表示右下)
dirs[i, j] = min_idx - 1
return least_E, dirs
In [7]:
least_e, dirs = least_edgy(edgeness(img))
plt.imshow(least_e)
plt.title('最小累积能量图')
plt.axis('off')
plt.show()
移除接缝¶
In [8]:
def get_seam_at(dirs, j):
"""根据方向图回溯,找到从顶部第 j 列开始的接缝路径
参数:
dirs: 方向图(由 least_edgy 返回)
j: 起始列位置
返回:
接缝路径:[(行, 列), ...] 坐标列表
"""
m, n = dirs.shape
js = np.zeros(m, dtype=np.int32) # 存储每一行的列位置
js[0] = j # 从第一行的第 j 列开始
# 根据方向图,从上到下追踪路径
for i in range(1, m):
# 下一行的列位置 = 当前列位置 + 方向偏移量
js[i] = max(0, min(js[i-1] + dirs[i-1, js[i-1]], n - 1))
return list(zip(range(m), js))
def mark_path(img, path, color=(255, 0, 255)):
"""在图像上标记接缝路径(用于可视化)
参数:
img: 原始图像
path: 接缝路径
color: 标记颜色(默认为洋红色)
"""
img_marked = img.copy()
n = img.shape[1]
for i, j in path:
# 标记接缝像素及其左右邻居,使路径更明显
for j_p in range(max(0, j - 1), min(j + 2, n)):
img_marked[i, j_p] = color
return img_marked
def rm_path(img, path):
"""从图像中移除接缝路径
参数:
img: 原始图像
path: 要移除的接缝路径
返回:
宽度减少 1 像素的新图像
"""
# 创建比原图窄 1 像素的新图像
if len(img.shape) == 3:
img_new = np.zeros((img.shape[0], img.shape[1] - 1, img.shape[2]), dtype=img.dtype)
else:
img_new = np.zeros((img.shape[0], img.shape[1] - 1), dtype=img.dtype)
# 对每一行,移除接缝上的像素
for i, j in path:
img_new[i, :j] = img[i, :j] # 复制接缝左侧的像素
img_new[i, j:] = img[i, j+1:] # 复制接缝右侧的像素
return img_new
In [9]:
def show_seam(start_column):
path = get_seam_at(dirs, start_column)
plt.imshow(mark_path(img, path))
plt.title(f'从列 {start_column} 开始的接缝')
plt.axis('off')
plt.show()
show_seam(img.shape[1] // 2)
In [10]:
interact(show_seam, start_column=IntSlider(min=0, max=img.shape[1]-1, value=0));
interactive(children=(IntSlider(value=0, description='start_column', max=787), Output()), _dom_classes=('widge…
In [11]:
def shrink_n(img, n):
"""移除 n 条接缝,缩小图像宽度
接缝裁剪的核心流程:
1. 计算能量图
2. 动态规划找最小能量路径
3. 移除该路径
4. 重复 n 次
参数:
img: 原始图像
n: 要移除的接缝数量
返回:
每一步结果的图像列表(用于动画演示)
"""
current_img = img.copy()
e = edgeness(current_img) # 计算初始能量图
imgs = [current_img.copy()] # 保存每一步的结果
for i in range(n):
# 1. 动态规划计算最小能量路径
least_E, dirs = least_edgy(e)
# 2. 找到能量最小的起始位置
min_j = np.argmin(least_E[0, :])
# 3. 回溯得到接缝路径
seam = get_seam_at(dirs, min_j)
# 4. 从图像和能量图中移除接缝
current_img = rm_path(current_img, seam)
e = rm_path(e, seam)
imgs.append(current_img.copy())
print(f'\r已移除 {i+1}/{n} 个接缝', end='')
return imgs
n_seams = min(150, img.shape[1] // 2)
carved = shrink_n(img, n_seams)
已移除 150/150 个接缝
In [12]:
def show_carved_result(n):
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(img)
axes[0].set_title(f'原始图像 ({img.shape[1]}×{img.shape[0]})')
axes[0].axis('off')
axes[1].imshow(carved[n])
axes[1].set_title(f'移除 {n} 个接缝后 ({carved[n].shape[1]}×{carved[n].shape[0]})')
axes[1].axis('off')
plt.show()
show_carved_result(len(carved)-1)
In [13]:
interact(show_carved_result, n=IntSlider(min=0, max=len(carved)-1, value=0));
interactive(children=(IntSlider(value=0, description='n', max=150), Output()), _dom_classes=('widget-interact'…
接缝裁剪总结¶
- 计算能量图:使用 Sobel 算子检测边缘
- 动态规划:从底部向上计算最小累积能量
- 找接缝:找到能量最小的路径
- 移除接缝:图像宽度减 1
- 重复:直到达到目标宽度
第二部分:特殊形态向量和矩阵¶
One-hot 向量(独热向量)¶
一个 one-hot 向量有一个单独的"热"元素,即在一片零中有一个单独的 1。
In [14]:
my_one_hot_vector = np.array([0, 1, 0, 0, 0, 0])
print(my_one_hot_vector)
[0 1 0 0 0 0]
你需要多少"信息"来表示一个 one-hot 向量?是 $n$ 个数字,还是两个(长度和热位置)?
Python 自定义类¶
可以在 Python 中创建自己的新类型。
In [15]:
class OneHot:
"""简易的独热向量"""
def __init__(self, n, k):
self.n = n
self.k = k
def __len__(self):
return self.n
def __getitem__(self, i):
return int(self.k == i)
def __repr__(self):
return f"OneHot(n={self.n}, k={self.k})"
def to_array(self):
arr = np.zeros(self.n, dtype=int)
arr[self.k] = 1
return arr
In [16]:
my_one_hot = OneHot(6, 2)
print(f"内部表示: {my_one_hot}")
print(f"向量形式: {my_one_hot.to_array()}")
内部表示: OneHot(n=6, k=2) 向量形式: [0 0 1 0 0 0]
查看对象内部¶
__dict__ 显示存储在对象内部的属性:
In [17]:
print("存储的属性:")
for key, value in my_one_hot.__dict__.items():
print(f" {key}: {value}")
存储的属性: n: 6 k: 2
对角矩阵¶
In [18]:
D = np.diag([5, 6, -10])
print(D)
[[ 5 0 0] [ 0 6 0] [ 0 0 -10]]
In [19]:
print(f"密集矩阵存储: {D.nbytes} bytes")
print(f"对角线元素存储: {np.array([5, 6, -10]).nbytes} bytes")
密集矩阵存储: 72 bytes 对角线元素存储: 24 bytes
稀疏矩阵¶
一个稀疏矩阵是有很多零的矩阵。
In [20]:
from scipy.sparse import csr_matrix
dense_M = np.array([[0, 0, 9], [0, 0, 0], [12, 0, 4]])
M_sparse = csr_matrix(dense_M)
print("稀疏矩阵:")
print(M_sparse)
稀疏矩阵: <Compressed Sparse Row sparse matrix of dtype 'int64' with 3 stored elements and shape (3, 3)> Coords Values (0, 2) 9 (2, 0) 12 (2, 2) 4
In [21]:
print("CSR 格式内部存储:")
print(f" data: {M_sparse.data}")
print(f" indices: {M_sparse.indices}")
print(f" indptr: {M_sparse.indptr}")
CSR 格式内部存储: data: [ 9 12 4] indices: [2 0 2] indptr: [0 1 1 3]
乘法表(外积)¶
In [22]:
mult_table = np.outer(range(1, 10), range(1, 10))
print(mult_table)
[[ 1 2 3 4 5 6 7 8 9] [ 2 4 6 8 10 12 14 16 18] [ 3 6 9 12 15 18 21 24 27] [ 4 8 12 16 20 24 28 32 36] [ 5 10 15 20 25 30 35 40 45] [ 6 12 18 24 30 36 42 48 54] [ 7 14 21 28 35 42 49 56 63] [ 8 16 24 32 40 48 56 64 72] [ 9 18 27 36 45 54 63 72 81]]
In [23]:
def show_multiplication_table(k):
table = np.outer(range(1, k+1), range(1, k+1))
plt.imshow(table, cmap='Reds')
plt.xticks(range(k), range(1, k+1))
plt.yticks(range(k), range(1, k+1))
plt.title(f'{k}×{k} 乘法表')
show_multiplication_table(9)
乘法表只需要两个向量就能表示,而不是 $n^2$ 个元素。
分解乘法表¶
给定矩阵,能否找到它是哪两个向量的外积?
In [24]:
def factor(mult_table):
mult_table = np.array(mult_table)
v = mult_table[:, 0]
w = mult_table[0, :]
if v[0] != 0:
w = w / v[0]
if np.allclose(np.outer(v, w), mult_table):
return v, w
else:
raise ValueError("输入不是乘法表")
In [25]:
v, w = factor(np.outer([1, 2, 3], [2, 2, 2]))
print(f"v = {v}, w = {w}")
v = [2 4 6], w = [1. 1. 1.]
In [26]:
try:
factor(np.random.rand(2, 2))
except ValueError as e:
print(f"错误: {e}")
错误: 输入不是乘法表
如果矩阵是两个或更多的乘法表的和呢?可以用奇异值分解(SVD)。
第三部分:奇异值分解(SVD)与图像压缩¶
In [27]:
np.random.seed(42)
A = np.outer(np.random.rand(3), np.random.rand(3)) + np.outer(np.random.rand(3), np.random.rand(3))
print("两个外积的和:\n", A)
两个外积的和: [[0.26534903 0.05963086 0.11476207] [1.18246876 0.16615895 0.988419 ] [0.86384744 0.12657835 0.69721442]]
In [28]:
U, S, Vt = svd(A)
reconstructed = np.outer(U[:, 0], Vt[0, :] * S[0]) + np.outer(U[:, 1], Vt[1, :] * S[1])
print("SVD 重建:\n", reconstructed)
print(f"\n与原矩阵的差异: {np.max(np.abs(A - reconstructed)):.10f}")
SVD 重建: [[0.26534903 0.05963086 0.11476207] [1.18246876 0.16615895 0.988419 ] [0.86384744 0.12657835 0.69721442]] 与原矩阵的差异: 0.0000000000
SVD 图像压缩¶
可以使用前 $k$ 个奇异值近似图像。
In [29]:
image_path, _ = urlretrieve('https://pacman.cs.tsinghua.edu.cn/~hanwentao/cpct/08/tree.png')
image = plt.imread(image_path)[:, :, :3]
plt.imshow(image)
plt.title('原始图像')
plt.axis('off')
Out[29]:
(np.float64(-0.5), np.float64(357.5), np.float64(199.5), np.float64(-0.5))
In [30]:
pr, pg, pb = image[:, :, 0], image[:, :, 1], image[:, :, 2]
Ur, Sr, Vtr = svd(pr, full_matrices=False)
Ug, Sg, Vtg = svd(pg, full_matrices=False)
Ub, Sb, Vtb = svd(pb, full_matrices=False)
print(f"奇异值数量: {len(Sr)}")
奇异值数量: 200
In [31]:
def show_svd_approximation(n):
pr_approx = sum(np.outer(Ur[:, i], Vtr[i, :]) * Sr[i] for i in range(min(n, len(Sr))))
pg_approx = sum(np.outer(Ug[:, i], Vtg[i, :]) * Sg[i] for i in range(min(n, len(Sg))))
pb_approx = sum(np.outer(Ub[:, i], Vtb[i, :]) * Sb[i] for i in range(min(n, len(Sb))))
img_approx = np.clip(np.stack([pr_approx, pg_approx, pb_approx], axis=2), 0, 1)
original_size = image.size
compressed_size = n * (Ur.shape[0] + Vtr.shape[1] + 1) * 3
ratio = original_size / compressed_size
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(image)
axes[0].set_title(f'原始图像 ({original_size:,} 像素)')
axes[0].axis('off')
axes[1].imshow(img_approx)
axes[1].set_title(f'SVD 近似 (n={n}, 压缩比 {ratio:.1f}x)')
axes[1].axis('off')
plt.show()
show_svd_approximation(50)
In [32]:
interact(show_svd_approximation, n=IntSlider(min=1, max=200, value=50));
interactive(children=(IntSlider(value=50, description='n', max=200, min=1), Output()), _dom_classes=('widget-i…
本讲小结¶
本讲探讨了数据中的结构以及如何利用它:
- 接缝裁剪:通过动态规划找到最小能量路径,智能缩放图像
- 特殊形态向量和矩阵:one-hot 向量、对角矩阵、稀疏矩阵、乘法表
- Python 对象:使用
class定义自定义数据结构 - SVD:发现矩阵隐藏结构,用于图像压缩
学到的语法¶
class:定义自定义数据结构__dict__:查看对象内部属性raise:抛出异常np.diag:创建对角矩阵scipy.sparse:稀疏矩阵scipy.linalg.svd:奇异值分解