Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions rest_flex_fields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@ def is_expanded(request, field: str) -> bool:
expand_value = request.query_params.get(EXPAND_PARAM)
expand_fields = []

if expand_value:
for f in expand_value.split(","):
expand_fields.extend([_ for _ in f.split(".")])

return any(field for field in expand_fields if field in WILDCARD_VALUES) or field in expand_fields
# first split on commas to get each expand
for full in expand_value.split(","):
# than split on dots to get each component that is expanded
parts = full.split(".")
for i in range(len(parts)):
# add each prefix, as each prefix is epxanded, ie
# a.b.c will add a, a.b and a.b.c to the expand_fields list
# we do this to differentiate a.b from b
expand_fields.append(".".join(parts[: i + 1]))

# WILDCARD_VALUES only expands top level fields
if "." not in field and any(field for field in expand_fields if field in WILDCARD_VALUES):
return True

return field in expand_fields


def is_included(request, field: str) -> bool:
Expand Down
43 changes: 24 additions & 19 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,27 @@ def test_should_not_be_included_and_due_to_fields_and_has_dot_notation(self):
request = MockRequest(query_params={"fields": "hobby,address"})
self.assertFalse(is_included(request, "name"))

def test_should_be_expanded(self):
request = MockRequest(query_params={"expand": "name,address"})
self.assertTrue(is_expanded(request, "name"))

def test_should_not_be_expanded(self):
request = MockRequest(query_params={"expand": "name,address"})
self.assertFalse(is_expanded(request, "hobby"))

def test_should_be_expanded_and_has_dot_notation(self):
request = MockRequest(query_params={"expand": "person.name,address"})
self.assertTrue(is_expanded(request, "name"))

def test_all_should_be_expanded(self):
request = MockRequest(query_params={"expand": WILDCARD_ALL})
self.assertTrue(is_expanded(request, "name"))

def test_asterisk_should_be_expanded(self):
request = MockRequest(query_params={"expand": WILDCARD_ASTERISK})
self.assertTrue(is_expanded(request, "name"))
def test_is_expanded(self):
test_cases = [
("a", "a", True),
("a", "b", False),
("a,b,c", "a", True),
("a,b,c", "b", True),
("a,b,c", "c", True),
("a,b,c", "d", False),
("a.b.c", "a", True),
("a.b.c", "a.b", True),
("a.b.c", "a.b.c", True),
("a.b.c", "b", False),
("a.b.c", "c", False),
("a.b.c", "d", False),
("a.b.c,d", "a", True),
("a.b.c,d", "d", True),
(WILDCARD_ASTERISK, "a", True),
(WILDCARD_ASTERISK, "a.b", False),
(WILDCARD_ALL, "a", True),
(WILDCARD_ALL, "a.b", False),
]
for expand_query_arg, field, should_be_expanded in test_cases:
request = MockRequest(query_params={"expand": expand_query_arg})
self.assertEqual(is_expanded(request, field), should_be_expanded)