Skip to content

Commit be5184a

Browse files
committed
update tasks.py
1 parent 24d1dc9 commit be5184a

File tree

1 file changed

+51
-56
lines changed

1 file changed

+51
-56
lines changed

tasks.py

Lines changed: 51 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,65 @@
11
import glob
2-
import operator
32
import os
4-
import pkg_resources
5-
import platform
6-
import re
73
import shutil
84
import stat
5+
import sys
96
from pathlib import Path
107

8+
import tomli
119
from invoke import task
10+
from packaging.requirements import Requirement
11+
from packaging.version import Version
1212

1313

14-
COMPARISONS = {
15-
'>=': operator.ge,
16-
'>': operator.gt,
17-
'<': operator.lt,
18-
'<=': operator.le
19-
}
14+
def _get_minimum_versions(dependencies, python_version):
15+
min_versions = {}
16+
for dependency in dependencies:
17+
if '@' in dependency:
18+
name, url = dependency.split(' @ ')
19+
min_versions[name] = f'{url}#egg={name}'
20+
continue
2021

22+
req = Requirement(dependency)
23+
if ';' in dependency:
24+
marker = req.marker
25+
if marker and not marker.evaluate({'python_version': python_version}):
26+
continue # python version does not match
2127

22-
@task
23-
def check_dependencies(c):
24-
c.run('python -m pip check')
28+
if req.name not in min_versions:
29+
min_version = next((spec.version for spec in req.specifier if spec.operator in ('>=', '==')), None)
30+
if min_version:
31+
min_versions[req.name] = f'{req.name}=={min_version}'
32+
33+
elif '@' not in min_versions[req.name]:
34+
existing_version = Version(min_versions[req.name].split('==')[1])
35+
new_version = next((spec.version for spec in req.specifier if spec.operator in ('>=', '==')), existing_version)
36+
if new_version > existing_version:
37+
min_versions[req.name] = f'{req.name}=={new_version}'
38+
39+
return list(min_versions.values())
2540

2641

2742
@task
28-
def unit(c):
29-
c.run('python -m pytest --cov=sigllm --cov-report=xml')
43+
def install_minimum(c):
44+
with open('pyproject.toml', 'rb') as pyproject_file:
45+
pyproject_data = tomli.load(pyproject_file)
3046

47+
dependencies = pyproject_data.get('project', {}).get('dependencies', [])
48+
python_version = '.'.join(map(str, sys.version_info[:2]))
49+
minimum_versions = _get_minimum_versions(dependencies, python_version)
3150

32-
def _validate_python_version(line):
33-
is_valid = True
34-
for python_version_match in re.finditer(r"python_version(<=?|>=?|==)\'(\d\.?)+\'", line):
35-
python_version = python_version_match.group(0)
36-
comparison = re.search(r'(>=?|<=?|==)', python_version).group(0)
37-
version_number = python_version.split(comparison)[-1].replace("'", "")
38-
comparison_function = COMPARISONS[comparison]
39-
is_valid = is_valid and comparison_function(
40-
pkg_resources.parse_version(platform.python_version()),
41-
pkg_resources.parse_version(version_number),
42-
)
51+
if minimum_versions:
52+
c.run(f'python -m pip install {" ".join(minimum_versions)}')
4353

44-
return is_valid
54+
55+
@task
56+
def check_dependencies(c):
57+
c.run('python -m pip check')
4558

4659

4760
@task
48-
def install_minimum(c):
49-
with open('setup.py', 'r') as setup_py:
50-
lines = setup_py.read().splitlines()
51-
52-
versions = []
53-
started = False
54-
for line in lines:
55-
if started:
56-
if line == ']':
57-
started = False
58-
continue
59-
60-
line = line.strip()
61-
if _validate_python_version(line):
62-
requirement = re.match(r'[^>]*', line).group(0)
63-
requirement = re.sub(r"""['",]""", '', requirement)
64-
version = re.search(r'>=?(\d\.?)+\w*', line).group(0)
65-
if version:
66-
version = re.sub(r'>=?', '==', version)
67-
version = re.sub(r"""['",]""", '', version)
68-
requirement += version
69-
versions.append(requirement)
70-
71-
elif (line.startswith('install_requires = [')):
72-
started = True
73-
74-
c.run(f'python -m pip install {" ".join(versions)}')
61+
def unit(c):
62+
c.run('python -m pytest --cov=sigllm --cov-report=xml')
7563

7664

7765
@task
@@ -110,8 +98,15 @@ def tutorials(c):
11098
@task
11199
def lint(c):
112100
check_dependencies(c)
113-
c.run('flake8 sigllm tests')
114-
c.run('isort -c --recursive sigllm tests')
101+
c.run('ruff check .')
102+
c.run('ruff format --check --diff .')
103+
104+
105+
@task
106+
def fix_lint(c):
107+
check_dependencies(c)
108+
c.run('ruff check --fix .')
109+
c.run('ruff format .')
115110

116111

117112
def remove_readonly(func, path, _):

0 commit comments

Comments
 (0)