
在jax中,进行自动微分的核心工具是jax.grad。然而,与pytorch等框架不同,jax秉持函数式编程范式,其jax.grad函数期望接收一个以待微分参数为输入的纯函数,并返回该函数对这些参数的梯度。当我们在自定义的类结构中封装模型参数时,直接对模型输出求导往往无法得到我们期望的模型权重梯度。
例如,原始代码中尝试使用grads = jax.grad(criterion)(out, target)。这里的criterion函数接收的是模型的输出out和目标值target。jax.grad会计算criterion对out和target的梯度。由于out和target通常不是模型的可训练参数(如权重和偏置),因此得到的梯度并非我们所寻求的模型参数梯度,而可能只是损失函数对其直接输入的梯度。
为了正确计算模型内部参数(如线性层的权重和偏置)的梯度,我们需要解决两个关键问题:
- 梯度函数输入: jax.grad需要一个函数,其第一个(或指定)参数就是我们希望求导的参数集合(例如,整个模型实例或一个包含所有权重的PyTree)。
- 参数结构可识别性: JAX需要能够理解自定义类中哪些部分是可训练参数,以及如何遍历这些参数。这意味着自定义类需要被注册为JAX的“PyTree”结构。
为了满足jax.grad对函数输入的要求,我们需要创建一个新的损失函数,该函数将模型实例(或其参数)作为其第一个参数。然后,该函数在内部调用模型进行前向传播,并计算损失。
考虑以下重构后的损失函数:
def compute_loss(model_instance, inputs, target):
"""
计算模型在给定输入和目标下的损失。
model_instance: 模型的实例,包含所有参数。
inputs: 模型的输入数据。
target: 目标值。
"""
output = model_instance(inputs)
loss_value = criterion(output, target)
return loss_value 现在,我们可以使用jax.grad来计算compute_loss函数对model_instance的梯度。argnums=0指示jax.grad对第一个参数(即model_instance)求导。
# 假设 model, data, target 已经定义 grads = jax.grad(compute_loss, argnums=0)(model, data, target)
然而,仅仅重构损失函数是不够的。jax.grad仍然需要知道如何“深入”到model实例内部,找到weights和biases等JAX数组并计算它们的梯度。这就引出了第二个解决方案。
解决方案二:注册自定义类为 PyTreeJAX的PyTree(Python Tree)是一种数据结构,它可以是任意嵌套的Python容器(如列表、元组、字典、dataclass),其叶子节点是JAX数组。jax.grad以及其他JAX转换(如jax.jit, jax.vmap)能够自动遍历PyTree结构。当JAX遇到一个非标准Python容器(例如我们自定义的Module子类实例)时,它不知道如何处理其内部结构。为了让JAX理解自定义类,我们需要将其注册为PyTree。
jax.tree_util.register_pytree_node函数用于注册自定义类型。它需要三个参数:
- cls: 要注册的类。
- flatten_func: 一个函数,接收cls的实例,返回一个元组(children, static_data)。
- children: 一个PyTree,包含所有可变(通常是JAX数组)的子组件,这些是jax.grad需要跟踪的部分。
- static_data: 一个元组,包含所有不可变(通常是Python原生类型)的静态数据,这些数据在PyTree扁平化和重建过程中保持不变。
- unflatten_func: 一个函数,接收static_data和children,重建cls的实例。
下面,我们将为示例中的Linear、Activation和Model类注册为PyTree。
Teleporthq
一体化AI网站生成器,能够快速设计和部署静态网站
182
查看详情
注册 Linear 类
Linear类包含weights和biases这两个JAX数组(可变参数),以及in_features和out_features这两个整数(静态数据)。
import jax
import jax.numpy as jnp
from jax import tree_util
# ... (Module, Linear, Activation, Model 类的定义保持不变) ...
# 注册 Linear 类为 PyTree
def _linear_flatten(obj):
# children 是可变部分(JAX数组),需要被跟踪梯度
children = (obj.weights, obj.biases)
# static_data 是不可变部分,不需要跟踪梯度
static_data = (obj.in_features, obj.out_features)
return children, static_data
def _linear_unflatten(static_data, children):
in_features, out_features = static_data
weights, biases = children
# 创建一个新的 Linear 实例,并直接设置其权重和偏置
# 注意:这里需要一个key来初始化,但在unflatten时我们只是重建,
# 实际的key在模型初始化时已经使用过。为了兼容,我们传递一个dummy key
# 或者修改Linear的__init__方法使其可以接受预先存在的weights/biases
# 更好的方式是修改Linear的__init__以支持从现有参数重建
# 但为了保持原始结构,我们暂时用一个dummy key,并手动设置参数
new_instance = Linear(key=jax.random.PRNGKey(0),
in_features=in_features,
out_features=out_features)
new_instance.weights = weights
new_instance.biases = biases
return new_instance
tree_util.register_pytree_node(Linear, _linear_flatten, _linear_unflatten) 注意: 在_linear_unflatten中,Linear的__init__方法需要一个key。在实际场景中,为了更好地支持PyTree重建,Linear的__init__可能需要修改,允许直接传入weights和biases,而不是通过key生成。为了与原问题代码兼容,我们在此使用了一个dummy key并在创建实例后手动赋值。
注册 Activation 类Activation类没有可训练参数,只有静态信息(即无)。
# 注册 Activation 类为 PyTree
def _activation_flatten(obj):
children = () # 没有可训练参数
static_data = () # 没有静态属性需要保留
return children, static_data
def _activation_unflatten(static_data, children):
return Activation() # 直接创建实例
tree_util.register_pytree_node(Activation, _activation_flatten, _activation_unflatten) 注册 Model 类
Model类包含linear和activation这两个子模块,它们本身也是PyTree。Model类没有额外的静态数据需要保留。
# 注册 Model 类为 PyTree
def _model_flatten(obj):
# children 是其子模块,它们本身也是 PyTree
children = (obj.linear, obj.activation)
static_data = () # Model本身没有额外的静态属性需要保留
return children, static_data
def _model_unflatten(static_data, children):
linear_module, activation_module = children
# 创建一个新的 Model 实例,并直接设置其子模块
# 类似 Linear,Model 的 __init__ 也需要 key, in_features, out_features
# 同样为了兼容,这里传递 dummy values 并手动设置子模块
new_instance = Model(key=jax.random.PRNGKey(0),
in_features=1, out_features=1) # dummy values
new_instance.linear = linear_module
new_instance.activation = activation_module
return new_instance
tree_util.register_pytree_node(Model, _model_flatten, _model_unflatten) 再次注意: _model_unflatten也面临与_linear_unflatten类似的问题,Model的__init__需要key, in_features, out_features。在实际生产代码中,应设计更灵活的__init__方法或使用更高级的PyTree构建方式。
完整示例代码将上述解决方案整合到原始代码中:
import jax
import jax.numpy as jnp
from jax import tree_util
# --- Module 类定义 ---
class Module:
def __init__(self) -> None:
pass
def __call__(self, inputs: jax.Array):
return self.forward(inputs)
class Linear(Module):
def __init__(self, key: jax.Array, in_features: int, out_features: int) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
key_w, key_b = jax.random.split(key=key, num=2)
self.weights = jax.random.normal(key=key_w, shape=(out_features, in_features))
self.biases = jax.random.normal(key=key_b, shape=(out_features,))
def forward(self, inputs: jax.Array) -> jax.Array:
out = jnp.dot(self.weights, inputs) + self.biases
return out
class Activation(Module):
def __init__(self) -> None:
super().__init__()
pass
def forward(self, inputs: jax.Array) -> jax.Array:
return jax.nn.sigmoid(inputs)
class Model(Module):
def __init__(self, key: jax.Array, in_features: int, out_features: int) -> None:
super().__init__()
self.linear = Linear(key=key, in_features=in_features, out_features=out_features)
self.activation = Activation()
def forward(self, inputs: jax.Array) -> jax.Array:
out = self.linear(inputs)
out = self.activation(out)
return out
def criterion(output: jax.Array, target: jax.Array):
return ((target - output) ** 2).sum()
# --- PyTree 注册 ---
def _linear_flatten(obj):
children = (obj.weights, obj.biases)
static_data = (obj.in_features, obj.out_features)
return children, static_data
def _linear_unflatten(static_data, children):
in_features, out_features = static_data
weights, biases = children
# 为了 PyTree 重建,可以修改 Linear 的 __init__ 以接受预设参数
# 这里为了兼容,使用 dummy key 并手动赋值
new_instance = Linear(key=jax.random.PRNGKey(0), in_features=in_features, out_features=out_features)
new_instance.weights = weights
new_instance.biases = biases
return new_instance
tree_util.register_pytree_node(Linear, _linear_flatten, _linear_unflatten)
def _activation_flatten(obj):
children = ()
static_data = ()
return children, static_data
def _activation_unflatten(static_data, children):
return Activation()
tree_util.register_pytree_node(Activation, _activation_flatten, _activation_unflatten)
def _model_flatten(obj):
children = (obj.linear, obj.activation)
static_data = ()
return children, static_data
def _model_unflatten(static_data, children):
linear_module, activation_module = children
# 同样,为了兼容,使用 dummy values 并手动赋值
new_instance = Model(key=jax.random.PRNGKey(0), in_features=1, out_features=1)
new_instance.linear = linear_module
new_instance.activation = activation_module
return new_instance
tree_util.register_pytree_node(Model, _model_flatten, _model_unflatten)
# --- 重构损失函数 ---
def compute_loss(model_instance, inputs, target):
output = model_instance(inputs)
loss_value = criterion(output, target)
return loss_value
if __name__ == "__main__":
in_features: int = 4
out_features: int = 1
key = jax.random.PRNGKey(67)
model = Model(key=key, in_features=in_features, out_features=out_features)
key = jax.random.PRNGKey(68)
data = jax.random.normal(key=key, shape=(in_features,))
target = jnp.array([2.0]) # 确保 target 是一个 JAX 数组
out = model(data)
print(f"Model output: {out = }")
loss = compute_loss(model_instance=model, inputs=data, target=target)
print(f"Computed loss: {loss = }")
# 计算模型参数的梯度
grads = jax.grad(compute_loss, argnums=0)(model, data, target)
print(f"\nComputed gradients for model parameters: {grads = }")
# 验证梯度结构
print(f"\nGradient for Linear weights: {grads.linear.weights}")
print(f"Gradient for Linear biases: {grads.linear.biases}") 运行上述代码,你将看到grads是一个与model结构相同的PyTree,其中包含了linear.weights和linear.biases的梯度。
注意事项与最佳实践- 参数管理: 在JAX的函数式编程范式中,模型参数通常作为独立的PyTree结构进行管理,并通过函数传递,而不是存储在可变的对象中。虽然上述方法通过PyTree注册使得自定义类可微分,但这仍然保留了一定的面向对象风格。更“JAX-y”的方式是显式地将参数和模型逻辑分离。
- __init__方法设计: 在_unflatten函数中,我们不得不使用dummy key并手动设置属性,这是因为原始__init__方法强制要求key来初始化参数。在设计自定义模块时,考虑为__init__添加一个可选参数,允许直接传入预先存在的权重和偏置,以简化PyTree的重建过程。
- 框架辅助: 对于复杂的模型结构和参数管理,手动注册PyTree可能会变得繁琐且容易出错。推荐使用专门为JAX设计的深度学习框架,如Flax或Equinox。这些框架提供了预注册的模块类和参数管理工具,极大地简化了模型构建和梯度计算的过程。例如,Equinox的eqx.Module会自动处理PyTree注册,并支持更直观的参数管理。
在JAX中为自定义类计算梯度,需要理解并遵循其核心的函数式编程原则。这主要涉及两个步骤:首先,将损失函数重构为接收模型(或其参数)作为输入的纯函数;其次,通过jax.tree_util.register_pytree_node将自定义类注册为JAX可识别的PyTree结构。通过这两个步骤,JAX的jax.grad就能够正确遍历模型内部的JAX数组,并计算出所需的参数梯度。对于更复杂的模型,利用Flax或Equinox等框架可以提供更高效、更简洁的解决方案。
以上就是JAX中自定义模块的梯度计算:PyTree注册与函数式转换的详细内容,更多请关注知识资源分享宝库其它相关文章!
相关标签: python node 工具 ai 深度学习 pytorch Python 面向对象 封装 子类 可变参数 数据结构 对象 pytorch 重构 大家都在看: 检测字符串中是否包含元音字母的 Python 方法 Python 检测 Ctrl+R 组合键并重启程序教程 使用Python监听Ctrl+R组合键并重启程序 使用 Python 在 Synapse Notebook 中替换表格参数值 使用 Python 检测 Ctrl+R 组合键并重启程序






发表评论:
◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。