99
1010DEFAULT_JSON_OUTPUT_FILE = "safety_results.json"
1111
12+ # File patterns to scan in HuggingFace repos (pickle-based model files)
13+ HF_PICKLE_PATTERNS = [
14+ "*.bin" ,
15+ "*.pt" ,
16+ "*.pth" ,
17+ "*.pkl" ,
18+ "*.pickle" ,
19+ "pytorch_model.bin" ,
20+ "model.safetensors" , # Skip safetensors (safe format)
21+ ]
22+
23+ # Extensions that are known to be safe (not pickle-based)
24+ HF_SAFE_EXTENSIONS = {".safetensors" , ".json" , ".txt" , ".md" , ".yaml" , ".yml" , ".toml" }
25+
26+
27+ def _scan_huggingface (
28+ repo_id : str ,
29+ revision : str | None = None ,
30+ token : str | None = None ,
31+ json_output_path : str | None = None ,
32+ print_results : bool = False ,
33+ ) -> int :
34+ """Scan a HuggingFace Hub repository for potentially malicious pickle files.
35+
36+ Args:
37+ repo_id: HuggingFace repository ID (e.g., 'bert-base-uncased')
38+ revision: Specific revision (branch, tag, or commit) to scan
39+ token: HuggingFace API token for private repositories
40+ json_output_path: Path to write JSON results
41+ print_results: Whether to print results to console
42+
43+ Returns:
44+ Exit code (0=clean, 1=unsafe, 2=error)
45+ """
46+ try :
47+ from huggingface_hub import HfApi , hf_hub_download
48+ except ImportError :
49+ sys .stderr .write (
50+ "Error: huggingface_hub is required for --huggingface scanning.\n "
51+ "Install with: pip install fickling[huggingface]\n "
52+ )
53+ return 2 # EXIT_ERROR
54+
55+ api = HfApi (token = token )
56+
57+ # List files in the repository
58+ try :
59+ repo_info = api .repo_info (repo_id = repo_id , revision = revision , token = token )
60+ files = [f .rfilename for f in repo_info .siblings ] if repo_info .siblings else []
61+ except Exception as e :
62+ sys .stderr .write (f"Error accessing HuggingFace repository '{ repo_id } ': { e !s} \n " )
63+ return 2 # EXIT_ERROR
64+
65+ # Filter for potentially unsafe pickle files
66+ pickle_files = []
67+ for filename in files :
68+ ext = "." + filename .rsplit ("." , 1 )[- 1 ].lower () if "." in filename else ""
69+ if ext in HF_SAFE_EXTENSIONS :
70+ continue
71+ if ext in {".bin" , ".pt" , ".pth" , ".pkl" , ".pickle" } or filename .endswith (
72+ "pytorch_model.bin"
73+ ):
74+ pickle_files .append (filename )
75+
76+ if not pickle_files :
77+ if print_results :
78+ print (f"No pickle files found in { repo_id } " )
79+ return 0 # EXIT_CLEAN
80+
81+ if print_results :
82+ print (f"Scanning { len (pickle_files )} file(s) in { repo_id } ..." )
83+
84+ overall_safe = True
85+ json_output = json_output_path or DEFAULT_JSON_OUTPUT_FILE
86+
87+ for filename in pickle_files :
88+ if print_results :
89+ print (f"\n Scanning: { filename } " )
90+
91+ try :
92+ # Download the file
93+ local_path = hf_hub_download (
94+ repo_id = repo_id ,
95+ filename = filename ,
96+ revision = revision ,
97+ token = token ,
98+ )
99+
100+ # Scan the file
101+ with open (local_path , "rb" ) as f :
102+ stacked_pickled = fickle .StackedPickle .load (f , fail_on_decode_error = False )
103+
104+ for pickled in stacked_pickled :
105+ safety_results = check_safety (pickled , json_output_path = json_output )
106+
107+ if print_results :
108+ result_str = safety_results .to_string ()
109+ if result_str :
110+ print (f" { result_str } " )
111+
112+ if safety_results .severity > Severity .LIKELY_SAFE :
113+ overall_safe = False
114+ if print_results :
115+ sys .stderr .write (f" WARNING: { filename } may contain unsafe content!\n " )
116+
117+ except fickle .PickleDecodeError as e :
118+ if print_results :
119+ sys .stderr .write (f" Error parsing { filename } : { e !s} \n " )
120+ # Parsing errors are suspicious but not necessarily unsafe
121+ continue
122+ except Exception as e :
123+ if print_results :
124+ sys .stderr .write (f" Error scanning { filename } : { e !s} \n " )
125+ continue
126+
127+ if print_results :
128+ if overall_safe :
129+ print (f"\n { repo_id } : No obvious safety issues detected" )
130+ else :
131+ print (f"\n { repo_id } : Potentially unsafe content detected!" )
132+
133+ return 0 if overall_safe else 1 # EXIT_CLEAN or EXIT_UNSAFE
134+
12135
13136def main (argv : list [str ] | None = None ) -> int :
14137 if argv is None :
@@ -96,6 +219,27 @@ def main(argv: list[str] | None = None) -> int:
96219 help = "print a runtime trace while interpreting the input pickle file" ,
97220 )
98221 parser .add_argument ("--version" , "-v" , action = "store_true" , help = "print the version and exit" )
222+ parser .add_argument (
223+ "--huggingface" ,
224+ "--hf" ,
225+ type = str ,
226+ default = None ,
227+ metavar = "REPO_ID" ,
228+ help = "scan a model from HuggingFace Hub by repository ID (e.g., 'bert-base-uncased'). "
229+ "Requires huggingface_hub: pip install fickling[huggingface]" ,
230+ )
231+ parser .add_argument (
232+ "--hf-revision" ,
233+ type = str ,
234+ default = None ,
235+ help = "specific revision (branch, tag, or commit) to scan from HuggingFace Hub" ,
236+ )
237+ parser .add_argument (
238+ "--hf-token" ,
239+ type = str ,
240+ default = None ,
241+ help = "HuggingFace API token for accessing private repositories" ,
242+ )
99243
100244 args = parser .parse_args (argv [1 :])
101245
@@ -106,6 +250,16 @@ def main(argv: list[str] | None = None) -> int:
106250 print (__version__ )
107251 return 0
108252
253+ # HuggingFace scanning mode
254+ if args .huggingface is not None :
255+ return _scan_huggingface (
256+ repo_id = args .huggingface ,
257+ revision = args .hf_revision ,
258+ token = args .hf_token ,
259+ json_output_path = args .json_output ,
260+ print_results = args .print_results ,
261+ )
262+
109263 if args .create is None :
110264 if args .PICKLE_FILE == "-" :
111265 if hasattr (sys .stdin , "buffer" ) and sys .stdin .buffer is not None :
0 commit comments