Skip to content

Commit 9bc0d2b

Browse files
committed
Merge code_interpreter and python_interpreter
1 parent 00c125e commit 9bc0d2b

File tree

2 files changed

+97
-203
lines changed

2 files changed

+97
-203
lines changed
Lines changed: 97 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,53 @@
1+
import base64
12
import logging
2-
import signal
3+
from enum import Enum
4+
from typing import Dict, List, Optional
35

6+
import grpc
7+
import orjson as json
48
from asgiref.sync import async_to_sync
9+
from django.conf import settings
10+
from google.protobuf.json_format import MessageToDict, ParseDict
11+
from google.protobuf.struct_pb2 import Struct
512
from pydantic import Field
613

14+
from llmstack.apps.schemas import OutputTemplate
15+
from llmstack.common.runner.proto import runner_pb2, runner_pb2_grpc
716
from llmstack.processors.providers.api_processor_interface import (
817
ApiProcessorInterface,
918
ApiProcessorSchema,
1019
)
20+
from llmstack.processors.providers.promptly import Content, ContentMimeType
1121

1222
logger = logging.getLogger(__name__)
1323

1424

25+
class CodeInterpreterLanguage(str, Enum):
26+
PYTHON = "Python"
27+
28+
def __str__(self):
29+
return self.value
30+
31+
1532
class CodeInterpreterInput(ApiProcessorSchema):
16-
code: str = Field(description="The code to run")
17-
language: str = Field(
18-
default="python",
19-
description="The language of the code",
33+
code: str = Field(description="The code to run", widget="textarea")
34+
language: CodeInterpreterLanguage = Field(
35+
title="Language", description="The language of the code", default=CodeInterpreterLanguage.PYTHON
36+
)
37+
local_variables: Optional[str] = Field(
38+
description="Values for the local variables as a JSON string", widget="textarea"
2039
)
2140

2241

2342
class CodeInterpreterOutput(ApiProcessorSchema):
24-
output: str = Field(..., description="The output of the code")
43+
stdout: List[Content] = Field(default=[], description="Standard output as a list of Content objects")
44+
stderr: str = Field(default="", description="Standard error")
45+
local_variables: Optional[Dict] = Field(description="Local variables as a JSON object")
46+
exit_code: int = Field(default=0, description="Exit code of the process")
2547

2648

2749
class CodeInterpreterConfiguration(ApiProcessorSchema):
28-
pass
50+
timeout: int = Field(default=5, description="Timeout in seconds", ge=1, le=30)
2951

3052

3153
class CodeInterpreterProcessor(
@@ -41,75 +63,84 @@ def slug() -> str:
4163

4264
@staticmethod
4365
def description() -> str:
44-
return "Runs code in a sandboxed environment"
66+
return "Runs the provided code and returns the output"
4567

4668
@staticmethod
4769
def provider_slug() -> str:
4870
return "promptly"
4971

5072
@staticmethod
51-
def tool_only() -> bool:
52-
return True
73+
def get_output_template() -> Optional[OutputTemplate]:
74+
return OutputTemplate(
75+
markdown="""{% for line in stdout %}
5376
54-
def process(self) -> dict:
55-
output_stream = self._output_stream
56-
code = self._input.code
57-
58-
# Run the input code in a sandboxed subprocess environment and return
59-
# the output
60-
if not self._input.language == "python":
61-
raise Exception("Invalid language")
62-
63-
import os
64-
import shutil
65-
import subprocess
66-
import sys
67-
import tempfile
68-
import time
69-
70-
# Create a temporary directory to store the code
71-
temp_dir = tempfile.mkdtemp()
72-
# Create a temporary file to store the code
73-
temp_file = tempfile.NamedTemporaryFile(
74-
dir=temp_dir,
75-
delete=False,
76-
)
77-
# Write the code to the temporary file
78-
temp_file.write(code.encode("utf-8"))
79-
temp_file.close()
80-
81-
# Run the code in a subprocess
82-
process = subprocess.Popen(
83-
[sys.executable, temp_file.name],
84-
stdout=subprocess.PIPE,
85-
stderr=subprocess.PIPE,
86-
preexec_fn=os.setsid,
87-
)
77+
{% if line.mime_type == "text/plain" %}
78+
{{ line.data }}
79+
{% endif %}
80+
81+
{% if line.mime_type == "image/png" %}
82+
![Image](data:image/png;base64,{{line.data}})
83+
{% endif %}
8884
89-
# Wait for the process to finish or timeout
90-
timeout = 5
91-
start_time = time.time()
92-
while process.poll() is None:
93-
time.sleep(0.1)
94-
if time.time() - start_time > timeout:
95-
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
96-
raise Exception("Code timed out")
97-
98-
# Get the output
99-
output, error = process.communicate()
100-
output = output.decode("utf-8")
101-
error = error.decode("utf-8")
102-
103-
# Delete the temporary directory
104-
shutil.rmtree(temp_dir)
105-
106-
# Send the output
107-
async_to_sync(output_stream.write)(
108-
CodeInterpreterOutput(output=output),
85+
{% endfor %}"""
10986
)
11087

111-
if error:
112-
raise Exception(error)
88+
def convert_stdout_to_content(self, stdout) -> List[Content]:
89+
content = []
90+
for entry in stdout:
91+
if not entry.mime_type or entry.mime_type == runner_pb2.ContentMimeType.TEXT:
92+
content.append(Content(mime_type=ContentMimeType.TEXT, data=entry.data.decode("utf-8")))
93+
elif entry.mime_type == runner_pb2.ContentMimeType.JSON:
94+
content.append(Content(mime_type=ContentMimeType.JSON, data=entry.data.decode("utf-8")))
95+
elif entry.mime_type == runner_pb2.ContentMimeType.HTML:
96+
content.append(Content(mime_type=ContentMimeType.HTML, data=entry.data.decode("utf-8")))
97+
elif entry.mime_type == runner_pb2.ContentMimeType.PNG:
98+
data = base64.b64encode(entry.data).decode("utf-8")
99+
content.append(Content(mime_type=ContentMimeType.PNG, data=data))
100+
elif entry.mime_type == runner_pb2.ContentMimeType.JPEG:
101+
data = base64.b64encode(entry.data).decode("utf-8")
102+
content.append(Content(mime_type=ContentMimeType.JPEG, data=data))
103+
elif entry.mime_type == runner_pb2.ContentMimeType.SVG:
104+
data = base64.b64encode(entry.data).decode("utf-8")
105+
content.append(Content(mime_type=ContentMimeType.SVG, data=data))
106+
elif entry.mime_type == runner_pb2.ContentMimeType.PDF:
107+
data = base64.b64encode(entry.data).decode("utf-8")
108+
content.append(Content(mime_type=ContentMimeType.PDF, data=data))
109+
elif entry.mime_type == runner_pb2.ContentMimeType.LATEX:
110+
data = base64.b64encode(entry.data).decode("utf-8")
111+
content.append(Content(mime_type=ContentMimeType.LATEX, data=data))
112+
return content
113+
114+
def process(self) -> dict:
115+
output_stream = self._output_stream
116+
channel = grpc.insecure_channel(f"{settings.RUNNER_HOST}:{settings.RUNNER_PORT}")
117+
stub = runner_pb2_grpc.RunnerStub(channel)
118+
input_data = {}
119+
if self._input.local_variables:
120+
try:
121+
input_data = json.loads(self._input.local_variables)
122+
except Exception as e:
123+
logger.error(f"Error parsing local variables: {e}")
124+
125+
request = runner_pb2.RestrictedPythonCodeRunnerRequest(
126+
source_code=self._input.code,
127+
input_data=ParseDict(input_data, Struct()),
128+
timeout_secs=5,
129+
)
130+
response_iterator = stub.GetRestrictedPythonCodeRunner(request)
131+
for response in response_iterator:
132+
if response.state == runner_pb2.RemoteBrowserState.TERMINATED:
133+
converted_stdout = self.convert_stdout_to_content(response.stdout)
134+
async_to_sync(output_stream.write)(
135+
CodeInterpreterOutput(
136+
stdout=converted_stdout,
137+
stderr=str(response.stderr),
138+
local_variables=MessageToDict(response.local_variables) if response.local_variables else None,
139+
exit_code=0,
140+
)
141+
)
142+
break
113143

114144
output = output_stream.finalize()
145+
115146
return output

llmstack/processors/providers/promptly/python_code_runner.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

0 commit comments

Comments
 (0)