#https://github.com/keras-team/keras/blob/v3.8.0/keras/src/saving/serialization_lib.py
def _retrieve_class_or_fn(
name, registered_name, module, obj_type, full_config, custom_objects=None
):
# .... Truncated ...
# Otherwise, attempt to retrieve the class object given the `module`
# and `class_name`. Import the module, find the class.
try:
mod = importlib.import_module(module)
except ModuleNotFoundError:
raise TypeError(...)
obj = vars(mod).get(name, None)
# Special case for keras.metrics.metrics
if obj is None and registered_name is not None:
obj = vars(mod).get(registered_name, None)
if obj is not None:
return obj
raise TypeError(...)