diff --git a/.gitignore b/.gitignore index 7dc6c890..076838c7 100644 --- a/.gitignore +++ b/.gitignore @@ -122,3 +122,6 @@ node_modules exports trash + +#spyder +.spyproject diff --git a/docs/persistence.rst b/docs/persistence.rst index 4a5f4c1c..b7c198c4 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -231,7 +231,7 @@ mean you can just call `sio.load(, trusted=get_untrusted_types(file=))` on this object -- only pass the types you really trust to the ``trusted`` argument. -Supported libraries +Supported/Trusted libraries ------------------- Skops intends to support all of **scikit-learn**, that is, not only its diff --git a/skops/io/__init__.py b/skops/io/__init__.py index c60c99c6..f81f3991 100644 --- a/skops/io/__init__.py +++ b/skops/io/__init__.py @@ -1,4 +1,5 @@ -from ._persist import dump, dumps, get_untrusted_types, load, loads +from ._persist import dump, dumps, get_untrusted_types, load, loads, is_trusted from ._visualize import visualize -__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types", "visualize"] +__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types", + "visualize","is_trusted"] diff --git a/skops/io/_persist.py b/skops/io/_persist.py index aaed469c..f401ce0e 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -6,11 +6,14 @@ from pathlib import Path from typing import Any, BinaryIO, Optional, Sequence from zipfile import ZIP_STORED, ZipFile +from inspect import getmembers import skops +from skops.io import _trusted_types from ._audit import NODE_TYPE_MAPPING, audit_tree, get_tree -from ._utils import LoadContext, SaveContext, _get_state, get_state +from ._utils import LoadContext, SaveContext, _get_state, get_state, get_type_name + # We load the dispatch functions from the corresponding modules and register # them. Old protocols are found in the 'old/' directory, with the protocol @@ -235,3 +238,43 @@ def get_untrusted_types( untrusted_types = tree.get_unsafe_set() return sorted(untrusted_types) + +def is_trusted(obj)-> list[bool]: + """Return a list of bools specifying if passed objects are trusted. + + Parameters + ---------- + obj: list, dictionary, or Any + A list, dictionary or single object to be checked. + All objected that are trusted return True, untrusted False + For dictionaries values are checked + + Returns + ------- + trust: list of bools + A list of bools specifying if the types of the passed objects + are trusted. + """ + + val = list() # to hold trusted types + trust = list() # to determine if passed objects are trusted + + # loop over _trusted_types to find all trusted types + # and add them to a single list + for w in getmembers(_trusted_types): + if 'NAMES' in w[0] and isinstance(w[1],list): + val.append(w[1]) + + t_type = [v for vs in val for v in vs] + + # check is passed object is in trusted types + if isinstance(obj, list): + for value in obj: + trust.append(get_type_name(type(value)) in t_type) + elif isinstance(obj, dict): + for value in obj.values(): + trust.append(get_type_name(type(value)) in t_type) + else: + trust.append(get_type_name(type(obj)) in t_type) + + return trust \ No newline at end of file