Skip to content

Please add pyi stub #12

@panda7281

Description

@panda7281

Thanks for your language binding, it works great!
But without proper code completion support, things can be quite tricky to implement.

A pyi stub can be automatically generated with some mypy and ast magic, the following code generates 95% symbols correctly

Hope you could add a fully functional stub for this package, thanks!

import mypy
import mypy.stubgen

from PyQt6.QtCore import Qt
from PyQt6.QtGui import QPen, QBrush, QColor
from QCustomPlot_PyQt6 import *
import sys
import ast
import astor

class RemoveUnderscoreVars(ast.NodeTransformer):
    def visit_ClassDef(self, node):
        new_body = []
        for stmt in node.body:
            remove = False
            if isinstance(stmt, ast.AnnAssign):
                if isinstance(stmt.target, ast.Name) and stmt.target.id.startswith('_'):
                    remove = True
            elif isinstance(stmt, ast.Assign):
                for target in stmt.targets:
                    if isinstance(target, ast.Name) and target.id.startswith('_'):
                        remove = True
                        break
            if not remove:
                new_body.append(stmt)
        node.body = new_body
        
        if isinstance(node, ast.ClassDef):
            self.generic_visit(node) 
        return node

class RemoveFailedImports(ast.NodeTransformer):
    def visit_Import(self, node):
        try:
            for alias in node.names:
                __import__(alias.name)
            return node
        except ImportError:
            return None

    def visit_ImportFrom(self, node):
        try:
            module = __import__(node.module, fromlist=[name.name for name in node.names])
            valid_names = []
            for alias in node.names:
                if hasattr(module, alias.name):
                    valid_names.append(alias)
            if valid_names:
                node.names = valid_names
                return node
            return None
        except (ImportError, AttributeError):
            return None

class SymbolCollector(ast.NodeVisitor):
    def __init__(self):
        self.symbols = set()

    def visit_Import(self, node):
        for alias in node.names:
            if alias.asname:
                self.symbols.add(alias.asname)
            else:
                self.symbols.add(alias.name.split('.')[0])

    def visit_ImportFrom(self, node):
        for alias in node.names:
            if alias.asname:
                self.symbols.add(alias.asname)
            else:
                self.symbols.add(alias.name)

    def visit_ClassDef(self, node):
        self.symbols.add(node.name)
        self.generic_visit(node)

    def visit_FunctionDef(self, node):
        self.symbols.add(node.name)
        self.generic_visit(node)

class RemoveUnknownAnnotations(ast.NodeTransformer):
    def __init__(self, known_symbols):
        self.known_symbols = known_symbols

    def visit_AnnAssign(self, node):
        if self._is_unknown(node.annotation):
            return None
        return node

    def visit_FunctionDef(self, node):
        if node.returns and self._is_unknown(node.returns):
            node.returns = None
        
        for arg in node.args.args:
            if arg.annotation and self._is_unknown(arg.annotation):
                arg.annotation = None
        
        return node

    def _check_subscript_slice(self, slice_node: ast.AST) -> bool:
        if isinstance(slice_node, ast.Tuple):
            return any(self._is_unknown(elt) for elt in slice_node.elts)
        if isinstance(slice_node, ast.Index):
            return self._is_unknown(slice_node.value)
        return self._is_unknown(slice_node)

    def _is_unknown(self, node: ast.AST) -> bool:
        if isinstance(node, ast.BinOp):
            return (self._is_unknown(node.left) or 
                    self._is_unknown(node.right))
        
        if isinstance(node, ast.Subscript):
            return (self._is_unknown(node.value) or 
                    self._check_subscript_slice(node.slice))
        
        if isinstance(node, ast.Attribute):
            parts = []
            current = node
            while isinstance(current, ast.Attribute):
                parts.append(current.attr)
                current = current.value
            if isinstance(current, ast.Name):
                parts.append(current.id)
                parts.reverse()
                return ".".join(parts) not in self.known_symbols and parts[-1] not in self.known_symbols
            return True
        
        if isinstance(node, ast.Name):
            return node.id not in self.known_symbols
        
        if isinstance(node, ast.Subscript):
            return self._is_unknown(node.value)
        
        return False

def process_pyi(input_file, output_file):
    with open(input_file, 'r') as f:
        tree = ast.parse(f.read())

    # remove underscore vars
    tree = RemoveUnderscoreVars().visit(tree)
    ast.fix_missing_locations(tree)

    # remove failed imports
    tree = RemoveFailedImports().visit(tree)
    ast.fix_missing_locations(tree)

    # collect symbols
    collector = SymbolCollector()
    collector.visit(tree)
    known_symbols = collector.symbols

    # remove unknown annotations
    tree = RemoveUnknownAnnotations(known_symbols).visit(tree)
    ast.fix_missing_locations(tree)

    new_code = astor.to_source(tree)
    with open(output_file, 'w') as f:
        f.write(new_code)

if __name__ == '__main__':
    args = ['-m', 'QCustomPlot_PyQt6', '-o', 'tmp']
    options = mypy.stubgen.parse_options(args)
    mypy.stubgen.generate_stubs(options)
    process_pyi('tmp/QCustomPlot_PyQt6.pyi', 'venv/Lib/site-packages/QCustomPlot_PyQt6.pyi')

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions