程序设计与计算思维¶

Computer Programming and Computational Thinking

第 6 讲:牛顿法求解方程¶

2025—2026学年度春季学期

清华大学 韩文弢

问题引入¶

在上一讲进行图像变换时,由于采用了反向变换的方式,需要求解变换函数的逆函数。

  • 对于线性变换,变换矩阵的逆就是逆变换的矩阵。
  • 对于非线性变换,一般没有解析解,需要用到数值方法。

本讲将介绍一种数值求解方法——牛顿法,可以用来求解非线性函数的根。

数值求解的思路¶

在科学和工程中,经常需要求解方程组(例如求在某个点的逆变换结果)。

  • 如果方程是线性的,可以使用线性代数的方法(例如高斯消元法)来求解。
  • 如果方程是非线性的,就会比较复杂。思考如何简化这个过程。

image.png

  • 考虑能否将非线性方程转化为一系列线性方程,它们通过用线性函数去近似非线性函数并求解以获得更好的解。
  • 然后根据需要重复此操作多次,以获得一系列越来越好的解。
  • 类似这样的方法被称为迭代算法。

牛顿法是一种常用的迭代算法。

准备工作¶

In [44]:
import numpy as np
import matplotlib.pyplot as plt
import sympy as sp  # 符号计算
import jax  # 自动微分,如果没有的话需要用 %pip install jax 命令安装
from ipywidgets import interact, FloatSlider, IntSlider

# 设置中文显示
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC']
#plt.rcParams['axes.unicode_minus'] = False  # 修复减号的显示

一维牛顿法¶

想要求解形如 $f(x) = g(x)$ 的方程,可以将所有项移到方程的一边,写成 $h(x) = 0$,其中 $h(x) = f(x) - g(x)$。

满足 $h(x^*) = 0$ 的点 $x^*$ 称为 $h$ 的根或零点。

牛顿法用于寻找零点,从而求解原始方程。先看一个直观的例子:

In [2]:
def straight_line(x0, y0, x, m):
    """计算通过点(x0, y0)、斜率为m的直线在x处的y值"""
    return y0 + m * (x - x0)

def newton_visualization_1d(f, f_prime, n_steps, x0, x_range=(-1, 10), y_range=(-10, 70)):
    """
    可视化一维牛顿法
    
    参数:
        f: 目标函数
        f_prime: 函数的导数
        n_steps: 迭代次数
        x0: 初始点
        x_range: x轴范围
        y_range: y轴范围
    """
    fig, ax = plt.subplots(figsize=(10, 7))
    
    x = np.linspace(x_range[0], x_range[1], 500)
    
    # 绘制函数
    ax.plot(x, f(x), 'b-', linewidth=2, label='f(x)')
    
    # 绘制x轴
    ax.axhline(y=0, color='magenta', linestyle='--', linewidth=2)
    
    # 标记初始点
    ax.scatter([x0], [0], c='green', s=100, zorder=5)
    ax.annotate(f'$x_0$', (x0, -5), fontsize=12, ha='center')
    
    current_x = x0
    
    for i in range(n_steps):
        # 从x轴到函数曲线的垂直线
        ax.plot([current_x, current_x], [0, f(current_x)], 'gray', alpha=0.5)
        
        # 标记函数上的点
        ax.scatter([current_x], [f(current_x)], c='red', s=80, zorder=5)
        
        # 计算切线
        m = f_prime(current_x)
        tangent_y = straight_line(current_x, f(current_x), x, m)
        ax.plot(x, tangent_y, 'b--', alpha=0.5, linewidth=1.5)
        
        # 计算下一个点
        next_x = current_x - f(current_x) / m
        
        # 标记新点
        ax.scatter([next_x], [0], c='green', s=100, zorder=5)
        ax.annotate(f'$x_{i}$', (next_x, -5), fontsize=12, ha='center')
        
        current_x = next_x
    
    ax.set_xlim(x_range)
    ax.set_ylim(y_range)
    ax.set_xlabel('x', fontsize=12)
    ax.set_ylabel('y', fontsize=12)
    ax.set_title(f'牛顿法可视化 ({n_steps} 次迭代)', fontsize=14)
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return current_x
In [3]:
# 示例1: 求解 x^2 - 2 = 0 (即 sqrt(2))
def f1(x):
    return x**2 - 2

def f1_prime(x):
    return 2*x

# 交互式演示
@interact(n=IntSlider(min=0, max=10, value=0, description='迭代次数'),
          x0=FloatSlider(min=-10, max=10, value=6, description='初始值 x₀'))
def interactive_newton_1(n, x0):
    newton_visualization_1d(f1, f1_prime, n, x0, x_range=(-1, 10))
interactive(children=(IntSlider(value=0, description='迭代次数', max=10), FloatSlider(value=6.0, description='初始值 …
In [4]:
# 示例2: 求解 0.2x^3 - 4x + 1 = 0
def f2(x):
    return 0.2*x**3 - 4*x + 1

def f2_prime(x):
    return 0.6*x**2 - 4

@interact(n=IntSlider(min=0, max=10, value=0, description='迭代次数'),
          x0=FloatSlider(min=-10, max=10, value=6, description='初始值 x₀'))
def interactive_newton_2(n, x0):
    newton_visualization_1d(f2, f2_prime, n, x0, x_range=(-10, 10))
interactive(children=(IntSlider(value=0, description='迭代次数', max=10), FloatSlider(value=6.0, description='初始值 …
  • 牛顿法的思想:跟随函数曲线的方向去不断靠近。
  • 具体做法:通过在当前位置建立切线,求它与 $x$ 轴的交点,将其作为下一个位置的 $x$ 坐标。

使用符号计算理解导数和非线性函数¶

可以使用符号计算(与数值计算相对应)来理解非线性(多项式)函数的情况。

观察在点 $z$ 附近用小量 $\eta$ 扰动函数 $f$ 会发生什么。

In [5]:
# 使用 SymPy 进行符号计算
x, z, eta = sp.symbols('x z eta')
m = 2

# 定义函数 f(x) = x^m - 2
f = x**m - 2
f
Out[5]:
$\displaystyle x^{2} - 2$
In [6]:
# 将 z 代入
f_z = f.subs(x, z)
f_z
Out[6]:
$\displaystyle z^{2} - 2$
In [7]:
# 将 z + η 代入,并展开
f_z_eta = sp.expand(f.subs(x, z + eta))
f_z_eta
Out[7]:
$\displaystyle \eta^{2} + 2 \eta z + z^{2} - 2$
In [8]:
# 计算 z 处的导数
f_prime_z = sp.diff(f, x).subs(x, z)
f_prime_z
Out[8]:
$\displaystyle 2 z$
In [9]:
# 线性近似
f_la_z = f_z + eta * f_prime_z
f_la_z
Out[9]:
$\displaystyle 2 \eta z + z^{2} - 2$
In [10]:
# 两者的差
f_z_eta - f_la_z
Out[10]:
$\displaystyle \eta^{2}$

当 $\eta$ 很小时,$\eta^2$ 及更高次的项会非常小,可以忽略。

剩下的项要么不含 $\eta$(常数项),要么是 $\eta$ 的线性项,$\eta$ 的系数就是导数。

也就是说:导数给出了函数的线性部分。

牛顿法的数学原理¶

下面用代数的语言来解释牛顿法。

假设有一个根的猜测 $x_0$,想要找到一个更好的猜测 $x_1$。

设 $x_1 = x_0 + \delta$,其中 $x_1$ 和 $\delta$ 仍然未知。

由于希望 $x_1$ 是根,所以:

$$f(x_1) = f(x_0 + \delta) \simeq 0$$

如果已经相当接近根,那么 $\delta$ 应该很小,所以可以用切线来近似 $f$:

$$f(x_0) + \delta \, f'(x_0) \simeq 0$$

因此,

$$\delta \simeq \frac{-f(x_0)}{f'(x_0)}$$

所以,

$$x_1 = x_0 - \frac{f(x_0)}{f'(x_0)}$$

重复这个过程:

$$x_2 = x_1 - \frac{f(x_1)}{f'(x_1)}$$

一般地,有:

$$x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}$$

这就是一维牛顿法。

一维牛顿法的实现¶

In [11]:
def newton_1d(f, x0, n_iters=10):
    """
    一维牛顿法
    
    参数:
        f: 目标函数
        x0: 初始猜测
        n_iterations: 迭代次数
    
    返回:
        近似根
    """
    # 使用 JAX 自动求导
    f_prime = jax.grad(f)
    x = x0
    for i in range(n_iters):
        x = x - f(x) / f_prime(x)
    
    return x
In [12]:
# 测试: 计算 sqrt(2)
def f_sqrt2(x):
    return x**2 - 2

result = float(newton_1d(f_sqrt2, 37.0))
print(f"牛顿法结果: {result}")
print(f"sqrt(2) = {np.sqrt(2)}")
print(f"误差: {abs(result - np.sqrt(2))}")
牛顿法结果: 1.4142135381698608
sqrt(2) = 1.4142135623730951
误差: 2.4203234305630872e-08

牛顿法的收敛性¶

In [13]:
def newton_1d_with_history(f, x0, n_iters=10):
    """
    带历史记录的一维牛顿法
    
    返回:
        最终结果和迭代历史
    """
    f_prime = jax.grad(f)
    
    x = x0
    history = [x]
    
    for i in range(n_iters):
        x = x - f(x) / f_prime(x)
        history.append(float(x))
    
    return x, history

# 演示收敛速度
result, history = newton_1d_with_history(f_sqrt2, 37.0, 10)

print("迭代历史 (求解 x² - 2 = 0):")
for i, x in enumerate(history):
    error = abs(x - np.sqrt(2))
    print(f"迭代 {i:2d}: x = {x:20.15f}, 误差 = {error:.2e}")
迭代历史 (求解 x² - 2 = 0):
迭代  0: x =   37.000000000000000, 误差 = 3.56e+01
迭代  1: x =   18.527027130126953, 误差 = 1.71e+01
迭代  2: x =    9.317488670349121, 误差 = 7.90e+00
迭代  3: x =    4.766069412231445, 误差 = 3.35e+00
迭代  4: x =    2.592851161956787, 误差 = 1.18e+00
迭代  5: x =    1.682101488113403, 误差 = 2.68e-01
迭代  6: x =    1.435545206069946, 误差 = 2.13e-02
迭代  7: x =    1.414372086524963, 误差 = 1.59e-04
迭代  8: x =    1.414213538169861, 误差 = 2.42e-08
迭代  9: x =    1.414213538169861, 误差 = 2.42e-08
迭代 10: x =    1.414213538169861, 误差 = 2.42e-08
In [14]:
# 可视化收敛过程
def plot_convergence(f, x0, n_iters=15):
    """
    绘制牛顿法收敛过程
    """
    result, history = newton_1d_with_history(f, x0, n_iters)
    errors = [abs(x - np.sqrt(2)) for x in history]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 估计值随迭代的变化
    axes[0].plot(history, 'o-', markersize=8)
    axes[0].axhline(y=np.sqrt(2), color='r', linestyle='--', label=f'真实值 √2 ≈ {np.sqrt(2):.6f}')
    axes[0].set_xlabel('迭代次数', fontsize=12)
    axes[0].set_ylabel('x值', fontsize=12)
    axes[0].set_title('牛顿法收敛过程', fontsize=14)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # 误差的对数随迭代的变化
    axes[1].semilogy(range(len(errors)), errors, 'o-', markersize=8)
    axes[1].set_xlabel('迭代次数', fontsize=12)
    axes[1].set_ylabel('误差 (对数尺度)', fontsize=12)
    axes[1].set_title('牛顿法误差收敛 (二次收敛)', fontsize=14)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_convergence(f_sqrt2, 37.0)
No description has been provided for this image

牛顿法小结¶

牛顿法的基本思想:用线性逼近非线性。

牛顿法的优点:

  1. 适用范围广:可以用于求解非线性方程组和寻找逆变换
  2. 快速收敛:在根附近具有二次收敛速度

牛顿法的缺点:

  1. 需要导数:需要函数的导数(可以使用自动微分技术解决)
  2. 初始值敏感:初始值选择不当可能导致不收敛
  3. 局部收敛:只能保证局部收敛,不能保证找到全局根

二维推广:一维的导数 $f'(x)$ 在二维中变成雅可比矩阵 $J$,求解公式从标量除法变成矩阵求解。

$$x_{n+1} = x_n - J^{-1} \cdot T(x_n)$$

SciPy 在 scipy.optimize 模块中提供了 root、fsolve、newton 等函数用于一般函数的数值求根。

Python 的 f-字符串¶

  • 格式化字符串字面值或称f-字符串是带有 'f' 或 'F' 前缀的字符串字面值。
  • 不同于其他字符串字面值,f-字符串的值不是固定的,可能包含由花括号 {} 标记的替换字段。
  • 替换字段包含在运行时进行求值的表达式。

例如:

In [15]:
name = 'Shuimu'
f'Hello, {name} College!'
Out[15]:
'Hello, Shuimu College!'

如果想在 f-字符串中给出真的花括号,可以使用双花括号:

In [16]:
f'Left {{ Right }}'
Out[16]:
'Left { Right }'

花括号内除了表达式外,还可以加冒号以及一段格式描述(参见格式说明微语言),例如:

In [17]:
number = 14.3
f'{number:20.7f}'
Out[17]:
'          14.3000000'

Python 函数装饰器¶

函数装饰器(function decorator)用于在已有函数的基础上进行装饰,获得一个新的函数。例如:

In [18]:
# negate 函数的作用是包装一个单个参数的函数 f(x),返回一个新的函数 g(x),使得 g(x) = -f(x)
def negate(f):
    return lambda x: -f(x)
In [19]:
# 使用装饰器语法
@negate
def double(x):
    return x*2

double(1)
Out[19]:
-2

装饰器是一种语法糖,上述代码等价于以下代码:

In [20]:
# 使用普通语法
def double(x):
    return x*2
double = negate(double)

double(1)
Out[20]:
-2

思考:如何实现像 interact 这样可以指定参数的装饰器?

In [21]:
def plus(n):
    return lambda f: lambda x: f(x) + n

@plus(3)
def double(x):
    return x*2

double(2)
Out[21]:
7

Python 字典(Dictionary)¶

字典是 Python 中非常重要的数据结构,用于存储键值对(key-value pairs)。

  • 字典中的每个元素由一个键和一个值组成
  • 键必须是不可变类型(如字符串、数字、元组),且唯一
  • 值可以是任意类型
  • 字典是可变的,可以动态添加、修改和删除元素
  • Python 3.7+ 中,字典保持插入顺序

创建字典¶

In [22]:
# 使用花括号创建字典
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92}
scores
Out[22]:
{'Alice': 95, 'Bob': 87, 'Charlie': 92}
In [23]:
# 使用 dict() 函数创建字典
scores2 = dict(Alice=95, Bob=87, Charlie=92)
scores2
Out[23]:
{'Alice': 95, 'Bob': 87, 'Charlie': 92}
In [24]:
# 创建空字典
empty_dict = {}
empty_dict2 = dict()
print(empty_dict, empty_dict2)
{} {}
In [25]:
# 使用字典推导式创建字典
# 语法: {key_expr: value_expr for item in iterable}

# 示例:将列表转换为索引-值字典
fruits = ['apple', 'banana', 'orange']
fruit_dict = {i: fruit for i, fruit in enumerate(fruits)}
fruit_dict
Out[25]:
{0: 'apple', 1: 'banana', 2: 'orange'}
In [26]:
# 示例:筛选和转换
scores = {'Alice': 95, 'Bob': 67, 'Charlie': 82, 'David': 58}
passed = {name: score for name, score in scores.items() if score >= 60}
passed
Out[26]:
{'Alice': 95, 'Bob': 67, 'Charlie': 82}
In [27]:
# 示例:键值互换
original = {'a': 1, 'b': 2, 'c': 3}
swapped = {v: k for k, v in original.items()}
swapped
Out[27]:
{1: 'a', 2: 'b', 3: 'c'}

访问字典元素¶

In [28]:
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92}

# 使用键访问值
print(scores['Alice'])  # 如果键不存在会报错
95
In [29]:
# 使用 get() 方法,键不存在时返回 None 或默认值
print(scores.get('Bob'))
print(scores.get('David'))  # 返回 None
print(scores.get('David', 0))  # 返回默认值 0
87
None
0

修改字典¶

In [30]:
scores = {'Alice': 95, 'Bob': 87}

# 添加或修改元素
scores['Charlie'] = 92  # 添加新元素
scores['Alice'] = 98    # 修改已有元素
scores
Out[30]:
{'Alice': 98, 'Bob': 87, 'Charlie': 92}
In [31]:
# 使用 update() 方法合并另一个字典
scores.update({'David': 88, 'Eve': 91})
scores
Out[31]:
{'Alice': 98, 'Bob': 87, 'Charlie': 92, 'David': 88, 'Eve': 91}
In [32]:
# 删除元素
del scores['Bob']      # 使用 del 关键字
value = scores.pop('Charlie')  # 使用 pop() 方法,返回被删除的值
print(f"删除的值: {value}")
scores
删除的值: 92
Out[32]:
{'Alice': 98, 'David': 88, 'Eve': 91}

字典的常用方法¶

In [33]:
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92}

# 获取所有键、值、键值对
print('所有键:', list(scores.keys()))
print('所有值:', list(scores.values()))
print('所有键值对:', list(scores.items()))
所有键: ['Alice', 'Bob', 'Charlie']
所有值: [95, 87, 92]
所有键值对: [('Alice', 95), ('Bob', 87), ('Charlie', 92)]
In [34]:
# 遍历字典
for name, score in scores.items():
    print(f'{name}: {score}')
Alice: 95
Bob: 87
Charlie: 92
In [35]:
# 检查键是否存在
print('Alice' in scores)
print('David' in scores)
True
False

默认字典(defaultdict)¶

当访问不存在的键时,普通字典会报错,而 defaultdict 会自动创建默认值。

In [36]:
from collections import defaultdict

# 创建一个默认值为 0 的字典
word_count = defaultdict(int)

text = 'the quick brown fox jumps over the lazy dog the fox'
for word in text.split():
    word_count[word] += 1  # 不需要检查键是否存在

dict(word_count)
Out[36]:
{'the': 3,
 'quick': 1,
 'brown': 1,
 'fox': 2,
 'jumps': 1,
 'over': 1,
 'lazy': 1,
 'dog': 1}
In [37]:
# 默认值为列表
from collections import defaultdict

grouped = defaultdict(list)
students = [('A', 'Alice'), ('B', 'Bob'), ('A', 'Amy'), ('B', 'Brian')]

for grade, name in students:
    grouped[grade].append(name)

dict(grouped)
Out[37]:
{'A': ['Alice', 'Amy'], 'B': ['Bob', 'Brian']}

字典排序¶

使用 sorted() 函数可以对字典进行排序。

In [38]:
scores = {'Alice': 95, 'Bob': 87, 'Charlie': 92, 'David': 88}

# 按键排序
sorted_by_key = dict(sorted(scores.items()))
print('按键排序:', sorted_by_key)

# 按值排序(升序)
sorted_by_value = dict(sorted(scores.items(), key=lambda x: x[1]))
print('按值升序:', sorted_by_value)

# 按值排序(降序)
sorted_by_value_desc = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
print('按值降序:', sorted_by_value_desc)
按键排序: {'Alice': 95, 'Bob': 87, 'Charlie': 92, 'David': 88}
按值升序: {'Bob': 87, 'David': 88, 'Charlie': 92, 'Alice': 95}
按值降序: {'Alice': 95, 'Charlie': 92, 'David': 88, 'Bob': 87}

Python 集合(Set)¶

集合是无序、不重复元素的集合。

  • 元素必须是不可变类型
  • 自动去重
  • 支持数学集合运算(交、并、差等)
In [39]:
# 创建集合
fruits = {'apple', 'banana', 'orange', 'apple'}  # 自动去重
print(fruits)

numbers = set([1, 2, 3, 2, 1])  # 从列表创建
print(numbers)
{'banana', 'orange', 'apple'}
{1, 2, 3}
In [40]:
# 添加和删除元素
fruits.add('grape')
fruits.remove('banana')
fruits
Out[40]:
{'apple', 'grape', 'orange'}
In [41]:
# 集合运算
a = {1, 2, 3, 4}
b = {3, 4, 5, 6}

print('交集:', a & b)      # 或 a.intersection(b)
print('并集:', a | b)      # 或 a.union(b)
print('差集:', a - b)      # 或 a.difference(b)
print('对称差:', a ^ b)    # 或 a.symmetric_difference(b)
交集: {3, 4}
并集: {1, 2, 3, 4, 5, 6}
差集: {1, 2}
对称差: {1, 2, 5, 6}
In [42]:
# 判断包含关系
a = {1, 2, 3, 4}
b = {2, 3}
c = {1, 2, 3, 4}

print('b 是 a 的子集:', b.issubset(a))     # 或 b <= a
print('a 是 b 的超集:', a.issuperset(b))   # 或 a >= b
print('c 与 a 相等:', c == a)
print('b 是 a 的真子集:', b < a)           # 真子集(不等于)
print('a 是 b 的真超集:', a > b)           # 真超集(不等于)
b 是 a 的子集: True
a 是 b 的超集: True
c 与 a 相等: True
b 是 a 的真子集: True
a 是 b 的真超集: True
In [43]:
# 判断两个集合是否无交集
x = {1, 2}
y = {3, 4}
z = {2, 3}

print('x 和 y 无交集:', x.isdisjoint(y))  # True
print('x 和 z 无交集:', x.isdisjoint(z))  # False
x 和 y 无交集: True
x 和 z 无交集: False

本讲小结¶

  • 牛顿法:数值方法求解一般方程
  • Python
    • f-字符串
    • 函数装饰器
    • 字典与集合