flax.errors package#
Flax has the following classes of errors.
- exception flax.errors.AlreadyExistsError(path)[source]#
Attempting to overwrite a file via copy.
You can pass
overwrite=True
to disable this behavior and overwrite existing files in.
- exception flax.errors.ApplyModuleInvalidMethodError(method)[source]#
When calling
Module.apply()
, you can specifythe method to apply using parameter
method
. This error is thrown if the provided parameter is not a method in the Module and not a function with at least one argument.Learn more on the reference docs for
Module.apply()
.
- exception flax.errors.ApplyScopeInvalidVariablesStructureError(variables)[source]#
This error is thrown when the dict passed as
variables
to apply() has anextra ‘params’ layer, i.e. {‘params’: {‘params’: …}}. For more explanation on variable dicts, please see
flax.core.variables
.
- exception flax.errors.ApplyScopeInvalidVariablesTypeError[source]#
When calling
Module.apply()
, the firstargument should be a variable dict. For more explanation on variable dicts, please see
flax.core.variables
.
- exception flax.errors.AssignSubModuleError(cls)[source]#
You are only allowed to create submodules in two places:
If your Module is noncompact: inside
Module.setup()
.If your Module is compact: inside the method wrapped in
nn.compact()
.
For instance, the following code throws this error, because
nn.Conv
is created in__call__
, which is not marked as compact:class Foo(nn.Module): def setup(self): pass def __call__(self, x): conv = nn.Conv(features=3, kernel_size=3) Foo().init(random.key(0), jnp.zeros((1,)))
Note that this error is also thrown if you partially defined a Module inside setup:
class Foo(nn.Module): def setup(self): self.conv = functools.partial(nn.Conv, features=3) def __call__(self, x): x = self.conv(kernel_size=4)(x) return x Foo().init(random.key(0), jnp.zeros((1,)))
In this case,
self.conv(kernel_size=4)
is called from__call__
, which is disallowed because it’s neither withinsetup
nor a method wrapped in x``nn.compact``.
- exception flax.errors.CallCompactUnboundModuleError[source]#
This error occurs when you are trying to call a Module directly, rather than
through
Module.apply()
. For instance, the error will be raised when trying to run this code:from flax import linen as nn import jax.numpy as jnp test_dense = nn.Dense(10) test_dense(jnp.ones((5,5)))
Instead, you should pass the variables (parameters and other state) via
Module.apply()
(or useModule.init()
to get initial variables):from jax import random variables = test_dense.init(random.key(0), jnp.ones((5,5))) y = test_dense.apply(variables, jnp.ones((5,5)))
- exception flax.errors.CallSetupUnboundModuleError[source]#
This error occurs when you are trying to call
.setup()
directly.For instance, the error will be raised when trying to run this code:
from flax import linen as nn import jax.numpy as jnp class MyModule(nn.Module): def setup(self): self.submodule = MySubModule() module = MyModule() module.setup() # <-- ERROR! submodule = module.submodule
In general you shouldn’t call
.setup()
yourself, if you need to get access to a field or submodule defined insidesetup
you can instead create a function to extract it and pass it tonn.apply
:# setup() will be called automatically by ``nn.apply`` def get_submodule(module): return module.submodule.clone() # avoid leaking the Scope empty_variables = {} # you can also use the real variables submodule = nn.apply(get_submodule, module)(empty_variables)
This error occurs when you are trying to call
nn.share_scope
on an unbound Module. For instance, when you try to usenn.share_scope
at the top-level:from flax import linen as nn class CustomDense(nn.Dense): def __call__(self, x): return super().__call__(x) + 1 custom_dense = CustomDense(5) dense = nn.Dense(5) # has the parameters nn.share_scope(custom_dense, dense) # <-- ERROR!
- exception flax.errors.CallUnbindOnUnboundModuleError[source]#
This error occurs when you are trying to call
.unbind()
on an unbound Module. For instance, when you try running the following example, an error will be raised:from flax import linen as nn class MyModule(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(features=10)(x) module = MyModule() module.unbind() # <-- ERROR!
Instead, you should
bind
the Module to a variable collection before calling.unbind()
:bound_module = module.bind(variables) ... # do something with bound_module module = bound_module.unbind() # <-- OK!
- exception flax.errors.CursorFindError(cursor=None, cursor2=None)[source]#
Error when calling
Cursor.find()
.This error occurs if no object or more than one object is found, given the conditions of the
cond_fn
.
- exception flax.errors.DescriptorAttributeError[source]#
This error occurs when you are trying to access a property that is accessing a non-existent attribute.
For example, the error will be raised when trying to run this code:
class Foo(nn.Module): @property def prop(self): return self.non_existent_field # ERROR! def __call__(self, x): return self.prop foo = Foo() variables = foo.init(jax.random.key(0), jnp.ones(shape=(1, 8)))
- exception flax.errors.IncorrectPostInitOverrideError[source]#
This error occurs when you overrode
.__post_init__()
without callingsuper().__post_init__()
.For example, the error will be raised when trying to run this code:
from flax import linen as nn import jax.numpy as jnp import jax class A(nn.Module): x: float def __post_init__(self): self.x_square = self.x ** 2 # super().__post_init__() <-- forgot to add this line @nn.compact def __call__(self, input): return input + 3 r = A(x=3) r.init(jax.random.key(2), jnp.ones(3))
- exception flax.errors.InvalidCheckpointError(path, step)[source]#
A checkpoint cannot be stored in a directory that already has
a checkpoint at the current or a later step.
You can pass
overwrite=True
to disable this behavior and overwrite existing checkpoints in the target directory.
- exception flax.errors.InvalidFilterError(filter_like)[source]#
A filter should be either a boolean, a string or a container object.
- exception flax.errors.InvalidInstanceModuleError[source]#
This error occurs when you are trying to call
.init()
,.init_with_output()
,.apply()
or.bind()
on the Module class itself, instead of an instance of the Module class. For example, the error will be raised when trying to run this code:
class B(nn.Module): @nn.compact def __call__(self, x): return x k = random.key(0) x = random.uniform(random.key(1), (2,)) B.init(k, x) # B is module class, not B() a module instance B.apply(vs, x) # similar issue with apply called on class instead of instance.
- exception flax.errors.InvalidRngError(msg)[source]#
All rngs used in a Module should be passed to
Module.init()
andModule.apply()
appropriately. We explain both separately using the following example:class Bar(nn.Module): @nn.compact def __call__(self, x): some_param = self.param('some_param', nn.initializers.zeros_init(), (1, )) dropout_rng = self.make_rng('dropout') x = nn.Dense(features=4)(x) ... class Foo(nn.Module): @nn.compact def __call__(self, x): x = Bar()(x) ...
PRNGs for Module.init()
In this example, two rngs are used:
params
is used for initializing the parameters of the model. This rng is used to initialize thesome_params
parameter, and for initializing the weights of theDense
Module used inBar
.dropout
is used for the dropout rng that is used inBar
.
So,
Foo
is initialized as follows:init_rngs = {'params': random.key(0), 'dropout': random.key(1)} variables = Foo().init(init_rngs, init_inputs)
If a Module only requires an rng for
params
, you can use:SomeModule().init(rng, ...) # Shorthand for {'params': rng}
PRNGs for Module.apply()
When applying
Foo
, only the rng fordropout
is needed, becauseparams
is only used for initializing the Module parameters:Foo().apply(variables, inputs, rngs={'dropout': random.key(2)})
If a Module only requires an rng for
params
, you don’t have to provide rngs for apply at all:SomeModule().apply(variables, inputs) # rngs=None
- exception flax.errors.InvalidScopeError(scope_name)[source]#
A temporary Scope is only valid within the context in which it is created:
- with Scope(variables, rngs=rngs).temporary() as root:
# Here root is invalid.
- exception flax.errors.JaxTransformError[source]#
JAX transforms and Flax modules cannot be mixed.
JAX’s functional transformations expect pure function. When you want to use JAX transformations inside Flax models, you should make use of the Flax transformation wrappers (e.g.:
flax.linen.vmap
,flax.linen.scan
, etc.).
- exception flax.errors.LazyInitError(partial_val)[source]#
Lazy Init function has uncomputable return values.
This happens when passing an argument to lazy_init with
jax.ShapeDtypeStruct
that affects the initialized variables. Make sure the init function only uses the shape and dtype or pass an actual JAX array if this is impossible.Example:
class Foo(nn.Module): @compact def __call__(self, x): # This parameter depends on the input x # this causes an error when using lazy_init. k = self.param("kernel", lambda _: x) return x * k Foo().lazy_init(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32))
- exception flax.errors.MPACheckpointingRequiredError(path, step)[source]#
To optimally save and restore a multiprocess array (GDA or jax Array outputted from pjit), use GlobalAsyncCheckpointManager.
You can create an GlobalAsyncCheckpointManager at top-level and pass it as argument:
from jax.experimental.gda_serialization import serialization as gdas gda_manager = gdas.GlobalAsyncCheckpointManager() save_checkpoint(..., gda_manager=gda_manager)
- exception flax.errors.MPARestoreDataCorruptedError(step, path)[source]#
A multiprocess array stored in Google Cloud Storage doesn’t contain a “commit_success.txt” file, which should be written at the end of the save.
Failure of finding it could indicate a corruption of your saved GDA data.
- exception flax.errors.MPARestoreTargetRequiredError(path, step, key=None)[source]#
Provide a valid target when restoring a checkpoint with a multiprocess array.
Multiprocess arrays need a sharding (global meshes and partition specs) to be initialized. Therefore, to restore a checkpoint that contains a multiprocess array, make sure the
target
you passed contains valid multiprocess arrays at the corresponding tree structure location. If you cannot provide a full validtarget
, considerallow_partial_mpa_restoration=True
.
- exception flax.errors.ModifyScopeVariableError(col, variable_name, scope_path)[source]#
You cannot update a variable if the collection it belongs to is immutable.
When you are applying a Module, you should specify which variable collections are mutable:
class MyModule(nn.Module): @nn.compact def __call__(self, x): ... var = self.variable('batch_stats', 'mean', ...) var.value = ... ... v = MyModule.init(...) ... logits = MyModule.apply(v, batch) # This throws an error. logits = MyModule.apply(v, batch, mutable=['batch_stats']) # This works.
- exception flax.errors.MultipleMethodsCompactError[source]#
The
@compact
decorator may only be added to at most one method in a Flaxmodule. In order to resolve this, you can:
remove
@compact
and define submodules and variables usingModule.setup()
.Use two separate modules that both have a unique
@compact
method.
TODO(marcvanzee): Link to a design note explaining the motivation behind this. There is no need for an equivalent to
hk.transparent
and it makes submodules much more sane because there is no need to prefix the method names.
- exception flax.errors.NameInUseError(key_type, value, module_name)[source]#
This error is raised when trying to create a submodule, param, or variable
with an existing name. They are all considered to be in the same namespace.
Sharing Submodules
This is the wrong pattern for sharing submodules:
y = nn.Dense(feature=3, name='bar')(x) z = nn.Dense(feature=3, name='bar')(x+epsilon)
Instead, modules should be shared by instance:
dense = nn.Dense(feature=3, name='bar') y = dense(x) z = dense(x+epsilon)
If submodules are not provided with a name, a unique name will be given to them automatically:
class MyModule(nn.Module): @nn.compact def __call__(self, x): x = MySubModule()(x) x = MySubModule()(x) # This is fine. return x
Parameters and Variables
A parameter name can collide with a submodule or variable, since they are all stored in the same variable dict:
class Foo(nn.Module): @nn.compact def __call__(self, x): bar = self.param('bar', nn.initializers.zeros_init(), (1, )) embed = nn.Embed(num_embeddings=2, features=5, name='bar') # <-- ERROR!
Variables should also have unique names, even if they have their own collection:
class Foo(nn.Module): @nn.compact def __call__(self, inputs): _ = self.param('mean', initializers.lecun_normal(), (2, 2)) _ = self.variable('stats', 'mean', initializers.zeros_init(), (2, 2))
- exception flax.errors.PartitioningUnspecifiedError(target)[source]#
This error is raised when trying to add an axis to a Partitioned variable by
using a transformation (e.g.:
scan
,vmap
) without specifying the “partition_name” in themetadata_params
dict.
- exception flax.errors.ReservedModuleAttributeError(annotations)[source]#
This error is thrown when creating a Module that is using reserved attributes.
The following attributes are reserved:
parent
: The parent Module of this Module.name
: The name of this Module.
- exception flax.errors.ScopeCollectionNotFound(col_name, var_name, scope_path)[source]#
This error is thrown when trying to access a variable from an empty collection.
There are two common causes:
- The collection was not passed to
apply
correctly.For example, you might have usedmodule.apply(params, ...)
insteadofmodule.apply({'params': params}, ...)
. - The collection is empty because the variables need to be initialized.In this case, you should have made the collection mutable duringapply (e.g.:
module.apply(variables, ..., mutable=['state'])
.
- exception flax.errors.ScopeParamNotFoundError(param_name, scope_path)[source]#
This error is thrown when trying to access a parameter that does not exist.
For instance, in the code below, the initialized embedding name ‘embedding’ does not match the apply name ‘embed’:
class Embed(nn.Module): num_embeddings: int features: int @nn.compact def __call__(self, inputs, embed_name='embedding'): inputs = inputs.astype('int32') embedding = self.param(embed_name, jax.nn.initializers.lecun_normal(), (self.num_embeddings, self.features)) return embedding[inputs] model = Embed(4, 8) variables = model.init(random.key(0), jnp.ones((5, 5, 1))) _ = model.apply(variables, jnp.ones((5, 5, 1)), 'embed')
- exception flax.errors.ScopeParamShapeError(param_name, scope_path, value_shape, init_shape)[source]#
This error is thrown when the shape of an existing parameter is different from
the shape of the return value of the
init_fn
. This can happen when the shape provided duringModule.apply()
is different from the one used when initializing the module.For instance, the following code throws this error because the apply shape (
(5, 5, 1)
) is different from the init shape ((5, 5
). As a result, the shape of the kernel duringinit
is(1, 8)
, and the shape duringapply
is(5, 8)
, which results in this error.:class NoBiasDense(nn.Module): features: int = 8 @nn.compact def __call__(self, x): kernel = self.param('kernel', lecun_normal(), (x.shape[-1], self.features)) # <--- ERROR y = lax.dot_general(x, kernel, (((x.ndim - 1,), (0,)), ((), ()))) return y variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1))) _ = NoBiasDense().apply(variables, jnp.ones((5, 5)))
- exception flax.errors.ScopeVariableNotFoundError(name, col, scope_path)[source]#
This error is thrown when trying to use a variable in a Scope in a collection
that is immutable. In order to create this variable, mark the collection as mutable explicitly using the
mutable
keyword inModule.apply()
.
- exception flax.errors.SetAttributeFrozenModuleError(module_cls, attr_name, attr_val)[source]#
You can only assign Module attributes to
self
insideModule.setup()
. Outside of that method, the Module instance is frozen (i.e., immutable). This behavior is similar to frozen Python dataclasses.For instance, this error is raised in the following case:
class SomeModule(nn.Module): @nn.compact def __call__(self, x, num_features=10): self.num_features = num_features # <-- ERROR! x = nn.Dense(self.num_features)(x) return x s = SomeModule().init(random.key(0), jnp.ones((5, 5)))
Similarly, the error is raised when trying to modify a submodule’s attributes after constructing it, even if this is done in the
setup()
method of the parent module:class Foo(nn.Module): def setup(self): self.dense = nn.Dense(features=10) self.dense.features = 20 # <--- This is not allowed def __call__(self, x): return self.dense(x)
- exception flax.errors.SetAttributeInModuleSetupError[source]#
You are not allowed to modify Module class attributes in
class Foo(nn.Module): features: int = 6 def setup(self): self.features = 3 # <-- ERROR def __call__(self, x): return nn.Dense(self.features)(x) variables = SomeModule().init(random.key(0), jnp.ones((1, )))
Instead, these attributes should be set when initializing the Module:
class Foo(nn.Module): features: int = 6 @nn.compact def __call__(self, x): return nn.Dense(self.features)(x) variables = SomeModule(features=3).init(random.key(0), jnp.ones((1, )))
TODO(marcvanzee): Link to a design note explaining why it’s necessary for modules to stay frozen (otherwise we can’t safely clone them, which we use for lifted transformations).
- exception flax.errors.TransformTargetError(target)[source]#
Linen transformations must be applied to Modules classes or functions taking a Module instance as the first argument.
This error occurs when passing an invalid target to a linen transform (nn.vmap, nn.scan, etc.). This occurs for example when trying to transform a Module instance:
nn.vmap(nn.Dense(features))(x) # raises TransformTargetError
You can transform the
nn.Dense
class directly instead:nn.vmap(nn.Dense)(features)(x)
Or you can create a function that takes the module instance as the first argument:
class BatchDense(nn.Module): @nn.compact def __call__(self, x): return nn.vmap( lambda mdl, x: mdl(x), variable_axes={'params': 0}, split_rngs={'params': True})(nn.Dense(3), x)
- exception flax.errors.TransformedMethodReturnValueError(name)[source]#
Transformed Module methods cannot return other Modules or Variables.
- exception flax.errors.TraverseTreeError(update_fn, cond_fn)[source]#
Error when calling
Cursor._traverse_tree()
. This function has two modes:if
update_fn
is not None, it will traverse the tree and return a generator of tuples containing the path where theupdate_fn
was applied and the newly modified value.if
cond_fn
is not None, it will traverse the tree and return a generator of tuple paths that fulfilled the conditions of thecond_fn
.
This error occurs if either both
update_fn
andcond_fn
are None, or both are not None.