初始化/应用

初始化/应用#

flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#

创建一个应用函数来调用绑定模块的 fn

Module.apply 不同,此函数返回一个新的函数,其签名为 (variables, *args, rngs=None, **kwargs) -> T,其中 Tfn 的返回类型。如果 mutable 不是 False,则返回类型为元组,其中第二项是包含已修改变量的 FrozenDict

返回的应用函数可以直接与 JAX 转换(如 jax.jit)组合使用。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> variables = {}
>>> foo = Foo()
>>> f_jitted = jax.jit(nn.apply(f, foo))
>>> f_jitted(variables, jnp.ones((1, 3)))
参数
  • fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的 module 的模块实例。

  • module – 将用于将变量和 RNG 绑定到的 Module。作为 fn 的第一个参数传递的 Module 将是模块的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:bool:所有/没有集合是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。

  • capture_intermediates – 如果 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回值

包装 fn 的应用函数。

flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

创建一个 init 函数来调用绑定模块的 fn

Module.init 不同,此函数返回一个新的函数,其签名为 (rngs, *args, **kwargs) -> variables。rngs 可以是 PRNGKey 字典或单个 `PRNGKey,它等效于传递一个字典,其中一个 PRNGKey 的名称为“params”。

返回的 init 函数可以直接与 JAX 转换(如 jax.jit)组合使用。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init(f, foo))
>>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
参数
  • fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的 module 的模块实例。

  • module – 将用于将变量和 RNG 绑定到的 Module。作为 fn 的第一个参数传递的 Module 将是模块的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:bool:所有/没有集合是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。默认情况下,除“intermediates”之外的所有集合都是可变的。

  • capture_intermediates – 如果 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回值

包装 fn 的 init 函数。

flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

创建一个 init 函数来调用绑定模块的 fn,该函数还返回函数输出。

Module.init_with_output 不同,此函数返回一个新的函数,其签名为 (rngs, *args, **kwargs) -> (T, variables),其中 Tfn 的返回类型。rngs 可以是 PRNGKey 字典或单个 `PRNGKey,它等效于传递一个字典,其中一个 PRNGKey 的名称为“params”。

返回的 init 函数可以直接与 JAX 转换(如 jax.jit)组合使用。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init_with_output(f, foo))
>>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
参数
  • fn – 应该应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的 module 的模块实例。

  • module – 将用于将变量和 RNG 绑定到的 Module。作为 fn 的第一个参数传递的 Module 将是模块的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:bool:所有/没有集合是可变的。 str:单个可变集合的名称。 list:可变集合名称的列表。默认情况下,除“intermediates”之外的所有集合都是可变的。

  • capture_intermediates – 如果 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改过滤行为。过滤函数采用模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回值

包装 fn 的 init 函数。