模块#
- class flax.nnx.Module(*args, **kwargs)[源代码]#
所有神经网络模块的基类。
层和模型应继承此类。
Module
可以包含子模块,并且可以通过这种方式嵌套在树结构中。子模块可以在__init__
方法中作为常规属性进行赋值。你可以在
Module
子类上定义任意的“前向传递”方法。虽然没有特殊的方法,但__call__
是一个流行的选择,因为你可以直接调用Module
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- eval(**attributes)[源代码]#
将模块设置为评估模式。
eval
使用set_attributes
递归设置所有具有这些属性的嵌套模块的属性deterministic=True
和use_running_average=True
。它主要用于控制Dropout
和BatchNorm
模块的运行时行为。示例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
- 参数
**attributes – 传递给
set_attributes
的附加属性。
- iter_children()[源代码]#
迭代当前模块的所有子
Module
。此方法类似于iter_modules()
,除了它只迭代直接子项,不会进一步递归。iter_children
创建一个生成器,它会产生键和模块实例,其中键是一个字符串,表示访问相应子模块的模块的属性名称。示例
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_children(): ... print(path, type(module).__name__) ... batch_norm BatchNorm dropout Dropout linear Linear submodule SubModule
- iter_modules()[源代码]#
递归迭代当前模块的所有嵌套
Module
,包括当前模块。iter_modules
创建一个生成器,它会产生路径和模块实例,其中路径是一个字符串或整数的元组,表示从根模块到模块的路径。示例
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) ... ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear ('submodule', 'linear1') Linear ('submodule', 'linear2') Linear ('submodule',) SubModule () Block
- set_attributes(*filters, raise_if_not_found=True, **attributes)[源代码]#
设置嵌套模块的属性,包括当前模块。如果在模块中找不到该属性,则会忽略它。
示例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
Filter
可用于设置特定模块的属性>>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False)
- 参数
*filters – 用于选择要设置属性的模块的过滤器。
raise_if_not_found – 如果为 True(默认),则当在选定的模块之一中找不到至少一个属性实例时,会引发 ValueError。
**attributes – 要设置的属性。
- sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[源代码]#
sow()
可用于收集中间值,而无需显式地通过每次模块调用传递容器。sow()
将值存储在新的Module
属性中,由name
表示。该值将由Variable
类型variable_type
包裹,这对于在split()
、state()
和pop()
中进行筛选非常有用。默认情况下,这些值存储在一个元组中,并且每个存储的值都会附加在末尾。这样,当多次调用同一模块时,可以跟踪所有中间值。
示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x, add=0): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x+add) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> assert len(model.i.value) == 1 # tuple of length 1 >>> assert model.i.value[0].shape == (1, 3) >>> y = model(x, add=1) >>> assert len(model.i.value) == 2 # tuple of length 2 >>> assert (model.i.value[0] + 1 == model.i.value[1]).all()
或者,可以传递自定义的 init/reduce 函数
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... self.sow(nnx.Intermediate, 'product', x, ... init_fn=lambda: 1, ... reduce_fn=lambda prev, curr: prev*curr) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x) >>> assert (model.sum.value == model.product.value).all() >>> intermediate = model.sum.value >>> y = model(x) >>> assert (model.sum.value == intermediate*2).all() >>> assert (model.product.value == intermediate**2).all()
- 参数
variable_type – 存储值的
Variable
类型。通常,Intermediate
用于表示中间值。name – 一个字符串,表示
Module
属性名称,其中存储了播种值。value – 要存储的值。
reduce_fn – 用于将现有值与新值组合的函数。默认是将值追加到元组中。
init_fn – 对于存储的第一个值,
reduce_fn
将会传递init_fn
的结果以及要存储的值。默认值是一个空元组。
- train(**attributes)[源代码]#
将模块设置为训练模式。
train
使用set_attributes
递归设置所有具有这些属性的嵌套模块的属性deterministic=False
和use_running_average=False
。它主要用于控制Dropout
和BatchNorm
模块的运行时行为。示例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False)
- 参数
**attributes – 传递给
set_attributes
的附加属性。