Skip to content

Commit 40e89c9

Browse files
committed
feat: class to manage base bcr data and some tests
1 parent 6b65b17 commit 40e89c9

File tree

3 files changed

+76
-26
lines changed

3 files changed

+76
-26
lines changed

src/graphs/bar_chart_race.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,43 @@
11
import pandas as pd
2+
from pandas.errors import ParserError
23
import re
34

5+
class BaseDf:
6+
def __init__(self, df):
7+
self.df = df
8+
9+
def prepare(self):
10+
self.verify_column_count()
11+
self.prepare_date_column()
12+
self.prepare_value_column()
13+
14+
def verify_column_count(self):
15+
if not (3 <= self.df.shape[1] <= 6):
16+
raise BaseDfException("number of columns must be between 3 and 6")
17+
18+
def prepare_date_column(self):
19+
original_name = self.df.columns[-1]
20+
try:
21+
self.df[original_name] = pd.to_datetime(self.df[original_name], format="ISO8601")
22+
except (ValueError, ParserError):
23+
raise BaseDfException("last column must be a date column")
24+
self.df.rename(columns={original_name: "date"})
25+
26+
def prepare_value_column(self):
27+
original_name = self.df.columns[-2]
28+
try:
29+
self.df[original_name] = self.df[original_name].astype("float")
30+
except ValueError:
31+
raise BaseDfException("second to last column must be a quantity column")
32+
self.df.rename(columns={original_name: "value"})
33+
34+
35+
class BaseDfException(Exception):
36+
def __init__(self, message):
37+
self.message = message
38+
return super().__init__(message)
39+
40+
441
def process_bar_chart_race(df):
542
"""
643
Process data for bar chart race visualization.
@@ -9,19 +46,11 @@ def process_bar_chart_race(df):
946
:return: Processed data suitable for bar chart race or error message
1047
"""
1148

12-
# RULES
13-
14-
# 1. Check if the number of columns is between 3 and 6
15-
if not (3 <= df.shape[1] <= 6):
16-
return {"failed": "data failed bar chart race rule - number of columns must be between 3 and 6"}
17-
18-
# 2. Check if the last column is a date column
19-
if not df.iloc[:, -1].apply(is_datetime_string).any():
20-
return {"failed": "data failed bar chart race rule - last column must be a date column"}
21-
22-
# 3. Check if the second to last column is a number-string column
23-
if not df.iloc[:, -2].apply(lambda x: isinstance(x, str) and x.replace('.', '', 1).isdigit()).any():
24-
return {"failed": "data failed bar chart race rule - second to last column must be a quantity column"}
49+
bdf = BaseDf(df)
50+
try:
51+
bdf.prepare()
52+
except BaseDfException as e:
53+
return {"failed": e.message}
2554

2655
# 4. Identify columns with word identifiers
2756
identifier_columns = identify_word_identifier_columns(df.iloc[:, :-2])

src/graphs/tests.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pandas as pd
22
from django.test import TestCase
3+
from django.utils.timezone import now
34

45
from graphs.table import process_table
6+
from graphs.bar_chart_race import process_bar_chart_race
57

68

79
class TestHelper:
@@ -39,3 +41,18 @@ def test_basic_table(self):
3941
},
4042
],
4143
)
44+
45+
def test_bcr_base_errors(self):
46+
df = TestHelper.mock_df_table()
47+
msg = process_bar_chart_race(df)["failed"]
48+
self.assertEqual(msg, "number of columns must be between 3 and 6")
49+
df["new column"] = "abcdef"
50+
msg = process_bar_chart_race(df)["failed"]
51+
self.assertEqual(msg, "last column must be a date column")
52+
df["new column"] = now()
53+
msg = process_bar_chart_race(df)["failed"]
54+
self.assertEqual(msg, "second to last column must be a quantity column")
55+
df["new column"] = 2
56+
df["last column"] = now()
57+
res = process_bar_chart_race(df)
58+
self.assertNotIn("failed", res)

src/query/sparql.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,32 @@ def df_from_query(sparql_string):
1010
:param sparql_string: SPARQL query string
1111
:return: DataFrame containing the results
1212
"""
13-
url = "https://query.wikidata.org/sparql"
14-
params = {
15-
"query": sparql_string,
16-
"format": "json"
17-
}
18-
headers = {'User-agent': 'Wiki-Infographics 1.0'}
19-
response = requests.get(url, headers=headers, params=params)
13+
response = get_response(sparql_string)
2014

2115
if response.status_code != 200:
2216
error_response = extract_error_message(response.text)
2317
return {"error": error_response}
2418

2519
data = response.json()
26-
results = data['results']['bindings']
27-
variables = data['head']['vars']
20+
results = data["results"]["bindings"]
21+
variables = data["head"]["vars"]
2822

29-
# Convert results to DataFrame
30-
df = pd.DataFrame([{var: binding.get(var, {}).get('value', None) for var in variables} for binding in results])
23+
df = pd.DataFrame(
24+
[
25+
{var: binding.get(var, {}).get("value", None) for var in variables}
26+
for binding in results
27+
]
28+
)
3129

3230
return df
3331

3432

33+
def get_response(sparql_string):
34+
url = "https://query.wikidata.org/sparql"
35+
params = {"query": sparql_string, "format": "json"}
36+
headers = {"User-agent": "Wiki-Infographics 1.0"}
37+
return requests.get(url, headers=headers, params=params)
38+
3539

3640
def extract_error_message(error_str):
3741
"""
@@ -43,10 +47,10 @@ def extract_error_message(error_str):
4347
Returns:
4448
str: Formatted error message in the form of "MalformedQueryException: [error details]."
4549
"""
46-
50+
4751
pattern = r"MalformedQueryException: [^.]*\."
4852
match = re.search(pattern, error_str)
49-
53+
5054
if match:
5155
return match.group(0)
5256
else:

0 commit comments

Comments
 (0)