Skip to content

Commit 21e5e9d

Browse files
adding plotly
instructing the template with the opportunity to use plotly
1 parent a5d6a46 commit 21e5e9d

File tree

15 files changed

+270
-7
lines changed

15 files changed

+270
-7
lines changed

pandasai/core/code_execution/environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def get_environment() -> dict:
2929
"pd": import_dependency("pandas"),
3030
"plt": import_dependency("matplotlib.pyplot"),
3131
"np": import_dependency("numpy"),
32+
"px": import_dependency("plotly.express"),
3233
}
3334

3435
return env

pandasai/core/code_generation/code_cleaning.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ def clean_code(self, code: str) -> str:
136136
tuple: Cleaned code as a string and a list of additional dependencies.
137137
"""
138138
code = self._replace_output_filenames_with_temp_chart(code)
139+
code = self._replace_output_filenames_with_temp_json_chart(code)
139140

140-
# If plt.show is in the code, remove that line
141-
code = re.sub(r"plt.show\(\)", "", code)
141+
# If plt.show or fig.show is in the code, remove that line
142+
code = re.sub(r"[a-z].show\(\)", "", code)
142143

143144
tree = ast.parse(code)
144145
new_body = []
@@ -166,3 +167,16 @@ def _replace_output_filenames_with_temp_chart(self, code: str) -> str:
166167
lambda m: f"{m.group(1)}{chart_path}{m.group(1)}",
167168
code,
168169
)
170+
171+
def _replace_output_filenames_with_temp_json_chart(self, code: str) -> str:
172+
"""
173+
Replace output file names with "temp_chart.json" (in case of usage of plotly).
174+
"""
175+
_id = uuid.uuid4()
176+
chart_path = os.path.join(DEFAULT_CHART_DIRECTORY, f"temp_chart_{_id}.json")
177+
chart_path = chart_path.replace("\\", "\\\\")
178+
return re.sub(
179+
r"""(['"])([^'"]*\.json)\1""",
180+
lambda m: f"{m.group(1)}{chart_path}{m.group(1)}",
181+
code,
182+
)

pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def execute_sql_query(sql_query: str) -> pd.Dataframe
1010
"""This method connects to the database, executes the sql query and returns the dataframe"""
1111
</function>
1212

13+
For the charts, you can either use `matplotlib.pyplot` or `plotly.express` to generate the charts.
14+
If you use `plotly.express`, you have to save each chart as a dictionary into a JSON file.
15+
1316
{% if last_code_generated != "" and context.memory.count() > 0 %}
1417
{{ last_code_generated }}
1518
{% else %}
@@ -31,4 +34,4 @@ At the end, declare "result" variable as a dictionary of type and value.
3134

3235
Generate python code and return full updated code:
3336

34-
### Note: Use only relevant table for query and do aggregation, sorting, joins and grouby through sql query
37+
### Note: Use only relevant table for query and do aggregation, sorting, joins and group by through sql query
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% if not output_type %}
2-
type (possible values "string", "number", "dataframe", "plot"). No other type available. Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }
2+
type (possible values "string", "number", "dataframe", "plot", "iplot"). No other type available. "plot" is when "matplotlib" is used; "iplot" when "plotly" si used. Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } or { "type": "iplot", "value": "temp_chart.json" }
33
{% elif output_type == "number" %}
44
type (must be "number"), value must int. Example: { "type": "number", "value": 125 }
55
{% elif output_type == "string" %}
@@ -8,4 +8,6 @@ type (must be "string"), value must be string. Example: { "type": "string", "val
88
type (must be "dataframe"), value must be pd.DataFrame or pd.Series. Example: { "type": "dataframe", "value": pd.DataFrame({...}) }
99
{% elif output_type == "plot" %}
1010
type (must be "plot"), value must be string. Example: { "type": "plot", "value": "temp_chart.png" }
11+
{% elif output_type == "iplot" %}
12+
type (must be "iplot"), value must be string. Example: { "type": "iplot", "value": "temp_chart.json" }
1113
{% endif %}

pandasai/core/response/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .chart import ChartResponse
33
from .dataframe import DataFrameResponse
44
from .error import ErrorResponse
5+
from .interactive_chart import InteractiveChartResponse
56
from .number import NumberResponse
67
from .parser import ResponseParser
78
from .string import StringResponse
@@ -10,6 +11,7 @@
1011
"ResponseParser",
1112
"BaseResponse",
1213
"ChartResponse",
14+
"InteractiveChartResponse",
1315
"DataFrameResponse",
1416
"NumberResponse",
1517
"StringResponse",
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
import os
3+
from typing import Any
4+
5+
from .base import BaseResponse
6+
7+
8+
class InteractiveChartResponse(BaseResponse):
9+
def __init__(self, value: Any, last_code_executed: str):
10+
super().__init__(value, "ichart", last_code_executed)
11+
12+
def _get_chart(self) -> dict:
13+
if isinstance(self.value, dict):
14+
return self.value
15+
16+
if isinstance(self.value, str):
17+
if os.path.exists(self.value):
18+
with open(self.value, "rb") as f:
19+
return json.load(f)
20+
21+
return json.loads(self.value)
22+
23+
raise ValueError("Invalid value type for InteractiveChartResponse. Expected dict or str.")
24+
25+
def save(self, path: str):
26+
img = self._get_chart()
27+
with open(path, "w") as f:
28+
json.dump(img, f)
29+
30+
def __str__(self) -> str:
31+
return self.value if isinstance(self.value, str) else json.dumps(self.value)
32+
33+
def get_dict_image(self) -> dict:
34+
return self._get_chart()

pandasai/core/response/parser.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .base import BaseResponse
99
from .chart import ChartResponse
1010
from .dataframe import DataFrameResponse
11+
from .interactive_chart import InteractiveChartResponse
1112
from .number import NumberResponse
1213
from .string import StringResponse
1314

@@ -26,6 +27,8 @@ def _generate_response(self, result: dict, last_code_executed: str = None):
2627
return DataFrameResponse(result["value"], last_code_executed)
2728
elif result["type"] == "plot":
2829
return ChartResponse(result["value"], last_code_executed)
30+
elif result["type"] == "iplot":
31+
return InteractiveChartResponse(result["value"], last_code_executed)
2932
else:
3033
raise InvalidOutputValueMismatch(f"Invalid output type: {result['type']}")
3134

@@ -72,4 +75,16 @@ def _validate_response(self, result: dict):
7275
"Invalid output: Expected a plot save path str but received an incompatible type."
7376
)
7477

78+
elif result["type"] == "iplot":
79+
if not isinstance(result["value"], (str, dict)):
80+
raise InvalidOutputValueMismatch(
81+
"Invalid output: Expected a plot save path str but received an incompatible type."
82+
)
83+
84+
path_to_plot_pattern = r"^(\/[\w.-]+)+(/[\w.-]+)*$|^[^\s/]+(/[\w.-]+)*$"
85+
if not bool(re.match(path_to_plot_pattern, result["value"])):
86+
raise InvalidOutputValueMismatch(
87+
"Invalid output: Expected a plot save path str but received an incompatible type."
88+
)
89+
7590
return True

poetry.lock

Lines changed: 47 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ seaborn = "^0.12.2"
2828
sqlglot = "^25.0.3"
2929
pyarrow = "^14.0.1"
3030
pyyaml = "^6.0.2"
31+
plotly = "^6.1.1"
3132

3233
[tool.poetry.group.dev]
3334
optional = true

tests/unit_tests/core/code_execution/test_environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_get_environment_with_secure_mode(self, mock_import_dependency):
1818
self.assertIn("pd", env)
1919
self.assertIn("plt", env)
2020
self.assertIn("np", env)
21+
self.assertIn("px", env)
2122

2223
@patch("pandasai.core.code_execution.environment.import_dependency")
2324
def test_get_environment_without_secure_mode(self, mock_import_dependency):
@@ -28,6 +29,7 @@ def test_get_environment_without_secure_mode(self, mock_import_dependency):
2829
self.assertIn("pd", env)
2930
self.assertIn("plt", env)
3031
self.assertIn("np", env)
32+
self.assertIn("px", env)
3133
self.assertIsInstance(env["pd"], MagicMock)
3234

3335
@patch("pandasai.core.code_execution.environment.importlib.import_module")

0 commit comments

Comments
 (0)