初始化/应用#
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#
创建一个应用函数来调用绑定模块的
fn
。与
Module.apply
不同,此函数返回一个新的函数,其签名为(variables, *args, rngs=None, **kwargs) -> T
,其中T
是fn
的返回类型。如果mutable
不是False
,则返回类型为元组,其中第二项是包含已修改变量的FrozenDict
。返回的应用函数可以直接与 JAX 转换(如
jax.jit
)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- 参数
fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module
的模块实例。module – 将用于将变量和 RNG 绑定到的
Module
。作为fn
的第一个参数传递的Module
将是模块的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。capture_intermediates – 如果
True
,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回值
包装
fn
的应用函数。
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
创建一个 init 函数来调用绑定模块的
fn
。与
Module.init
不同,此函数返回一个新的函数,其签名为(rngs, *args, **kwargs) -> variables
。rngs 可以是 PRNGKey 字典或单个`PRNGKey
,它等效于传递一个字典,其中一个 PRNGKey 的名称为“params”。返回的 init 函数可以直接与 JAX 转换(如
jax.jit
)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module
的模块实例。module – 将用于将变量和 RNG 绑定到的
Module
。作为fn
的第一个参数传递的Module
将是模块的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回值
包装
fn
的 init 函数。
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#
创建一个 init 函数来调用绑定模块的
fn
,该函数还返回函数输出。与
Module.init_with_output
不同,此函数返回一个新的函数,其签名为(rngs, *args, **kwargs) -> (T, variables)
,其中T
是fn
的返回类型。rngs 可以是 PRNGKey 字典或单个`PRNGKey
,它等效于传递一个字典,其中一个 PRNGKey 的名称为“params”。返回的 init 函数可以直接与 JAX 转换(如
jax.jit
)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module
的模块实例。module – 将用于将变量和 RNG 绑定到的
Module
。作为fn
的第一个参数传递的Module
将是模块的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果
True
,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回值
包装
fn
的 init 函数。