Skip to content

Commit dbaeb3f

Browse files
__radd__ in Borehole
1 parent f563585 commit dbaeb3f

File tree

3 files changed

+75
-18
lines changed

3 files changed

+75
-18
lines changed

pygfunction/boreholes.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
from typing import Union
23
import warnings
34

45
import numpy as np
@@ -55,11 +56,11 @@ def __repr__(self):
5556
f' orientation={self.orientation})')
5657
return s
5758

58-
def __add__(self, other: Self):
59+
def __add__(self, other: Union[Self, list]):
5960
"""
6061
Adds two boreholes together to form a borefield
6162
"""
62-
if not isinstance(other, self.__class__):
63+
if not isinstance(other, (self.__class__, list)):
6364
# Check if other is a borefield and try the operation using
6465
# other.__radd__
6566
try:
@@ -70,12 +71,41 @@ def __add__(self, other: Self):
7071
f'Expected Borefield, list or Borehole input;'
7172
f' got {other}'
7273
)
74+
elif isinstance(other, list):
75+
# Create a borefield from the borehole and a list
76+
from .borefield import Borefield
77+
field = Borefield.from_boreholes([self] + other)
7378
else:
7479
# Create a borefield from the two boreholes
7580
from .borefield import Borefield
7681
field = Borefield.from_boreholes([self, other])
7782
return field
7883

84+
def __radd__(self, other: Union[Self, list]):
85+
"""
86+
Adds two boreholes together to form a borefield
87+
"""
88+
if not isinstance(other, (self.__class__, list)):
89+
# Check if other is a borefield and try the operation using
90+
# other.__radd__
91+
try:
92+
field = other.__add__(self)
93+
except:
94+
# Invalid input
95+
raise TypeError(
96+
f'Expected Borefield, list or Borehole input;'
97+
f' got {other}'
98+
)
99+
elif isinstance(other, list):
100+
# Create a borefield from the borehole and a list
101+
from .borefield import Borefield
102+
field = Borefield.from_boreholes(other + [self])
103+
else:
104+
# Create a borefield from the two boreholes
105+
from .borefield import Borefield
106+
field = Borefield.from_boreholes([other, self])
107+
return field
108+
79109
def distance(self, target):
80110
"""
81111
Evaluate the distance between the current borehole and a target

tests/borefield_test.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,30 @@ def test_borefield_init(field, request):
3232
H, D, r_b, x, y, tilt=tilt, orientation=orientation)
3333
assert borefield == borefield_from_boreholes
3434

35+
36+
# Test Borefield.__add__ and Borefield.__radd__
37+
@pytest.mark.parametrize("field, other_field, field_list, other_field_list", [
38+
# Using Borefield and Borehole objects
39+
('single_borehole', 'two_boreholes_inclined', False, False),
40+
('ten_boreholes_rectangular', 'single_borehole_short', False, False),
41+
('ten_boreholes_rectangular', 'two_boreholes_inclined', False, False),
42+
# Using Borefield as lists
43+
('ten_boreholes_rectangular', 'two_boreholes_inclined', False, True),
44+
('ten_boreholes_rectangular', 'two_boreholes_inclined', True, False),
45+
])
46+
def test_borefield_add(field, other_field, field_list, other_field_list, request):
47+
field = request.field
48+
other_field = request.other_field
49+
reference_field = gt.borefield.Borefield.from_boreholes(
50+
field.to_boreholes() + other_field.to_boreholes()
51+
)
52+
if field_list:
53+
field = field.to_boreholes()
54+
if other_field_list:
55+
other_field = other_field.to_boreholes()
56+
assert field + other_field_list == reference_field
57+
58+
3559
# Test borefield comparison using __eq__
3660
@pytest.mark.parametrize("field, other_field, expected", [
3761
# Fields that are equal
@@ -53,6 +77,7 @@ def test_borefield_eq(field, other_field, expected, request):
5377
other_field = request.getfixturevalue(other_field)
5478
assert (borefield == other_field) == expected
5579

80+
5681
# Test borefield comparison using __ne__
5782
@pytest.mark.parametrize("field, other_field, expected", [
5883
# Fields that are equal
@@ -75,16 +100,6 @@ def test_borefield_ne(field, other_field, expected, request):
75100
assert (borefield != other_field) == expected
76101

77102

78-
def test_borefield_add():
79-
borehole = gt.boreholes.Borehole(100, 1, 0.075, 15, 10)
80-
borefield = gt.borefield.Borefield.rectangle_field(2, 1, 6, 6, 100, 1, 0.075)
81-
borefield_2 = gt.borefield.Borefield.from_boreholes([borehole, gt.boreholes.Borehole(110, 1, 0.075, 20, 15)])
82-
assert borefield + borehole == gt.borefield.Borefield.from_boreholes(borefield.to_boreholes() + [borehole])
83-
assert borehole + borefield == gt.borefield.Borefield.from_boreholes(borefield.to_boreholes() + [borehole])
84-
assert borefield + [borehole] == gt.borefield.Borefield.from_boreholes(borefield.to_boreholes() + [borehole])
85-
assert borefield + borefield_2 == gt.borefield.Borefield.from_boreholes(borefield.to_boreholes()+borefield_2.to_boreholes())
86-
87-
88103
# =============================================================================
89104
# Test evaluate_g_function (vertical boreholes)
90105
# =============================================================================

tests/boreholes_test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ def test_borehole_init():
3333
])
3434

3535

36+
# Test Borehole.__add__
37+
@pytest.mark.parametrize("borehole, other_borehole, borehole_list, other_borehole_list", [
38+
('single_borehole', 'single_borehole_short', False, False),
39+
('single_borehole', 'single_borehole_short', True, False),
40+
('single_borehole', 'single_borehole_short', False, True),
41+
])
42+
def test_borehole_add(borehole, other_borehole, borehole_list, other_borehole_list, request):
43+
borehole = request.borehole
44+
other_borehole = request.other_borehole
45+
field = gt.borefield.Borefield.from_boreholes(
46+
[borehole, other_borehole])
47+
if borehole_list:
48+
borehole = [borehole]
49+
if other_borehole_list:
50+
other_borehole = [other_borehole]
51+
assert field == borehole + other_borehole
52+
53+
3654
# Test Borehole.distance
3755
@pytest.mark.parametrize("borehole1, borehole2", [
3856
# Same borehole
@@ -307,9 +325,3 @@ def test_circle_field(N, R):
307325
len(field) == 1 or np.isclose(np.min(dis), B_min),
308326
len(field) == 1 or np.max(dis) <= (2 + 1e-6) * R,
309327
])
310-
311-
312-
def test_add_boreholes():
313-
borehole1 = gt.boreholes.Borehole(100, 1, 0.075, 0, 0)
314-
borehole2 = gt.boreholes.Borehole(110, 1, 0.075, 0, 0)
315-
assert gt.borefield.Borefield.from_boreholes([borehole1, borehole2]) == borehole1 + borehole2

0 commit comments

Comments
 (0)