Skip to content

Commit 537702d

Browse files
fix issue when from_file_path is str instead of Path
1 parent 17b2e15 commit 537702d

File tree

2 files changed

+71
-9
lines changed

2 files changed

+71
-9
lines changed

src/geophires_x_client/geophires_input_parameters.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ class PowerPlantType(Enum):
4040

4141

4242
class GeophiresInputParameters:
43+
"""
44+
.. deprecated:: v3.9.21
45+
Use :class:`~geophires_x_client.geophires_input_parameters.ImmutableGeophiresInputParameters` instead for
46+
better performance and guardrails against erroneous usage.
47+
This class is kept for backwards compatibility, but does not work with GeophiresXClient caching and is more
48+
susceptible to potential bugs due to its mutability.
49+
"""
4350

4451
def __init__(self, params: Optional[MappingProxyType] = None, from_file_path: Optional[Path] = None):
4552
"""
@@ -100,15 +107,24 @@ class ImmutableGeophiresInputParameters(GeophiresInputParameters):
100107
"""
101108

102109
params: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({}))
103-
from_file_path: Union[Path, None] = None
110+
from_file_path: Union[Path, str, None] = None
104111

105112
# A unique ID for this instance, used for file I/O but not for hashing or equality.
106113
_instance_id: uuid.UUID = field(default_factory=uuid.uuid4, init=False, repr=False, compare=False)
107114

108115
def __post_init__(self):
109-
"""Ensures that the parameters dictionary is immutable."""
116+
"""
117+
Validates input and normalizes field types for immutability and consistency.
118+
- Ensures from_file_path is a Path object if provided as a string.
119+
- Ensures the params dictionary is an immutable mapping proxy.
120+
"""
121+
# Normalize from_file_path to a Path object. object.__setattr__ is required
122+
# because the dataclass is frozen.
123+
if self.from_file_path and isinstance(self.from_file_path, str):
124+
object.__setattr__(self, 'from_file_path', Path(self.from_file_path))
125+
126+
# Ensure params is an immutable proxy
110127
if not isinstance(self.params, MappingProxyType):
111-
# object.__setattr__ is required to modify a field in a frozen dataclass
112128
object.__setattr__(self, 'params', MappingProxyType(self.params))
113129

114130
def __hash__(self) -> int:
@@ -117,36 +133,38 @@ def __hash__(self) -> int:
117133
If a base file is used, its content is read and hashed to ensure
118134
the hash reflects a true snapshot of all inputs.
119135
"""
120-
121136
param_hash = hash(frozenset(self.params.items()))
122137

123-
if self.from_file_path is not None and self.from_file_path.exists():
138+
file_content_hash = None
139+
# self.from_file_path is now guaranteed to be a Path object or None
140+
if self.from_file_path and self.from_file_path.exists():
124141
file_content_hash = hash(self.from_file_path.read_bytes())
125142
else:
143+
# Hash the path itself if it's None or doesn't exist.
126144
file_content_hash = hash(self.from_file_path)
127145

128146
return hash((param_hash, file_content_hash))
129147

130148
def as_file_path(self) -> Path:
131149
"""
132150
Creates a temporary file representation of the parameters on demand.
133-
The resulting file path is cached for efficiency.
151+
The resulting file path is cached on the instance for efficiency.
134152
"""
135153

136-
# Return the cached path if the file has already been generated for this instance.
154+
# Use hasattr to check for the cached attribute on the frozen instance
137155
if hasattr(self, '_cached_file_path'):
138156
return self._cached_file_path
139157

140158
file_path = Path(tempfile.gettempdir(), f'geophires-input-params_{self._instance_id!s}.txt')
141159

142160
with open(file_path, 'w', encoding='UTF-8') as f:
143-
if self.from_file_path is not None:
161+
if self.from_file_path:
144162
with open(self.from_file_path, encoding='UTF-8') as base_file:
145163
f.write(base_file.read())
146164

147165
if self.params:
148166
# Ensure there is a newline between the base file content and appended params.
149-
if self.from_file_path is not None and f.tell() > 0:
167+
if self.from_file_path and f.tell() > 0:
150168
f.seek(f.tell() - 1)
151169
if f.read(1) != '\n':
152170
f.write('\n')

tests/geophires_x_client_tests/test_geophires_input_parameters.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import tempfile
22
import uuid
33
from pathlib import Path
4+
from types import MappingProxyType
45

56
from geophires_x_client import GeophiresInputParameters
67
from geophires_x_client import GeophiresXClient
8+
from geophires_x_client.geophires_input_parameters import ImmutableGeophiresInputParameters
79
from tests.base_test_case import BaseTestCase
810

911

@@ -52,3 +54,45 @@ def test_input_file_comments(self):
5254
GeophiresInputParameters(from_file_path=self._get_test_file_path('input_comments.txt'))
5355
)
5456
self.assertIsNotNone(result)
57+
58+
59+
class ImmutableGeophiresInputParametersTestCase(BaseTestCase):
60+
def test_init_with_file_path_as_string(self):
61+
"""Verify that the class can be initialized with a string path without raising an AttributeError."""
62+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp_file:
63+
tmp_file_path = tmp_file.name
64+
tmp_file.write('key,value\n')
65+
66+
# This should not raise an AttributeError
67+
params = ImmutableGeophiresInputParameters(from_file_path=tmp_file_path)
68+
69+
# Verify the path was correctly converted and can be used
70+
self.assertTrue(params.as_file_path().exists())
71+
self.assertIsInstance(params.from_file_path, Path)
72+
73+
# Clean up the temporary file
74+
Path(tmp_file_path).unlink()
75+
76+
def test_hash_equality(self):
77+
"""Verify that two objects with the same content have the same hash."""
78+
params = {'Reservoir Depth': 3, 'Gradient 1': 50}
79+
p1 = ImmutableGeophiresInputParameters(params=params)
80+
p2 = ImmutableGeophiresInputParameters(params=params)
81+
82+
self.assertIsNot(p1, p2)
83+
self.assertEqual(hash(p1), hash(p2))
84+
85+
def test_hash_inequality(self):
86+
"""Verify that two objects with different content have different hashes."""
87+
p1 = ImmutableGeophiresInputParameters(params={'Reservoir Depth': 3})
88+
p2 = ImmutableGeophiresInputParameters(params={'Reservoir Depth': 4})
89+
self.assertNotEqual(hash(p1), hash(p2))
90+
91+
def test_immutability_of_params(self):
92+
"""Verify that the params dictionary is an immutable mapping proxy."""
93+
p1 = ImmutableGeophiresInputParameters(params={'Reservoir Depth': 3})
94+
self.assertIsInstance(p1.params, MappingProxyType)
95+
96+
with self.assertRaises(TypeError):
97+
# This should fail because MappingProxyType is read-only
98+
p1.params['Reservoir Depth'] = 4

0 commit comments

Comments
 (0)