模块#

class flax.nnx.Module(*args, **kwargs)[源代码]#

所有神经网络模块的基类。

层和模型应继承此类。

Module 可以包含子模块,并且可以通过这种方式嵌套在树结构中。子模块可以在 __init__ 方法中作为常规属性进行赋值。

你可以在 Module 子类上定义任意的“前向传递”方法。虽然没有特殊的方法,但 __call__ 是一个流行的选择,因为你可以直接调用 Module

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
eval(**attributes)[源代码]#

将模块设置为评估模式。

eval 使用 set_attributes 递归设置所有具有这些属性的嵌套模块的属性 deterministic=Trueuse_running_average=True。它主要用于控制 DropoutBatchNorm 模块的运行时行为。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
参数

**attributes – 传递给 set_attributes 的附加属性。

iter_children()[源代码]#

迭代当前模块的所有子 Module。此方法类似于 iter_modules(),除了它只迭代直接子项,不会进一步递归。

iter_children 创建一个生成器,它会产生键和模块实例,其中键是一个字符串,表示访问相应子模块的模块的属性名称。

示例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
iter_modules()[源代码]#

递归迭代当前模块的所有嵌套 Module,包括当前模块。

iter_modules 创建一个生成器,它会产生路径和模块实例,其中路径是一个字符串或整数的元组,表示从根模块到模块的路径。

示例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
set_attributes(*filters, raise_if_not_found=True, **attributes)[源代码]#

设置嵌套模块的属性,包括当前模块。如果在模块中找不到该属性,则会忽略它。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter 可用于设置特定模块的属性

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
参数
  • *filters – 用于选择要设置属性的模块的过滤器。

  • raise_if_not_found – 如果为 True(默认),则当在选定的模块之一中找不到至少一个属性实例时,会引发 ValueError。

  • **attributes – 要设置的属性。

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[源代码]#

sow() 可用于收集中间值,而无需显式地通过每次模块调用传递容器。sow() 将值存储在新的 Module 属性中,由 name 表示。该值将由 Variable 类型 variable_type 包裹,这对于在 split()state()pop() 中进行筛选非常有用。

默认情况下,这些值存储在一个元组中,并且每个存储的值都会附加在末尾。这样,当多次调用同一模块时,可以跟踪所有中间值。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

或者,可以传递自定义的 init/reduce 函数

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
参数
  • variable_type – 存储值的 Variable 类型。通常,Intermediate 用于表示中间值。

  • name – 一个字符串,表示 Module 属性名称,其中存储了播种值。

  • value – 要存储的值。

  • reduce_fn – 用于将现有值与新值组合的函数。默认是将值追加到元组中。

  • init_fn – 对于存储的第一个值,reduce_fn 将会传递 init_fn 的结果以及要存储的值。默认值是一个空元组。

train(**attributes)[源代码]#

将模块设置为训练模式。

train 使用 set_attributes 递归设置所有具有这些属性的嵌套模块的属性 deterministic=Falseuse_running_average=False。它主要用于控制 DropoutBatchNorm 模块的运行时行为。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
参数

**attributes – 传递给 set_attributes 的附加属性。