Skip to content

Commit 6f855ce

Browse files
authored
Added multiple llm provider support (#4)
* added pylint
1 parent 6e55164 commit 6f855ce

File tree

12 files changed

+105
-51
lines changed

12 files changed

+105
-51
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ jobs:
1818
files: |
1919
tests/file.txt
2020
tests/file.json
21-
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
21+
provider: groq
22+
api_key: ${{ secrets.GROQ_API_KEY }}

.pre-commit-config.yaml

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
repos:
22
- repo: https://github.com/PyCQA/isort
3-
rev: 5.13.2
3+
rev: 6.0.1
44
hooks:
55
- id: isort
66

77
- repo: https://github.com/psf/black
8-
rev: 24.10.0
8+
rev: 25.1.0
99
hooks:
1010
- id: black
1111

@@ -15,10 +15,22 @@ repos:
1515
- id: check-yaml
1616
- id: end-of-file-fixer
1717
- id: no-commit-to-branch
18-
args: [-b, main, -b, master]
18+
args: [ -b, main, -b, master ]
1919

2020
- repo: https://github.com/PyCQA/flake8
21-
rev: 7.1.1
21+
rev: 7.1.2
2222
hooks:
2323
- id: flake8
24-
additional_dependencies: [flake8-pyproject]
24+
additional_dependencies: [ flake8-pyproject ]
25+
- repo: local
26+
hooks:
27+
- id: pylint
28+
name: pylint
29+
entry: pylint
30+
language: system
31+
types: [ python ]
32+
args:
33+
[
34+
"-rn", # Only display messages
35+
"-sn", # Don't display the score
36+
]

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
This Github Action uses AI to find typos and grammatical errors in specified data files.
44

55
## Usage
6+
67
Refer [test.yml](./.github/workflows/test.yml)
8+
79
```yaml
810
- name: Lint Text in Data Files
911
uses: actions/text-linter@v1
@@ -12,5 +14,12 @@ Refer [test.yml](./.github/workflows/test.yml)
1214
files: |
1315
tests/file.txt
1416
tests/file.json
15-
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
17+
provider: groq
18+
api_key: ${{ secrets.GROQ_API_KEY }}
19+
model: //optional
1620
```
21+
22+
## Supported Models & Providers
23+
24+
All text models supported by [Litellm](https://docs.litellm.ai/docs/providers).
25+
Hence `provider` name should be same as given in this doc.

action.yml

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@ inputs:
1414
description: 'Comma-separated list of files to analyze'
1515
required: true
1616
type: list
17-
GROQ_API_KEY:
18-
description: 'Groq API Key'
17+
provider:
18+
description: 'Provider of the language model (e.g., openai, anthropic, groq)'
1919
required: true
20+
default: groq
21+
api_key:
22+
description: 'API Key for the specified provider'
23+
required: true
24+
model:
25+
description: 'Language model by the specified provider'
26+
required: false
2027

2128

2229
runs:
@@ -44,7 +51,9 @@ runs:
4451
env:
4552
PR_BASE: ${{ github.base_ref }}
4653
INPUT_FILES: ${{ inputs.files }}
47-
GROQ_API_KEY: ${{ inputs.GROQ_API_KEY }}
54+
PROVIDER: ${{ inputs.provider }}
55+
API_KEY: ${{ inputs.api_key }}
56+
LLM_MODEL: ${{ inputs.model }}
4857
token: ${{ inputs.token }}
4958
PR_NO: ${{ github.event.pull_request.number }}
5059
run: |

local_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from dotenv import load_dotenv
1+
from dotenv import load_dotenv, set_key
22

33
load_dotenv()
44

5+
set_key(".env", "ENVIRONMENT", "local")
6+
57
if __name__ == "__main__":
68
from src.text_linter import main
79

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,12 @@ profile = "black"
1414
line_length = 120
1515
multi_line_output = 3
1616

17+
[tool.pylint.master]
18+
jobs = 2
19+
ignore-paths = "(?!(src|tests))/*"
20+
1721
[tool.pylint.format]
1822
max-line-length = 130
23+
24+
[tool.pylint.messages_control]
25+
disable = ["I", "import-error", "missing-docstring", "duplicate-code"]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
groq==0.13.0
1+
litellm~=1.63.11
22
requests~=2.32.3

src/config.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
import os
2+
from dataclasses import dataclass
23

34

4-
def get_env(x):
5-
return os.environ[x]
5+
def get_env(x, optional=False):
6+
return os.environ.get(x) if optional else os.environ[x]
67

78

8-
GROQ_API_KEY = get_env("GROQ_API_KEY")
9-
PAT_TOKEN = get_env("token")
9+
ENVIRONMENT = get_env("ENVIRONMENT", optional=True)
1010

11-
REPO = get_env("GITHUB_REPOSITORY")
12-
PR_BASE = get_env("PR_BASE")
13-
PR_NO = get_env("PR_NO")
14-
INPUT_FILES = [*map(str.strip, get_env("INPUT_FILES").splitlines())]
11+
12+
@dataclass
13+
class GitEnv:
14+
PAT_TOKEN = get_env("token")
15+
REPO = get_env("GITHUB_REPOSITORY")
16+
PR_BASE = get_env("PR_BASE")
17+
PR_NO = get_env("PR_NO")
18+
INPUT_FILES = [*map(str.strip, get_env("INPUT_FILES").splitlines())]
19+
20+
21+
PROVIDER = get_env("PROVIDER").lower()
22+
LLM_API_KEY = get_env("API_KEY")
23+
LLM_MODELS = {"openai": "gpt-3.5-turbo", "anthropic": "claude-3-haiku-20240307", "groq": "llama3-70b-8192"}
24+
LLM_MODEL = get_env("LLM_MODEL", optional=True) or LLM_MODELS[PROVIDER]
Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
import json
2-
from functools import cache
32

4-
from groq import NOT_GIVEN, BadRequestError, Groq
3+
from litellm import BadRequestError, completion, validate_environment
54

6-
from src.config import GROQ_API_KEY
5+
from src.config import LLM_API_KEY, LLM_MODEL, PROVIDER
76

8-
LLM_MODEL = ["llama-3.1-8b-instant", "llama3-70b-8192"][1]
97

8+
def validate_model(provider: str = PROVIDER, model: str = LLM_MODEL) -> None:
9+
model_str = f"{provider}/{model}"
10+
assert not validate_environment(model_str)["keys_in_environment"], f"Invalid value : {model}"
1011

11-
@cache
12-
def get_groq_client():
13-
return Groq(api_key=GROQ_API_KEY)
1412

15-
16-
def ask_groq(query_text: str, system_content: str, json_schema: dict = None) -> dict:
17-
client = get_groq_client()
18-
response_format = NOT_GIVEN
13+
def ask_llm(query_text: str, system_content: str, json_schema: dict = None, provider=PROVIDER) -> dict:
14+
response_format = None
1915
if json_schema:
2016
system_content = f"{system_content}\n Use this json schema to reply {json.dumps(json_schema)}"
2117
response_format = {"type": "json_object"}
2218
try:
23-
chat_completion = client.chat.completions.create(
19+
# Send a message to the model
20+
print("calling llm...")
21+
response = completion(
22+
model=f"{provider}/{LLM_MODEL}",
23+
api_key=LLM_API_KEY,
2424
messages=[{"role": "system", "content": system_content}, {"role": "user", "content": query_text}],
25-
model=LLM_MODEL,
2625
response_format=response_format,
2726
)
28-
json_str = chat_completion.choices[0].message.content
29-
json_str = json_str.strip()
27+
json_str = response["choices"][0]["message"]["content"]
3028
except BadRequestError as e:
31-
failed_json = e.body["error"]["failed_generation"] # noqa
32-
json_str = failed_json.replace('"""', '"').replace("\n", "\\n")
29+
json_str = e.message # reusing variable for showing error
30+
31+
if json_str.startswith("litellm"):
32+
raise RuntimeError(json_str)
33+
3334
if not json_schema:
3435
return json_str
3536
try:
36-
# json_str = re.sub("\n +", "", json_str)
3737
return json.loads(json_str)
3838
except json.JSONDecodeError as e:
3939
print(e.doc.encode(), e.pos)
@@ -47,7 +47,7 @@ def find_typos(query_text):
4747
" and list mistakes (if any) along with suggestion, not fixes."
4848
"Keep it short, no explanation. Do not modify urls, usernames and hashtags."
4949
)
50-
ans = ask_groq(query_text, system_content, json_schema={"suggestions": ["str"]})
50+
ans = ask_llm(query_text, system_content, json_schema={"suggestions": ["str"]})
5151
return ans["suggestions"]
5252

5353

@@ -57,12 +57,12 @@ def fix_typos(query_text):
5757
"without any additional info like headings, footers, etc."
5858
"Do not modify urls, usernames and hashtags. `~~~` is a phrase seperator, keep it as it is."
5959
)
60-
return ask_groq(query_text, system_content)
60+
return ask_llm(query_text, system_content)
6161

6262

6363
if __name__ == "__main__":
6464
from dotenv import load_dotenv # noqa
6565

6666
load_dotenv()
67-
answer = ask_groq("Debuging is hard.", system_content="")
68-
print(answer)
67+
answer = ask_llm("Debuging is hard.", system_content="")
68+
print(f"Response from model: {answer}")

src/text_linter.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import requests
44

5-
from src.config import INPUT_FILES, PAT_TOKEN, PR_BASE, PR_NO, REPO
6-
from src.groq_ai import find_typos
5+
from src.config import ENVIRONMENT, GitEnv
6+
from src.llm_service import find_typos, validate_model
77

88

9-
def process_diff(file_path, base_branch=PR_BASE):
9+
def process_diff(file_path, base_branch=GitEnv.PR_BASE):
1010
try:
1111
# Get diff from PR
1212
diff_command = f"git diff -U0 origin/{base_branch}... -- {file_path}"
@@ -21,19 +21,24 @@ def process_diff(file_path, base_branch=PR_BASE):
2121

2222

2323
def post_comment(comment):
24-
url = f"https://api.github.com/repos/{REPO}/issues/{PR_NO}/comments"
24+
print(comment)
25+
if ENVIRONMENT == "local":
26+
return
27+
28+
url = f"https://api.github.com/repos/{GitEnv.REPO}/issues/{GitEnv.PR_NO}/comments"
2529
headers = {
2630
"Accept": "application/vnd.github+json",
27-
"Authorization": f"Bearer {PAT_TOKEN}",
31+
"Authorization": f"Bearer {GitEnv.PAT_TOKEN}",
2832
"X-GitHub-Api-Version": "2022-11-28",
2933
}
3034
resp = requests.post(url, headers=headers, json={"body": comment}, timeout=300)
3135
resp.raise_for_status()
3236

3337

3438
def main():
39+
validate_model()
3540
# Process diff(s) for each file
36-
results = {file_path: process_diff(file_path) for file_path in INPUT_FILES if file_path}
41+
results = {file_path: process_diff(file_path) for file_path in GitEnv.INPUT_FILES if file_path}
3742

3843
# Create markdown comment for fixes
3944
flag = False
@@ -49,7 +54,6 @@ def main():
4954
else:
5055
comment = "### No typos found"
5156

52-
print(comment)
5357
post_comment(comment)
5458

5559

0 commit comments

Comments
 (0)