Skip to content

Commit 0a750d3

Browse files
committed
refactor: complete refactor of bar chart race, still no categories
1 parent adf8c63 commit 0a750d3

File tree

2 files changed

+173
-180
lines changed

2 files changed

+173
-180
lines changed

src/graphs/bar_chart_race.py

Lines changed: 70 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,58 @@
33
import re
44

55
class BaseDf:
6-
def __init__(self, df):
6+
def __init__(self, df: pd.DataFrame):
77
self.df = df
88

9-
def prepare(self):
9+
def prepare(self) -> "BaseDf":
1010
self.verify_column_count()
1111
self.prepare_date_column()
1212
self.prepare_value_column()
13+
self.prepare_identifier_columns()
14+
self.drop_other_columns()
15+
return self
1316

1417
def verify_column_count(self):
15-
if not (3 <= self.df.shape[1] <= 6):
16-
raise BaseDfException("number of columns must be between 3 and 6")
18+
if not (3 <= self.df.shape[1] <= 5):
19+
raise BaseDfException("number of columns must be between 3 and 5")
1720

1821
def prepare_date_column(self):
1922
original_name = self.df.columns[-1]
2023
try:
21-
self.df[original_name] = pd.to_datetime(self.df[original_name], format="ISO8601")
24+
# TODO: allow other units of time
25+
self.df[original_name] = pd.to_datetime(self.df[original_name], format="ISO8601").dt.year
2226
except (ValueError, ParserError):
2327
raise BaseDfException("last column must be a date column")
24-
self.df.rename(columns={original_name: "date"})
28+
self.df.rename(columns={original_name: "date"}, inplace=True)
2529

2630
def prepare_value_column(self):
2731
original_name = self.df.columns[-2]
2832
try:
2933
self.df[original_name] = self.df[original_name].astype("float")
3034
except ValueError:
3135
raise BaseDfException("second to last column must be a quantity column")
32-
self.df.rename(columns={original_name: "value"})
36+
self.df.rename(columns={original_name: "value"}, inplace=True)
37+
38+
def prepare_identifier_columns(self):
39+
renamer = {}
40+
for column in self.df.columns[:-2]:
41+
first = self.df[column][0]
42+
if re.match(r"^https?://.*", first):
43+
renamer[column] = "url"
44+
continue
45+
if "name" in renamer.values():
46+
renamer[column] = "category"
47+
else:
48+
renamer[column] = "name"
49+
if list(renamer.values()) == ["url"] or len(renamer) == 0:
50+
raise BaseDfException("there should be at least one label")
51+
self.df.rename(columns=renamer, inplace=True)
52+
53+
def drop_other_columns(self):
54+
keep = ["name", "category", "url", "value", "date"]
55+
drop = [col for col in self.df.columns if col not in keep]
56+
self.df.drop(columns=drop, inplace=True)
57+
3358

3459

3560
class BaseDfException(Exception):
@@ -38,6 +63,38 @@ def __init__(self, message):
3863
return super().__init__(message)
3964

4065

66+
class DfProcessor:
67+
def __init__(self, bdf: BaseDf):
68+
self.df = bdf.df
69+
70+
def elements(self):
71+
identifiers = [col for col in ["url", "category"] if col in self.df.columns]
72+
if not identifiers:
73+
return {name: {} for name in self.df["name"].unique()}
74+
agg = {col: "first" for col in identifiers}
75+
return (
76+
self.df[["name", *identifiers]].groupby("name").agg(agg).to_dict("index")
77+
)
78+
79+
def interpolated_df(self):
80+
df = self.df
81+
mux = pd.MultiIndex.from_product(
82+
[df["name"].unique(), range(df["date"].min(), df["date"].max() + 1)],
83+
names=["name", "date"],
84+
)
85+
df = (
86+
df.drop_duplicates(["name", "date"])
87+
.set_index(["name", "date"])
88+
.reindex(mux)
89+
.reset_index()
90+
.pivot(index="date", columns=["name"], values="value")
91+
.interpolate()
92+
.melt(ignore_index=False)
93+
.reset_index()
94+
)
95+
return df
96+
97+
4198
def process_bar_chart_race(df):
4299
"""
43100
Process data for bar chart race visualization.
@@ -52,177 +109,11 @@ def process_bar_chart_race(df):
52109
except BaseDfException as e:
53110
return {"failed": e.message}
54111

55-
# 4. Identify columns with word identifiers
56-
identifier_columns = identify_word_identifier_columns(df.iloc[:, :-2])
57-
58-
new_columns = rename_columns(df.columns, identifier_columns)
59-
60-
new_data = df.copy()
61-
new_data.columns = new_columns
62-
63-
# Select the relevant columns to form the new DataFrame
64-
selected_columns = [col for col in new_columns if col in ["category", "name", "value", "date"]]
65-
processed_data = new_data[selected_columns]
66-
67-
processed = fill_nans(processed_data)
68-
69-
result = processed.to_dict(orient='records')
112+
proc = DfProcessor(bdf)
113+
_elements = proc.elements() # TODO: use this to control category + url
114+
ip = proc.interpolated_df()
115+
ip["value"] = ip["value"].astype(str)
116+
ip["date"] = ip["date"].astype(str) + "-01-01"
70117

118+
result = ip.to_dict(orient='records')
71119
return result
72-
73-
74-
def fill_nans(df, fill_NaN=True):
75-
"""
76-
Fill NaNs in dataframe using time-based linear interpolation.
77-
78-
Args:
79-
df (pd.DataFrame): A dataframe.
80-
fill_NaN (bool): Enable or disable filling NaNs with appropriate values.
81-
82-
Returns:
83-
pd.DataFrame: A DataFrame containing the cleaned data.
84-
"""
85-
86-
# Make a deep copy of the DataFrame to avoid SettingWithCopyWarning
87-
df = df.copy()
88-
89-
df['value'] = df['value'].astype(int)
90-
df['date'] = pd.to_datetime(df['date']).dt.year
91-
92-
if "category" in df.columns:
93-
94-
# Handle duplicates by aggregating values within the same year and name
95-
df_aggregated = df.groupby(["date", "name", "category"], as_index=False).agg({
96-
"value": "max",
97-
"date": "first"
98-
})
99-
100-
# Convert the year back to the "YYYY-01-01" format
101-
df_aggregated['date'] = df_aggregated['date'].astype(str) + '-01-01'
102-
103-
# Pivot the table using the original 'date' column
104-
df_pivoted = df_aggregated.pivot(index="date", columns=["name", "category"], values="value").reset_index()
105-
else:
106-
# Handle duplicates by aggregating values within the same year and name
107-
df_aggregated = df.groupby(["date", "name"], as_index=False).agg({
108-
"value": "max",
109-
"date": "first"
110-
})
111-
112-
# Convert the year back to the "YYYY-01-01" format
113-
df_aggregated['date'] = df_aggregated['date'].astype(str) + '-01-01'
114-
115-
# Pivot the table using the original 'date' column
116-
df_pivoted = df_aggregated.pivot(index="date", columns="name", values="value").reset_index()
117-
118-
if fill_NaN:
119-
120-
# Fill NaNs before the first valid value in each column with zero
121-
for col in df_pivoted.columns:
122-
first_valid_index = df_pivoted[col].first_valid_index()
123-
if first_valid_index is not None:
124-
df_pivoted.loc[:first_valid_index, col] = df_pivoted.loc[:first_valid_index, col].fillna(0)
125-
126-
# Check if NaN is the last value in a column(Edge case) and fill NaNs in each column
127-
projection_factor = 1.1
128-
129-
for col in df_pivoted.columns:
130-
last_valid_index = df_pivoted[col].last_valid_index()
131-
132-
# If there is a last valid index and the last value is NaN, project the value
133-
if last_valid_index is not None and pd.isna(df_pivoted[col].iloc[-1]):
134-
projected_value = df_pivoted[col].iloc[last_valid_index] * projection_factor
135-
df_pivoted.loc[df_pivoted.index[-1], col] = projected_value
136-
137-
df_pivoted['date'] = pd.to_datetime(df_pivoted['date'], format='%Y-%m-%d')
138-
df_pivoted.set_index('date', inplace=True)
139-
140-
# Use time based linear interpolation to fill NaNs between known values
141-
df_pivoted = df_pivoted.interpolate(method='time', limit_direction='forward', axis=0).astype(int)
142-
143-
# print(df_pivoted)
144-
df_pivoted = df_pivoted.reset_index()
145-
146-
# Flatten the MultiIndex columns if necessary
147-
df_pivoted.columns = ['_'.join(col).strip() if isinstance(col, tuple) else col for col in df_pivoted.columns]
148-
149-
if "category" in df.columns:
150-
151-
df_unpivoted = df_pivoted.melt(id_vars=['date_'], var_name='category_name', value_name='value')
152-
153-
# Split the concatenated column names back into 'category' and 'name'
154-
df_unpivoted[['category', 'name']] = df_unpivoted['category_name'].str.split('_', expand=True)
155-
df_unpivoted.drop(columns=['category_name'], inplace=True)
156-
157-
df_unpivoted.reset_index(drop=True, inplace=True)
158-
159-
# Rename columns to match the desired format
160-
df_unpivoted.rename(columns={'date_': 'date'}, inplace=True)
161-
162-
# Ensure columns are in the desired order
163-
df_unpivoted = df_unpivoted[['category', 'name', 'value', 'date']]
164-
else:
165-
166-
df_unpivoted = df_pivoted.melt(id_vars=['date'], var_name='name', value_name='value')
167-
168-
df_unpivoted.reset_index(drop=True, inplace=True)
169-
170-
# Ensure columns are in the desired order
171-
df_unpivoted = df_unpivoted[['name', 'value', 'date']]
172-
173-
df_unpivoted[['date', 'value']] = df_unpivoted[['date', 'value']].astype(str)
174-
175-
return df_unpivoted
176-
177-
return df_pivoted
178-
179-
def identify_word_identifier_columns(df):
180-
"""
181-
Identify columns with word identifiers.
182-
183-
:param data: DataFrame excluding the last two columns
184-
:return: List of column indices that match the word identifier rule
185-
"""
186-
def is_non_link_word(x):
187-
if isinstance(x, str):
188-
if re.match(r'https?://|http://', x): # Exclude URLs
189-
return False
190-
if re.match(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$', x): # Exclude date strings
191-
return False
192-
if x.replace('.', '', 1).isdigit(): # Exclude purely numeric strings
193-
return False
194-
if re.search(r'[A-Za-z]', x): # Check if it contains any alphabetic character
195-
return True
196-
return False
197-
198-
# Identify columns with word identifiers
199-
identifier_columns = [i for i in range(df.shape[1]) if df.iloc[:, i].apply(is_non_link_word).any()]
200-
return identifier_columns
201-
202-
def rename_columns(columns, identifier_columns):
203-
"""
204-
Rename columns based on identified word identifiers and the last two columns.
205-
206-
:param columns: List of original column names
207-
:param identifier_columns: List of column indices with word identifiers
208-
:return: List of new column names
209-
"""
210-
new_columns = columns.tolist()
211-
212-
# Rename the last two columns to "value" and "date"
213-
new_columns[-2] = "value"
214-
new_columns[-1] = "date"
215-
216-
# Rename based on the number of word identifier columns found
217-
if len(identifier_columns) >= 2:
218-
new_columns[identifier_columns[0]] = "category"
219-
new_columns[identifier_columns[1]] = "name"
220-
elif len(identifier_columns) == 1:
221-
new_columns[identifier_columns[0]] = "name"
222-
223-
return new_columns
224-
225-
def is_datetime_string(s):
226-
# Check if a string follows a common datetime format using regex
227-
datetime_pattern = r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$'
228-
return bool(re.match(datetime_pattern, str(s)))

0 commit comments

Comments
 (0)