Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@ node_modules

exports
trash

#spyder
.spyproject
2 changes: 1 addition & 1 deletion docs/persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ mean you can just call `sio.load(<file>,
trusted=get_untrusted_types(file=<file>))` on this object -- only pass the
types you really trust to the ``trusted`` argument.

Supported libraries
Supported/Trusted libraries
-------------------

Skops intends to support all of **scikit-learn**, that is, not only its
Expand Down
5 changes: 3 additions & 2 deletions skops/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._persist import dump, dumps, get_untrusted_types, load, loads
from ._persist import dump, dumps, get_untrusted_types, load, loads, is_trusted
from ._visualize import visualize

__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types", "visualize"]
__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types",
"visualize","is_trusted"]
45 changes: 44 additions & 1 deletion skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from pathlib import Path
from typing import Any, BinaryIO, Optional, Sequence
from zipfile import ZIP_STORED, ZipFile
from inspect import getmembers

import skops

from skops.io import _trusted_types
from ._audit import NODE_TYPE_MAPPING, audit_tree, get_tree
from ._utils import LoadContext, SaveContext, _get_state, get_state
from ._utils import LoadContext, SaveContext, _get_state, get_state, get_type_name


# We load the dispatch functions from the corresponding modules and register
# them. Old protocols are found in the 'old/' directory, with the protocol
Expand Down Expand Up @@ -235,3 +238,43 @@ def get_untrusted_types(
untrusted_types = tree.get_unsafe_set()

return sorted(untrusted_types)

def is_trusted(obj)-> list[bool]:
"""Return a list of bools specifying if passed objects are trusted.

Parameters
----------
obj: list, dictionary, or Any
A list, dictionary or single object to be checked.
All objected that are trusted return True, untrusted False
For dictionaries values are checked

Returns
-------
trust: list of bools
A list of bools specifying if the types of the passed objects
are trusted.
"""

val = list() # to hold trusted types
trust = list() # to determine if passed objects are trusted

# loop over _trusted_types to find all trusted types
# and add them to a single list
for w in getmembers(_trusted_types):
if 'NAMES' in w[0] and isinstance(w[1],list):
val.append(w[1])

t_type = [v for vs in val for v in vs]

# check is passed object is in trusted types
if isinstance(obj, list):
for value in obj:
trust.append(get_type_name(type(value)) in t_type)
elif isinstance(obj, dict):
for value in obj.values():
trust.append(get_type_name(type(value)) in t_type)
else:
trust.append(get_type_name(type(obj)) in t_type)

return trust