Skip to content

Commit d7754a4

Browse files
committed
feat: enhance chart sub-class handling with canonicalization and lookup
1 parent a9c4c4b commit d7754a4

File tree

1 file changed

+44
-19
lines changed

1 file changed

+44
-19
lines changed

mineru_vl_utils/post_process/image_analysis_postprocess.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,49 @@
1010
"content": ("<|content_start|>", "<|content_end|>"),
1111
}
1212

13-
CHART_SUB_CLASS_MAPPING = {
14-
"Line Chart": "line_chart", # 折线图
15-
"Bar Chart": "bar_chart", # 柱状图
16-
"Scatterplot": "scatterplot", # 散点图
17-
"Stacked Bar Chart": "stacked_bar_chart", # 堆叠柱状图
18-
"Area Chart": "area_chart", # 面积图
19-
"Bar-line Hybrid": "bar_line_hybrid", # 柱线混合图
20-
"Histogram": "histogram", # 直方图
21-
"Pie Chart": "pie_chart", # 饼图
22-
"Heatmap": "heatmap", # 热力图
23-
"Bubble Chart": "bubble_chart", # 气泡图
24-
"Radar Chart": "radar_chart", # 雷达图
25-
"Box Plot": "box_plot", # 箱线图
26-
"Geospatial Charts": "geospatial_charts", # 地理图表
27-
"Complex & Scientific": "complex_scientific", # 复杂与科学绘图
13+
CHART_SUB_CLASS_GROUPS: dict[str, tuple[str, ...]] = {
14+
"line_chart": ("Line Chart", "line chart", "line graph", "line"), # 折线图
15+
"bar_chart": ("Bar Chart", "bar chart", "column chart", "column", "bar"), # 柱状图
16+
"scatterplot": ("Scatterplot", "scatter plot", "scatter"), # 散点图
17+
"stacked_bar_chart": ("Stacked Bar Chart", "stacked bar chart", "stacked bar"), # 堆叠柱状图
18+
"area_chart": ("Area Chart", "area chart", "3D area chart", "3D area", "area"), # 面积图
19+
"bar_line_hybrid": ("Bar-line Hybrid", "bar line hybrid", "bar-line hybrid", "bar line"), # 柱线混合图
20+
"histogram": ("Histogram",), # 直方图
21+
"pie_chart": ("Pie Chart", "pie chart", "pie"), # 饼图
22+
"heatmap": ("Heatmap", "heat map"), # 热力图
23+
"bubble_chart": ("Bubble Chart", "bubble chart", "bubble"), # 气泡图
24+
"radar_chart": ("Radar Chart", "radar chart", "radar"), # 雷达图
25+
"box_plot": ("Box Plot", "box plot", "boxplot", "box"), # 箱线图
26+
"geospatial_charts": ("Geospatial Charts", "geospatial chart", "geospatial", "map chart"), # 地理图表
27+
"complex_scientific": ("Complex & Scientific", "complex and scientific", "complex scientific", "scientific"), # 复杂与科学绘图
2828
}
2929

3030

31+
def _canonicalize_chart_sub_class(sub_class: str) -> str:
32+
normalized_sub_class = re.sub(r"\s+", " ", sub_class).strip().lower()
33+
normalized_sub_class = normalized_sub_class.replace("&", " and ")
34+
normalized_sub_class = re.sub(r"[\W_]+", " ", normalized_sub_class)
35+
return re.sub(r"\s+", " ", normalized_sub_class).strip()
36+
37+
38+
def _build_chart_sub_class_lookup(groups: dict[str, tuple[str, ...]]) -> dict[str, str]:
39+
lookup: dict[str, str] = {}
40+
for mapped_sub_class, variants in groups.items():
41+
for variant in variants:
42+
canonical_variant = _canonicalize_chart_sub_class(variant)
43+
existing_sub_class = lookup.get(canonical_variant)
44+
if existing_sub_class is not None and existing_sub_class != mapped_sub_class:
45+
raise ValueError(
46+
f"Duplicate canonical chart sub_class variant '{canonical_variant}' "
47+
f"for '{existing_sub_class}' and '{mapped_sub_class}'"
48+
)
49+
lookup[canonical_variant] = mapped_sub_class
50+
return lookup
51+
52+
53+
CANONICAL_CHART_SUB_CLASS_LOOKUP = _build_chart_sub_class_lookup(CHART_SUB_CLASS_GROUPS)
54+
55+
3156
def _extract_tagged_field(text: str, start_tag: str, end_tag: str) -> str:
3257
start_idx = text.find(start_tag)
3358
if start_idx < 0:
@@ -262,13 +287,13 @@ def node_fixer(match):
262287

263288

264289
def _normalize_chart_sub_class(sub_class: str) -> str:
265-
normalized_sub_class = re.sub(r"\s+", " ", sub_class).strip()
266-
mapped_sub_class = CHART_SUB_CLASS_MAPPING.get(normalized_sub_class)
290+
normalized_sub_class = _canonicalize_chart_sub_class(sub_class)
291+
mapped_sub_class = CANONICAL_CHART_SUB_CLASS_LOOKUP.get(normalized_sub_class)
267292
if mapped_sub_class is not None:
268293
return mapped_sub_class
269294

270-
logger.warning("Unknown chart sub_class: {}; mapped to default", sub_class)
271-
return "default"
295+
logger.warning("Unknown chart sub_class: {}; mapped to complex_scientific", sub_class)
296+
return "complex_scientific"
272297

273298

274299
def process_image_or_chart(content: str) -> dict[str, str]:

0 commit comments

Comments
 (0)