Skip to content

Commit 41953e9

Browse files
authored
Merge pull request #346 from semantic-systems/refactor_hf
Refactor Huggingface Sources
2 parents 861362b + 0f813af commit 41953e9

File tree

3 files changed

+242
-174
lines changed

3 files changed

+242
-174
lines changed

config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class Config:
122122
},
123123
"module": "huggingface_models",
124124
"search-endpoint": f"https://huggingface.co/api/models?limit={NUMBER_OF_RECORDS_FOR_SEARCH_ENDPOINT}&search=",
125+
"get-resource-endpoint": f"https://huggingface.co/api/models/",
125126
},
126127
"Huggingface - Datasets": {
127128
"logo": {
@@ -133,6 +134,7 @@ class Config:
133134
},
134135
"module": "huggingface_datasets",
135136
"search-endpoint": f"https://huggingface.co/api/datasets?limit={NUMBER_OF_RECORDS_FOR_SEARCH_ENDPOINT}&search=",
137+
"get-resource-endpoint": f"https://huggingface.co/api/datasets/",
136138
},
137139
"OPENAIRE - Products": {
138140
"logo": {

sources/huggingface_datasets.py

Lines changed: 109 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,121 @@
1-
from objects import thing, Article, Author, Dataset, Person
2-
from sources import data_retriever
3-
import utils
4-
from main import app
5-
6-
@utils.handle_exceptions
7-
def search(source: str, search_term: str, results, failed_sources):
8-
search_result = data_retriever.retrieve_data(source=source,
9-
base_url=app.config['DATA_SOURCES'][source].get('search-endpoint', ''),
10-
search_term=search_term,
11-
failed_sources=failed_sources)
12-
13-
total_hits = len(search_result)
14-
15-
if int(total_hits) > 0:
16-
utils.log_event(type="info", message=f"{source} - {total_hits} records matched")
17-
18-
for hit in search_result:
19-
20-
dataset = map_entry_to_dataset(hit)
21-
results['resources'].append(dataset)
22-
23-
def map_entry_to_dataset(record) -> Dataset:
1+
from typing import Union, Dict, Any, List, Iterable
242

25-
dataset = Dataset() # thing -> CreateWork -> Dataset
26-
27-
dataset.identifier = record.get("id", "")
28-
dataset.name = record.get("id", "")
29-
dataset.additionalType = "DATASET"
30-
dataset.url = "https://huggingface.co/datasets/" + record.get("id", "")
31-
dataset.description = utils.remove_html_tags(record.get("description", ""))
32-
dataset.abstract = dataset.description
33-
dataset.license = record.get("license", {}).get("id", "")
34-
dataset.datePublished = record.get("createdAt", "")
35-
dataset.dateModified = record.get("lastModified", "")
3+
import utils
4+
from config import Config
5+
from sources.base import BaseSource
6+
from sources import data_retriever
7+
from objects import thing, Article, Author, Dataset, Person
368

37-
# much metadata is contained in the tags
38-
tags = record.get("tags", [])
399

40-
dataset.inLanguage = [t.split("language:")[1] for t in tags if t.startswith("language:")]
41-
dataset.genre = ", ".join(t.split("task_categories:")[1] for t in tags if t.startswith("task_categories:"))
42-
dataset.encodingFormat = ", ".join(t.split("format:")[1] for t in tags if t.startswith("format:"))
43-
dataset.countryOfOrigin = next((t.split("region:")[1] for t in tags if t.startswith("region:")), "")
44-
dataset.keywords = tags
10+
class HuggingFaceDatasets(BaseSource):
11+
SOURCE = "Huggingface - Datasets"
12+
SEARCH_ENDPOINT = Config.DATA_SOURCES[SOURCE].get('search-endpoint', '')
13+
RESOURCE_ENDPOINT = Config.DATA_SOURCES[SOURCE].get("get-resource-endpoint", "")
14+
15+
def fetch(self, search_term: str, failed_sources: list = []) -> Dict[str, Any]:
16+
"""
17+
Fetch raw json from the source using the given search term.
18+
"""
19+
return data_retriever.retrieve_data(
20+
source=self.SOURCE,
21+
base_url=self.SEARCH_ENDPOINT,
22+
search_term=search_term,
23+
failed_sources=failed_sources,
24+
) or {}
25+
26+
def extract_hits(self, raw: Dict[str, Any]) -> Iterable[Dict[str, Any]]:
27+
"""
28+
Extract the list of hits from the raw JSON response. Should return an iterable of hit dicts.
29+
"""
30+
return raw
31+
32+
def map_hit(self, source_name: str, hit: Dict[str, Any]):
33+
"""
34+
Map a single hit dict from the source to a object from objects.py (e.g., Article, CreativeWork).
35+
"""
36+
dataset = Dataset() # thing -> CreateWork -> Dataset
37+
38+
dataset.identifier = hit.get("id", "")
39+
dataset.name = hit.get("id", "")
40+
dataset.additionalType = "DATASET"
41+
dataset.url = "https://huggingface.co/datasets/" + hit.get("id", "")
42+
dataset.description = utils.remove_html_tags(hit.get("description", ""))
43+
dataset.abstract = dataset.description
44+
dataset.license = hit.get("license", {}).get("id", "")
45+
dataset.datePublished = hit.get("createdAt", "")
46+
dataset.dateModified = hit.get("lastModified", "")
47+
48+
# much metadata is contained in the tags
49+
tags = hit.get("tags", [])
50+
51+
dataset.inLanguage = [t.split("language:")[1] for t in tags if t.startswith("language:")]
52+
dataset.genre = ", ".join(t.split("task_categories:")[1] for t in tags if t.startswith("task_categories:"))
53+
dataset.encodingFormat = ", ".join(t.split("format:")[1] for t in tags if t.startswith("format:"))
54+
dataset.countryOfOrigin = next((t.split("region:")[1] for t in tags if t.startswith("region:")), "")
55+
dataset.keywords = tags
56+
57+
dataset.license = next((t.split("license:")[1] for t in tags if t.startswith("license:")), "")
58+
59+
dataset.creativeWorkStatus = (
60+
"disabled" if hit.get("disabled")
61+
else "private" if hit.get("private")
62+
else "gated" if hit.get("gated")
63+
else "public"
64+
)
65+
66+
if hit.get("author"):
67+
dataset.author = [Author(name=hit["author"])]
68+
dataset.publisher = dataset.author[0].name if dataset.author else ""
69+
70+
_source = thing()
71+
_source.name = source_name
72+
_source.originalSource = dataset.publisher
73+
_source.identifier = dataset.identifier
74+
_source.url = dataset.url
75+
dataset.source.append(_source)
4576

46-
dataset.license = next((t.split("license:")[1] for t in tags if t.startswith("license:")), "")
77+
return dataset
4778

48-
dataset.creativeWorkStatus = (
49-
"disabled" if record.get("disabled")
50-
else "private" if record.get("private")
51-
else "gated" if record.get("gated")
52-
else "public"
53-
)
79+
def search(self, source_name: str, search_term: str, results: dict, failed_sources: list) -> None:
80+
"""
81+
Fetch json from the source, extract hits, map them to objects, and insert them in-place into the results dict.
82+
"""
83+
search_result = self.fetch(search_term, failed_sources)
84+
85+
total_hits = len(search_result)
86+
if int(total_hits) > 0:
87+
utils.log_event(type="info", message=f"{self.SOURCE} - {total_hits} records matched")
88+
89+
for hit in search_result:
90+
dataset = self.map_hit(self.SOURCE, hit)
91+
results['resources'].append(dataset)
92+
93+
def get_resource(self, doi: str) -> Dataset | None:
94+
search_result = data_retriever.retrieve_object(
95+
source=self.SOURCE,
96+
base_url=self.RESOURCE_ENDPOINT,
97+
identifier=doi,
98+
quote=False,
99+
)
100+
if search_result:
101+
dataset = self.map_hit(self.SOURCE, search_result)
102+
utils.log_event(type="info", message=f"{self.SOURCE} - retrieved dataset details")
103+
return dataset
104+
else:
105+
utils.log_event(type="error", message=f"{self.SOURCE} - failed to retrieve dataset details")
106+
return None
54107

55-
if record.get("author"):
56-
dataset.author = [Author(name=record["author"])]
57-
dataset.publisher = dataset.author[0].name if dataset.author else ""
58108

59-
_source = thing()
60-
_source.name = 'Huggingface - Datasets'
61-
_source.originalSource = dataset.publisher
62-
_source.identifier = dataset.identifier
63-
_source.url = dataset.url
64-
dataset.source.append(_source)
109+
@utils.handle_exceptions
110+
def search(source: str, search_term: str, results, failed_sources) -> None:
111+
"""
112+
Entrypoint to search Huggingface Datasets.
113+
"""
114+
HuggingFaceDatasets().search(source, search_term, results, failed_sources)
65115

66-
return dataset
67116

68117
@utils.handle_exceptions
69-
def get_resource(source: str, source_id: str, doi: str):
118+
def get_resource(source: str, source_id: str, doi: str) -> Dataset | None:
70119
"""
71120
Retrieve detailed information for the dataset.
72121
@@ -76,21 +125,4 @@ def get_resource(source: str, source_id: str, doi: str):
76125
77126
:return: dataset
78127
"""
79-
80-
print(f"{source=}")
81-
print(f"{source_id=}")
82-
print(f"{doi=}")
83-
84-
base_url = 'https://huggingface.co/api/datasets/'
85-
search_result = data_retriever.retrieve_object(source=source,
86-
base_url=base_url,
87-
identifier=doi)
88-
89-
if search_result:
90-
dataset = map_entry_to_dataset(search_result)
91-
utils.log_event(type="info", message=f"{source} - retrieved dataset details")
92-
return dataset
93-
else:
94-
utils.log_event(type="error", message=f"{source} - failed to retrieve dataset details")
95-
return None
96-
128+
return HuggingFaceDatasets().get_resource(doi)

0 commit comments

Comments
 (0)