问题引入¶
在上一讲进行图像变换时,由于采用了反向变换的方式,需要求解变换函数的逆函数。
- 对于线性变换,变换矩阵的逆就是逆变换的矩阵。
- 对于非线性变换,一般没有解析解,需要用到数值方法。
本讲将介绍一种数值求解方法——牛顿法,可以用来求解非线性函数的根。
数值求解的思路¶
在科学和工程中,经常需要求解方程组(例如求在某个点的逆变换结果)。
- 如果方程是线性的,可以使用线性代数的方法(例如高斯消元法)来求解。
- 如果方程是非线性的,就会比较复杂。思考如何简化这个过程。
- 考虑能否将非线性方程转化为一系列线性方程,它们通过用线性函数去近似非线性函数并求解以获得更好的解。
- 然后根据需要重复此操作多次,以获得一系列越来越好的解。
- 类似这样的方法被称为迭代算法。
牛顿法是一种常用的迭代算法。
准备工作¶
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp # 符号计算
import jax # 自动微分,如果没有的话需要用 %pip install jax 命令安装
from ipywidgets import interact, FloatSlider, IntSlider
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC']
#plt.rcParams['axes.unicode_minus'] = False # 修复减号的显示
一维牛顿法¶
想要求解形如 $f(x) = g(x)$ 的方程,可以将所有项移到方程的一边,写成 $h(x) = 0$,其中 $h(x) = f(x) - g(x)$。
满足 $h(x^*) = 0$ 的点 $x^*$ 称为 $h$ 的根或零点。
牛顿法用于寻找零点,从而求解原始方程。先看一个直观的例子:
def straight_line(x0, y0, x, m):
"""计算通过点(x0, y0)、斜率为m的直线在x处的y值"""
return y0 + m * (x - x0)
def newton_visualization_1d(f, f_prime, n_steps, x0, x_range=(-1, 10), y_range=(-10, 70)):
"""
可视化一维牛顿法
参数:
f: 目标函数
f_prime: 函数的导数
n_steps: 迭代次数
x0: 初始点
x_range: x轴范围
y_range: y轴范围
"""
fig, ax = plt.subplots(figsize=(10, 7))
x = np.linspace(x_range[0], x_range[1], 500)
# 绘制函数
ax.plot(x, f(x), 'b-', linewidth=2, label='f(x)')
# 绘制x轴
ax.axhline(y=0, color='magenta', linestyle='--', linewidth=2)
# 标记初始点
ax.scatter([x0], [0], c='green', s=100, zorder=5)
ax.annotate(f'$x_0$', (x0, -5), fontsize=12, ha='center')
current_x = x0
for i in range(n_steps):
# 从x轴到函数曲线的垂直线
ax.plot([current_x, current_x], [0, f(current_x)], 'gray', alpha=0.5)
# 标记函数上的点
ax.scatter([current_x], [f(current_x)], c='red', s=80, zorder=5)
# 计算切线
m = f_prime(current_x)
tangent_y = straight_line(current_x, f(current_x), x, m)
ax.plot(x, tangent_y, 'b--', alpha=0.5, linewidth=1.5)
# 计算下一个点
next_x = current_x - f(current_x) / m
# 标记新点
ax.scatter([next_x], [0], c='green', s=100, zorder=5)
ax.annotate(f'$x_{i}$', (next_x, -5), fontsize=12, ha='center')
current_x = next_x
ax.set_xlim(x_range)
ax.set_ylim(y_range)
ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('y', fontsize=12)
ax.set_title(f'牛顿法可视化 ({n_steps} 次迭代)', fontsize=14)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return current_x
# 示例1: 求解 x^2 - 2 = 0 (即 sqrt(2))
def f1(x):
return x**2 - 2
def f1_prime(x):
return 2*x
# 交互式演示
@interact(n=IntSlider(min=0, max=10, value=0, description='迭代次数'),
x0=FloatSlider(min=-10, max=10, value=6, description='初始值 x₀'))
def interactive_newton_1(n, x0):
newton_visualization_1d(f1, f1_prime, n, x0, x_range=(-1, 10))
interactive(children=(IntSlider(value=0, description='迭代次数', max=10), FloatSlider(value=6.0, description='初始值 …
# 示例2: 求解 0.2x^3 - 4x + 1 = 0
def f2(x):
return 0.2*x**3 - 4*x + 1
def f2_prime(x):
return 0.6*x**2 - 4
@interact(n=IntSlider(min=0, max=10, value=0, description='迭代次数'),
x0=FloatSlider(min=-10, max=10, value=6, description='初始值 x₀'))
def interactive_newton_2(n, x0):
newton_visualization_1d(f2, f2_prime, n, x0, x_range=(-10, 10))
interactive(children=(IntSlider(value=0, description='迭代次数', max=10), FloatSlider(value=6.0, description='初始值 …
- 牛顿法的思想:跟随函数曲线的方向去不断靠近。
- 具体做法:通过在当前位置建立切线,求它与 $x$ 轴的交点,将其作为下一个位置的 $x$ 坐标。
# 使用 SymPy 进行符号计算
x, z, eta = sp.symbols('x z eta')
m = 2
# 定义函数 f(x) = x^m - 2
f = x**m - 2
f
# 将 z 代入
f_z = f.subs(x, z)
f_z
# 将 z + η 代入,并展开
f_z_eta = sp.expand(f.subs(x, z + eta))
f_z_eta
# 计算 z 处的导数
f_prime_z = sp.diff(f, x).subs(x, z)
f_prime_z
# 线性近似
f_la_z = f_z + eta * f_prime_z
f_la_z
# 两者的差
f_z_eta - f_la_z
当 $\eta$ 很小时,$\eta^2$ 及更高次的项会非常小,可以忽略。
剩下的项要么不含 $\eta$(常数项),要么是 $\eta$ 的线性项,$\eta$ 的系数就是导数。
也就是说:导数给出了函数的线性部分。
牛顿法的数学原理¶
下面用代数的语言来解释牛顿法。
假设有一个根的猜测 $x_0$,想要找到一个更好的猜测 $x_1$。
设 $x_1 = x_0 + \delta$,其中 $x_1$ 和 $\delta$ 仍然未知。
由于希望 $x_1$ 是根,所以:
$$f(x_1) = f(x_0 + \delta) \simeq 0$$
如果已经相当接近根,那么 $\delta$ 应该很小,所以可以用切线来近似 $f$:
$$f(x_0) + \delta \, f'(x_0) \simeq 0$$
因此,
$$\delta \simeq \frac{-f(x_0)}{f'(x_0)}$$
所以,
$$x_1 = x_0 - \frac{f(x_0)}{f'(x_0)}$$
重复这个过程:
$$x_2 = x_1 - \frac{f(x_1)}{f'(x_1)}$$
一般地,有:
$$x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}$$
这就是一维牛顿法。
一维牛顿法的实现¶
def newton_1d(f, x0, n_iters=10):
"""
一维牛顿法
参数:
f: 目标函数
x0: 初始猜测
n_iterations: 迭代次数
返回:
近似根
"""
# 使用 JAX 自动求导
f_prime = jax.grad(f)
x = x0
for i in range(n_iters):
x = x - f(x) / f_prime(x)
return x
# 测试: 计算 sqrt(2)
def f_sqrt2(x):
return x**2 - 2
result = float(newton_1d(f_sqrt2, 37.0))
print(f"牛顿法结果: {result}")
print(f"sqrt(2) = {np.sqrt(2)}")
print(f"误差: {abs(result - np.sqrt(2))}")
牛顿法结果: 1.4142135381698608 sqrt(2) = 1.4142135623730951 误差: 2.4203234305630872e-08
牛顿法的收敛性¶
def newton_1d_with_history(f, x0, n_iters=10):
"""
带历史记录的一维牛顿法
返回:
最终结果和迭代历史
"""
f_prime = jax.grad(f)
x = x0
history = [x]
for i in range(n_iters):
x = x - f(x) / f_prime(x)
history.append(float(x))
return x, history
# 演示收敛速度
result, history = newton_1d_with_history(f_sqrt2, 37.0, 10)
print("迭代历史 (求解 x² - 2 = 0):")
for i, x in enumerate(history):
error = abs(x - np.sqrt(2))
print(f"迭代 {i:2d}: x = {x:20.15f}, 误差 = {error:.2e}")
迭代历史 (求解 x² - 2 = 0): 迭代 0: x = 37.000000000000000, 误差 = 3.56e+01 迭代 1: x = 18.527027130126953, 误差 = 1.71e+01 迭代 2: x = 9.317488670349121, 误差 = 7.90e+00 迭代 3: x = 4.766069412231445, 误差 = 3.35e+00 迭代 4: x = 2.592851161956787, 误差 = 1.18e+00 迭代 5: x = 1.682101488113403, 误差 = 2.68e-01 迭代 6: x = 1.435545206069946, 误差 = 2.13e-02 迭代 7: x = 1.414372086524963, 误差 = 1.59e-04 迭代 8: x = 1.414213538169861, 误差 = 2.42e-08 迭代 9: x = 1.414213538169861, 误差 = 2.42e-08 迭代 10: x = 1.414213538169861, 误差 = 2.42e-08
# 可视化收敛过程
def plot_convergence(f, x0, n_iters=15):
"""
绘制牛顿法收敛过程
"""
result, history = newton_1d_with_history(f, x0, n_iters)
errors = [abs(x - np.sqrt(2)) for x in history]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 估计值随迭代的变化
axes[0].plot(history, 'o-', markersize=8)
axes[0].axhline(y=np.sqrt(2), color='r', linestyle='--', label=f'真实值 √2 ≈ {np.sqrt(2):.6f}')
axes[0].set_xlabel('迭代次数', fontsize=12)
axes[0].set_ylabel('x值', fontsize=12)
axes[0].set_title('牛顿法收敛过程', fontsize=14)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# 误差的对数随迭代的变化
axes[1].semilogy(range(len(errors)), errors, 'o-', markersize=8)
axes[1].set_xlabel('迭代次数', fontsize=12)
axes[1].set_ylabel('误差 (对数尺度)', fontsize=12)
axes[1].set_title('牛顿法误差收敛 (二次收敛)', fontsize=14)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
plot_convergence(f_sqrt2, 37.0)
牛顿法小结¶
牛顿法的基本思想:用线性逼近非线性。
牛顿法的优点:
- 适用范围广:可以用于求解非线性方程组和寻找逆变换
- 快速收敛:在根附近具有二次收敛速度
牛顿法的缺点:
- 需要导数:需要函数的导数(可以使用自动微分技术解决)
- 初始值敏感:初始值选择不当可能导致不收敛
- 局部收敛:只能保证局部收敛,不能保证找到全局根
二维推广:一维的导数 $f'(x)$ 在二维中变成雅可比矩阵 $J$,求解公式从标量除法变成矩阵求解。
$$x_{n+1} = x_n - J^{-1} \cdot T(x_n)$$
SciPy 在 scipy.optimize 模块中提供了 root、fsolve、newton 等函数用于一般函数的数值求根。
Python 的 f-字符串¶
- 格式化字符串字面值或称f-字符串是带有 'f' 或 'F' 前缀的字符串字面值。
- 不同于其他字符串字面值,f-字符串的值不是固定的,可能包含由花括号 {} 标记的替换字段。
- 替换字段包含在运行时进行求值的表达式。
例如:
name = 'Shuimu'
f'Hello, {name} College!'
'Hello, Shuimu College!'
如果想在 f-字符串中给出真的花括号,可以使用双花括号:
f'Left {{ Right }}'
'Left { Right }'
花括号内除了表达式外,还可以加冒号以及一段格式描述(参见格式说明微语言),例如:
number = 14.3
f'{number:20.7f}'
' 14.3000000'
Python 函数装饰器¶
函数装饰器(function decorator)用于在已有函数的基础上进行装饰,获得一个新的函数。例如:
# negate 函数的作用是包装一个单个参数的函数 f(x),返回一个新的函数 g(x),使得 g(x) = -f(x)
def negate(f):
return lambda x: -f(x)
# 使用装饰器语法
@negate
def double(x):
return x*2
double(1)
-2
装饰器是一种语法糖,上述代码等价于以下代码:
# 使用普通语法
def double(x):
return x*2
double = negate(double)
double(1)
-2
思考:如何实现像 interact 这样可以指定参数的装饰器?
def plus(n):
return lambda f: lambda x: f(x) + n
@plus(3)
def double(x):
return x*2
double(2)
7
Python 字典(Dictionary)¶
字典是 Python 中非常重要的数据结构,用于存储键值对(key-value pairs)。
- 字典中的每个元素由一个键和一个值组成
- 键必须是不可变类型(如字符串、数字、元组),且唯一
- 值可以是任意类型
- 字典是可变的,可以动态添加、修改和删除元素
- Python 3.7+ 中,字典保持插入顺序
创建字典¶
# 使用花括号创建字典
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92}
scores
{'Alice': 95, 'Bob': 87, 'Charlie': 92}
# 使用 dict() 函数创建字典
scores2 = dict(Alice=95, Bob=87, Charlie=92)
scores2
{'Alice': 95, 'Bob': 87, 'Charlie': 92}
# 创建空字典
empty_dict = {}
empty_dict2 = dict()
print(empty_dict, empty_dict2)
{} {}
# 使用字典推导式创建字典
# 语法: {key_expr: value_expr for item in iterable}
# 示例:将列表转换为索引-值字典
fruits = ['apple', 'banana', 'orange']
fruit_dict = {i: fruit for i, fruit in enumerate(fruits)}
fruit_dict
{0: 'apple', 1: 'banana', 2: 'orange'}
# 示例:筛选和转换
scores = {'Alice': 95, 'Bob': 67, 'Charlie': 82, 'David': 58}
passed = {name: score for name, score in scores.items() if score >= 60}
passed
{'Alice': 95, 'Bob': 67, 'Charlie': 82}
# 示例:键值互换
original = {'a': 1, 'b': 2, 'c': 3}
swapped = {v: k for k, v in original.items()}
swapped
{1: 'a', 2: 'b', 3: 'c'}
访问字典元素¶
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92}
# 使用键访问值
print(scores['Alice']) # 如果键不存在会报错
95
# 使用 get() 方法,键不存在时返回 None 或默认值
print(scores.get('Bob'))
print(scores.get('David')) # 返回 None
print(scores.get('David', 0)) # 返回默认值 0
87 None 0
修改字典¶
scores = {'Alice': 95, 'Bob': 87}
# 添加或修改元素
scores['Charlie'] = 92 # 添加新元素
scores['Alice'] = 98 # 修改已有元素
scores
{'Alice': 98, 'Bob': 87, 'Charlie': 92}
# 使用 update() 方法合并另一个字典
scores.update({'David': 88, 'Eve': 91})
scores
{'Alice': 98, 'Bob': 87, 'Charlie': 92, 'David': 88, 'Eve': 91}
# 删除元素
del scores['Bob'] # 使用 del 关键字
value = scores.pop('Charlie') # 使用 pop() 方法,返回被删除的值
print(f"删除的值: {value}")
scores
删除的值: 92
{'Alice': 98, 'David': 88, 'Eve': 91}
字典的常用方法¶
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92}
# 获取所有键、值、键值对
print('所有键:', list(scores.keys()))
print('所有值:', list(scores.values()))
print('所有键值对:', list(scores.items()))
所有键: ['Alice', 'Bob', 'Charlie']
所有值: [95, 87, 92]
所有键值对: [('Alice', 95), ('Bob', 87), ('Charlie', 92)]
# 遍历字典
for name, score in scores.items():
print(f'{name}: {score}')
Alice: 95 Bob: 87 Charlie: 92
# 检查键是否存在
print('Alice' in scores)
print('David' in scores)
True False
默认字典(defaultdict)¶
当访问不存在的键时,普通字典会报错,而 defaultdict 会自动创建默认值。
from collections import defaultdict
# 创建一个默认值为 0 的字典
word_count = defaultdict(int)
text = 'the quick brown fox jumps over the lazy dog the fox'
for word in text.split():
word_count[word] += 1 # 不需要检查键是否存在
dict(word_count)
{'the': 3,
'quick': 1,
'brown': 1,
'fox': 2,
'jumps': 1,
'over': 1,
'lazy': 1,
'dog': 1}
# 默认值为列表
from collections import defaultdict
grouped = defaultdict(list)
students = [('A', 'Alice'), ('B', 'Bob'), ('A', 'Amy'), ('B', 'Brian')]
for grade, name in students:
grouped[grade].append(name)
dict(grouped)
{'A': ['Alice', 'Amy'], 'B': ['Bob', 'Brian']}
字典排序¶
使用 sorted() 函数可以对字典进行排序。
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92, 'David': 88}
# 按键排序
sorted_by_key = dict(sorted(scores.items()))
print('按键排序:', sorted_by_key)
# 按值排序(升序)
sorted_by_value = dict(sorted(scores.items(), key=lambda x: x[1]))
print('按值升序:', sorted_by_value)
# 按值排序(降序)
sorted_by_value_desc = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
print('按值降序:', sorted_by_value_desc)
按键排序: {'Alice': 95, 'Bob': 87, 'Charlie': 92, 'David': 88}
按值升序: {'Bob': 87, 'David': 88, 'Charlie': 92, 'Alice': 95}
按值降序: {'Alice': 95, 'Charlie': 92, 'David': 88, 'Bob': 87}
# 创建集合
fruits = {'apple', 'banana', 'orange', 'apple'} # 自动去重
print(fruits)
numbers = set([1, 2, 3, 2, 1]) # 从列表创建
print(numbers)
{'banana', 'orange', 'apple'}
{1, 2, 3}
# 添加和删除元素
fruits.add('grape')
fruits.remove('banana')
fruits
{'apple', 'grape', 'orange'}
# 集合运算
a = {1, 2, 3, 4}
b = {3, 4, 5, 6}
print('交集:', a & b) # 或 a.intersection(b)
print('并集:', a | b) # 或 a.union(b)
print('差集:', a - b) # 或 a.difference(b)
print('对称差:', a ^ b) # 或 a.symmetric_difference(b)
交集: {3, 4}
并集: {1, 2, 3, 4, 5, 6}
差集: {1, 2}
对称差: {1, 2, 5, 6}
# 判断包含关系
a = {1, 2, 3, 4}
b = {2, 3}
c = {1, 2, 3, 4}
print('b 是 a 的子集:', b.issubset(a)) # 或 b <= a
print('a 是 b 的超集:', a.issuperset(b)) # 或 a >= b
print('c 与 a 相等:', c == a)
print('b 是 a 的真子集:', b < a) # 真子集(不等于)
print('a 是 b 的真超集:', a > b) # 真超集(不等于)
b 是 a 的子集: True a 是 b 的超集: True c 与 a 相等: True b 是 a 的真子集: True a 是 b 的真超集: True
# 判断两个集合是否无交集
x = {1, 2}
y = {3, 4}
z = {2, 3}
print('x 和 y 无交集:', x.isdisjoint(y)) # True
print('x 和 z 无交集:', x.isdisjoint(z)) # False
x 和 y 无交集: True x 和 z 无交集: False
本讲小结¶
- 牛顿法:数值方法求解一般方程
- Python
- f-字符串
- 函数装饰器
- 字典与集合