Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .asset import Asset
from .common import logger
from .schemas import EmptyOutput
from .auth import BasicAuth, TokenAuth, OAuth, NoAuth
from .auth import BasicAuth, TokenAuth, OAuth, NoAuth, CertificateAuth
from .request_maker import make_request

app = App(
Expand Down Expand Up @@ -72,12 +72,8 @@ def test_connectivity(soar: SOARClient, asset: Asset) -> None:
app.register_action(action_delete.delete_data, action_type="generic", read_only=False)
app.register_action(action_head.get_headers, action_type="investigate")
app.register_action(action_options.get_options, action_type="investigate")
app.register_action(
put_file.put_file, action_type="generic", read_only=False, verbose=put_file.VERBOSE
)
app.register_action(
get_file.get_file, action_type="investigate", verbose=get_file.VERBOSE
)
app.register_action(put_file.put_file, action_type="generic", read_only=False, verbose=put_file.VERBOSE)
app.register_action(get_file.get_file, action_type="investigate", verbose=get_file.VERBOSE)


def get_auth_method(asset: Asset, soar_client: SOARClient):
Expand All @@ -88,7 +84,9 @@ def get_auth_method(asset: Asset, soar_client: SOARClient):
authentication method to use (Basic, Token, OAuth, or None) and returns
an instance of the corresponding strategy class.
"""
if asset.username and asset.password:
if asset.public_cert and asset.private_key:
return CertificateAuth(asset)
elif asset.username and asset.password:
return BasicAuth(asset)
elif asset.auth_token_name and asset.auth_token:
return TokenAuth(asset)
Expand Down
22 changes: 7 additions & 15 deletions src/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,12 @@ class Asset(BaseAsset):
description="Type of authentication token",
default="ph-auth-token",
)
auth_token: str = AssetField(
required=False, description="Value of authentication token", sensitive=True
)
username: str = AssetField(
required=False, description="Username (for HTTP basic auth)"
)
password: str = AssetField(
required=False, description="Password (for HTTP basic auth)", sensitive=True
)
oauth_token_url: str = AssetField(
required=False, description="URL to fetch oauth token from"
)
auth_token: str = AssetField(required=False, description="Value of authentication token", sensitive=True)
username: str = AssetField(required=False, description="Username (for HTTP basic auth)")
password: str = AssetField(required=False, description="Password (for HTTP basic auth)", sensitive=True)
oauth_token_url: str = AssetField(required=False, description="URL to fetch oauth token from")
client_id: str = AssetField(required=False, description="Client ID (for OAuth)")
client_secret: str = AssetField(
required=False, description="Client Secret (for OAuth)", sensitive=True
)
client_secret: str = AssetField(required=False, description="Client Secret (for OAuth)", sensitive=True)
scope: str = AssetField(required=False, description="Scope for OAuth")
timeout: float = AssetField(required=False, description="Timeout for HTTP calls")
test_http_method: str = AssetField(
Expand All @@ -61,3 +51,5 @@ class Asset(BaseAsset):
"PATCH",
],
)
public_cert: str = AssetField(required=False, sensitive=True, description="Public part of the client certificate")
private_key: str = AssetField(required=False, sensitive=True, description="Private key for the client certificate")
22 changes: 16 additions & 6 deletions src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,9 @@ def _generate_new_token(self):
access_token = json.loads(response.text).get("access_token")

except requests.exceptions.RequestException as e:
raise ActionFailure(
f"Error fetching OAuth token from {token_url}. Details: {e}"
) from e
raise ActionFailure(f"Error fetching OAuth token from {token_url}. Details: {e}") from e
except json.JSONDecodeError as e:
raise ActionFailure(
"Error parsing response from server while fetching token"
) from e
raise ActionFailure("Error parsing response from server while fetching token") from e

if not access_token:
raise ActionFailure("Access token not found in response body")
Expand Down Expand Up @@ -158,6 +154,20 @@ def create_auth(self, headers: dict) -> tuple[None, dict]:
return None, headers


class CertificateAuth(Authorization):
"""
Implements authentication using client-side certificates.
"""

def __init__(self, asset):
self.public_cert = asset.public_cert
self.private_key = asset.private_key

def create_auth(self, headers: dict) -> tuple[tuple[str, str], dict]:
logger.info("Using Certificate-based authentication")
return (self.public_cert, self.private_key), headers


class NoAuth(Authorization):
"""
Represents an anonymous request with no authentication.
Expand Down
53 changes: 36 additions & 17 deletions src/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from .common import logger
from .schemas import ParsedResponseBody

import tempfile
import os
from contextlib import contextmanager


def process_xml_response(response) -> dict:
try:
Expand All @@ -34,13 +38,9 @@ def process_json_response(response) -> dict:
try:
return ParsedResponseBody(**response.json())
except json.JSONDecodeError as e:
raise ActionFailure(
f"Server claimed JSON but failed to parse. Error: {e}"
) from e
raise ActionFailure(f"Server claimed JSON but failed to parse. Error: {e}") from e
except ValidationError as e:
raise ActionFailure(
f"Response JSON did not match expected structure. Details: {e}"
) from e
raise ActionFailure(f"Response JSON did not match expected structure. Details: {e}") from e


def process_html_response(response) -> ParsedResponseBody:
Expand All @@ -58,11 +58,7 @@ def process_html_response(response) -> ParsedResponseBody:


def process_empty_response(content_type) -> dict:
message = (
"Response includes a file"
if "octet-stream" in content_type
else "Empty response body"
)
message = "Response includes a file" if "octet-stream" in content_type else "Empty response body"
return {"message": message}


Expand Down Expand Up @@ -92,16 +88,12 @@ def parse_headers(headers_str: Optional[str]) -> dict:
parsed_headers = json.loads(headers_str)

except json.JSONDecodeError as e:
error_message = (
f"Failed to parse headers. Ensure it's a valid JSON object. Error: {e}"
)
error_message = f"Failed to parse headers. Ensure it's a valid JSON object. Error: {e}"
logger.error(error_message)
raise ActionFailure(error_message) from e

if not isinstance(parsed_headers, dict):
raise ActionFailure(
"Headers parameter must be a valid JSON object (dictionary)."
)
raise ActionFailure("Headers parameter must be a valid JSON object (dictionary).")

return parsed_headers

Expand All @@ -128,3 +120,30 @@ def handle_various_response(response):
else:
raw_body = response.text
return parsed_body, raw_body


@contextmanager
def temp_cert_files(public_cert_data: str, private_key_data: str):
"""
Context manager to create temporary files for public certificate and private key.
"""
public_cert_path = None
private_key_path = None
try:
if public_cert_data:
with tempfile.NamedTemporaryFile(delete=False) as f_pub:
f_pub.write(public_cert_data.encode("utf-8"))
public_cert_path = f_pub.name

if private_key_data:
with tempfile.NamedTemporaryFile(delete=False) as f_priv:
f_priv.write(private_key_data.encode("utf-8"))
private_key_path = f_priv.name

Comment on lines +133 to +142
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cert based required both public and private keys to be not null. We should raise if both aren't a non empty string

yield (public_cert_path, private_key_path)

finally:
if public_cert_path and os.path.exists(public_cert_path):
os.remove(public_cert_path)
if private_key_path and os.path.exists(private_key_path):
os.remove(private_key_path)
97 changes: 79 additions & 18 deletions src/request_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .asset import Asset
from .auth import OAuth
from .common import logger
from .auth import Authorization, CertificateAuth
from .helpers import temp_cert_files


def make_request(
Expand All @@ -49,20 +51,37 @@ def make_request(

logger.info(f"Making {method} request to: {full_url}")

body = (
UnicodeDammit(body).unicode_markup.encode("utf-8")
if isinstance(body, str)
else body
)
body = UnicodeDammit(body).unicode_markup.encode("utf-8") if isinstance(body, str) else body

from .app import get_auth_method

auth_method = get_auth_method(asset, soar)

if isinstance(auth_method, CertificateAuth):
return _execute_certificate_request(auth_method, full_url, method, body, verify, parsed_headers, output, asset)
else:
return _execute_standard_request(auth_method, full_url, method, body, verify, parsed_headers, output, asset)


def _execute_standard_request(
auth_method: Authorization,
full_url: str,
method: str,
body: Optional[str],
verify: bool,
headers: dict,
output_cls: type[ActionOutput],
asset: Asset,
) -> ActionOutput:
"""
Executes a standard HTTP request (non-certificate based).
Handles OAuth token refresh logic.
"""
retries = 1
response = None

while retries >= 0:
from .app import get_auth_method

auth_method = get_auth_method(asset, soar)
auth_object, final_headers = auth_method.create_auth(parsed_headers)
auth_object, final_headers = auth_method.create_auth(headers.copy())

try:
response = requests.request(
Expand All @@ -75,26 +94,68 @@ def make_request(
timeout=asset.timeout,
)
response.raise_for_status()

break

except requests.exceptions.RequestException as e:
if isinstance(auth_method, OAuth) and retries > 0:
logger.warning(
"Request failed with 401, token might be expired. Forcing a refresh."
)
auth_method.get_token(force_new=True)
logger.warning("Request failed with 401, token might be expired. Forcing a refresh.")
auth_method.get_token(force_new=True, full_url=full_url)
retries -= 1
continue
else:
raise ActionFailure(
f"Request failed for {full_url}. Details: {e}"
) from e
raise ActionFailure(f"Request failed for {full_url}. Details: {e}") from e

if response is None:
raise ActionFailure(f"Request failed for {full_url} and no response was received after retries.")

parsed_body, raw_body = helpers.handle_various_response(response)
logger.info(f"Successfully processed data. Status: {response.status_code}")

return output_cls(
status_code=response.status_code,
location=full_url,
method=method,
parsed_response_body=parsed_body,
response_body=raw_body,
response_headers=str(dict(response.headers)),
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this part be common for both _execute functions?


def _execute_certificate_request(
auth_method: CertificateAuth,
full_url: str,
method: str,
body: Optional[str],
verify: bool,
headers: dict,
output_cls: type[ActionOutput],
asset: Asset,
) -> ActionOutput:
"""
Executes an HTTP request using client-side certificates.
"""
public_cert_data, private_key_data = auth_method.create_auth(headers.copy())

with temp_cert_files(public_cert_data, private_key_data) as cert_param:
try:
response = requests.request(
method=method,
url=full_url,
cert=cert_param,
data=body,
verify=verify,
headers=headers,
timeout=asset.timeout,
)
response.raise_for_status()

except requests.exceptions.RequestException as e:
raise ActionFailure(f"Certificate-based request failed for {full_url}. Details: {e}") from e

parsed_body, raw_body = helpers.handle_various_response(response)
logger.info(f"Successfully processed data. Status: {response.status_code}")

return output(
return output_cls(
status_code=response.status_code,
location=full_url,
method=method,
Expand Down