Skip to content

Commit 0707bd4

Browse files
authored
Merge pull request #32 from sypht-team/add-async-workfows
add run_workflow_async method, blackify
2 parents 2e40100 + d6cb9c6 commit 0707bd4

File tree

2 files changed

+73
-20
lines changed

2 files changed

+73
-20
lines changed

sypht/client.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def _create_session(self):
7575
return requests.Session()
7676

7777
def _authenticate_v2(self, endpoint, client_id, client_secret, audience):
78-
basic_auth_slug = b64encode((client_id + ":" + client_secret).encode("utf-8")).decode(
79-
"utf-8"
80-
)
78+
basic_auth_slug = b64encode(
79+
(client_id + ":" + client_secret).encode("utf-8")
80+
).decode("utf-8")
8181
result = self.requests.post(
8282
endpoint,
8383
headers={
@@ -95,7 +95,9 @@ def _authenticate_v2(self, endpoint, client_id, client_secret, audience):
9595
return result["access_token"], result["expires_in"]
9696

9797
def _authenticate_v1(self, endpoint, client_id, client_secret, audience):
98-
endpoint = endpoint or os.environ.get("SYPHT_AUTH_ENDPOINT", SYPHT_LEGACY_AUTH_ENDPOINT)
98+
endpoint = endpoint or os.environ.get(
99+
"SYPHT_AUTH_ENDPOINT", SYPHT_LEGACY_AUTH_ENDPOINT
100+
)
99101
result = self.requests.post(
100102
endpoint,
101103
data={
@@ -107,7 +109,9 @@ def _authenticate_v1(self, endpoint, client_id, client_secret, audience):
107109
).json()
108110

109111
if result.get("error_description"):
110-
raise Exception("Authentication failed: {}".format(result["error_description"]))
112+
raise Exception(
113+
"Authentication failed: {}".format(result["error_description"])
114+
)
111115

112116
return result["access_token"], result["expires_in"]
113117

@@ -214,13 +218,36 @@ def upload(
214218

215219
if "fileId" not in result:
216220
raise Exception(
217-
"Upload failed with response: {}".format("\n" + json.dumps(result, indent=2))
221+
"Upload failed with response: {}".format(
222+
"\n" + json.dumps(result, indent=2)
223+
)
218224
)
219225

220226
return result["fileId"]
221227

222228
def run_workflow(self, workflow, inputs, step=None, endpoint=None, headers=None):
223-
endpoint = urljoin(endpoint or self.base_endpoint, f"workflows/{workflow}/invoke")
229+
endpoint = urljoin(
230+
endpoint or self.base_endpoint, f"workflows/{workflow}/invoke"
231+
)
232+
headers = headers or {}
233+
headers = self._get_headers(**headers)
234+
return self._parse_response(
235+
self.requests.post(
236+
endpoint,
237+
data=json.dumps(
238+
{
239+
"step_id": step,
240+
"inputs": inputs,
241+
}
242+
),
243+
headers=headers,
244+
)
245+
)
246+
247+
def run_workflow_async(
248+
self, workflow, inputs, step=None, endpoint=None, headers=None
249+
):
250+
endpoint = urljoin(endpoint or self.base_endpoint, f"workflows/{workflow}/jobs")
224251
headers = headers or {}
225252
headers = self._get_headers(**headers)
226253
return self._parse_response(
@@ -332,7 +359,9 @@ def get_file(self, file_id, endpoint=None, headers=None):
332359
return self._parse_response(self.requests.get(endpoint, headers=headers))
333360

334361
def get_file_data(self, file_id, endpoint=None, headers=None):
335-
endpoint = urljoin(endpoint or self.base_endpoint, f"app/docs/{file_id}/download")
362+
endpoint = urljoin(
363+
endpoint or self.base_endpoint, f"app/docs/{file_id}/download"
364+
)
336365
headers = headers or {}
337366
headers = self._get_headers(**headers)
338367
response = self.requests.get(endpoint, headers=headers)
@@ -401,9 +430,13 @@ def get_annotations_for_docs(self, doc_ids, endpoint=None):
401430
headers = self._get_headers()
402431
headers["Accept"] = "application/json"
403432
headers["Content-Type"] = "application/json"
404-
return self._parse_response(self.requests.post(endpoint, data=body, headers=headers))
433+
return self._parse_response(
434+
self.requests.post(endpoint, data=body, headers=headers)
435+
)
405436

406-
def set_company_annotations(self, doc_id, annotations, company_id=None, endpoint=None):
437+
def set_company_annotations(
438+
self, doc_id, annotations, company_id=None, endpoint=None
439+
):
407440
data = {
408441
"origin": "external",
409442
"fields": [
@@ -481,7 +514,9 @@ def set_files_for_tag(self, tag, file_ids, company_id=None, endpoint=None):
481514
headers["Accept"] = "application/json"
482515
headers["Content-Type"] = "application/json"
483516
return self._parse_response(
484-
self.requests.put(endpoint, data=json.dumps({"docs": file_ids}), headers=headers)
517+
self.requests.put(
518+
endpoint, data=json.dumps({"docs": file_ids}), headers=headers
519+
)
485520
)
486521

487522
def add_files_to_tag(self, tag, file_ids, company_id=None, endpoint=None):
@@ -494,7 +529,9 @@ def add_files_to_tag(self, tag, file_ids, company_id=None, endpoint=None):
494529
headers["Accept"] = "application/json"
495530
headers["Content-Type"] = "application/json"
496531
return self._parse_response(
497-
self.requests.patch(endpoint, data=json.dumps({"docs": file_ids}), headers=headers)
532+
self.requests.patch(
533+
endpoint, data=json.dumps({"docs": file_ids}), headers=headers
534+
)
498535
)
499536

500537
def remove_file_from_tag(self, file_id, tag, company_id=None, endpoint=None):
@@ -529,7 +566,9 @@ def set_tags_for_file(self, file_id, tags, company_id=None, endpoint=None):
529566
headers["Accept"] = "application/json"
530567
headers["Content-Type"] = "application/json"
531568
return self._parse_response(
532-
self.requests.put(endpoint, data=json.dumps({"tags": tags}), headers=headers)
569+
self.requests.put(
570+
endpoint, data=json.dumps({"tags": tags}), headers=headers
571+
)
533572
)
534573

535574
def add_tags_to_file(self, file_id, tags, company_id=None, endpoint=None):
@@ -542,7 +581,9 @@ def add_tags_to_file(self, file_id, tags, company_id=None, endpoint=None):
542581
headers["Accept"] = "application/json"
543582
headers["Content-Type"] = "application/json"
544583
return self._parse_response(
545-
self.requests.patch(endpoint, data=json.dumps({"tags": tags}), headers=headers)
584+
self.requests.patch(
585+
endpoint, data=json.dumps({"tags": tags}), headers=headers
586+
)
546587
)
547588

548589
def get_entity(self, entity_id, entity_type, company_id=None, endpoint=None):
@@ -586,7 +627,9 @@ def get_many_entities(self, entity_type, entities, company_id=None, endpoint=Non
586627
self.requests.post(endpoint, data=json.dumps(entities), headers=headers)
587628
)
588629

589-
def list_entities(self, entity_type, company_id=None, page=None, limit=None, endpoint=None):
630+
def list_entities(
631+
self, entity_type, company_id=None, page=None, limit=None, endpoint=None
632+
):
590633
"""Get list of entity_ids by pagination."""
591634
company_id = company_id or self.company_id
592635
entity_type = quote_plus(entity_type)
@@ -602,9 +645,13 @@ def list_entities(self, entity_type, company_id=None, page=None, limit=None, end
602645
params["page"] = page
603646
if limit:
604647
params["limit"] = int(limit)
605-
return self._parse_response(self.requests.get(endpoint, headers=headers, params=params))
648+
return self._parse_response(
649+
self.requests.get(endpoint, headers=headers, params=params)
650+
)
606651

607-
def get_all_entity_ids(self, entity_type, verbose=True, company_id=None, endpoint=None):
652+
def get_all_entity_ids(
653+
self, entity_type, verbose=True, company_id=None, endpoint=None
654+
):
608655
"""Get all entity_ids for specified entity_type.
609656
610657
Returns list of objects if verbose (by default):
@@ -687,7 +734,9 @@ def set_many_entities(
687734
for batch in _iter_chunked_sequence(entities, batch_size):
688735
responses.append(
689736
self._parse_response(
690-
self.requests.post(endpoint, data=json.dumps(batch), headers=headers)
737+
self.requests.post(
738+
endpoint, data=json.dumps(batch), headers=headers
739+
)
691740
)
692741
)
693742
return responses
@@ -720,7 +769,9 @@ def update_specification(self, specification, endpoint=None):
720769
headers["Accept"] = "application/json"
721770
headers["Content-Type"] = "application/json"
722771
return self._parse_response(
723-
self.requests.post(endpoint, data=json.dumps(specification), headers=headers)
772+
self.requests.post(
773+
endpoint, data=json.dumps(specification), headers=headers
774+
)
724775
)
725776

726777
def submit_task(

tests/tests_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ class DataExtraction(unittest.TestCase):
2020
def setUp(self):
2121
warnings.simplefilter("ignore", category=ResourceWarning)
2222

23-
self.sypht_client = SyphtClient(os.environ["CLIENT_ID"], os.environ["CLIENT_SECRET"])
23+
self.sypht_client = SyphtClient(
24+
os.environ["CLIENT_ID"], os.environ["CLIENT_SECRET"]
25+
)
2426

2527
def test_with_wrong_fieldset(self):
2628
with self.assertRaises(Exception) as context:

0 commit comments

Comments
 (0)