diff --git a/docs/change_log.md b/docs/change_log.md index dbdc25b5..6691e836 100644 --- a/docs/change_log.md +++ b/docs/change_log.md @@ -7,6 +7,9 @@ - deprecate result_row_t::as_tuple - new as_tuple(const result_row_t&) - new get_sql_name_tuple(const result_row_t&), #72 +- sqlpp23-ddl2cpp: command-line option `--path-to-datatype-file` was renamed to `--path-to-custom-types` +- sqlpp23-ddl2cpp: if an error is found in the custom types file, the program exits with error code 30 +- sqlpp23-ddl2cpp: base types in the custom types file should use `snake_case`, although the old `CamelCase` is still supported for backwards compatibility ## 0.67 diff --git a/docs/ddl2cpp.md b/docs/ddl2cpp.md index 5fd84fa2..45584091 100644 --- a/docs/ddl2cpp.md +++ b/docs/ddl2cpp.md @@ -40,7 +40,7 @@ For detailed instructions refer to the documentation of your database. | --path-to-header PATH_TO_HEADER | No[^3] | | Output pathname of the generated C++ header. | Second command-line argument | | --path-to-header-directory PATH_TO_HEADER_DIRECTORY | No[^3] | | Output Directory for the generated C++ headers | Second command-line argument + -split-tables. | | --path-to-module PATH_TO_MODULE | No[^2][^3] | | Output pathname of the generated C++ module. | N/A | -| --path-to-datatype-file PATH_TO_DATATYPE_FILE | No | | Input pathname of a CSV file defining aliases of existing data types. | Same | +| --path-to-custom-types PATH_TO_CUSTOM_TYPES | No | | Input pathname of a CSV file defining aliases of existing data types. | Same | **Additional options** | Option | Required | Default | Description | Before v0.67 | diff --git a/scripts/sqlpp23-ddl2cpp b/scripts/sqlpp23-ddl2cpp index e37c4a7b..0bdef98b 100755 --- a/scripts/sqlpp23-ddl2cpp +++ b/scripts/sqlpp23-ddl2cpp @@ -36,6 +36,7 @@ from enum import IntEnum class ExitCode(IntEnum): SUCCESS = 0 BAD_ARGS = 1 + BAD_CUSTOM_TYPES = 30 BAD_DATA_TYPE = 10 STRANGE_PARSING = 20 @@ -51,11 +52,11 @@ class DdlParser: """ # Names of data types that can be updated by the custom data types - ddlBooleanTypes = [ + ddl_boolean_types = [ "bool", "boolean", ] - ddlIntegerTypes = [ + ddl_integer_types = [ "bigint", "int", "int2", # PostgreSQL @@ -66,7 +67,7 @@ class DdlParser: "smallint", "tinyint", ] - ddlSerialTypes = [ + ddl_serial_types = [ "bigserial", # PostgreSQL "serial", # PostgreSQL "serial2", # PostgreSQL @@ -74,7 +75,7 @@ class DdlParser: "serial8", # PostgreSQL "smallserial", # PostgreSQL ] - ddlFloatingPointTypes = [ + ddl_floating_point_types = [ "decimal", # MYSQL "double", "float8", # PostgreSQL @@ -83,7 +84,7 @@ class DdlParser: "numeric", # PostgreSQL "real", ] - ddlTextTypes = [ + ddl_text_types = [ "char", "varchar", "character varying", # PostgreSQL @@ -98,7 +99,7 @@ class DdlParser: "mediumtext", # MYSQL "rational", # PostgreSQL pg_rationale extension ] - ddlBlobTypes = [ + ddl_blob_types = [ "bytea", "tinyblob", "blob", @@ -107,120 +108,120 @@ class DdlParser: "binary", # MYSQL "varbinary", # MYSQL ] - ddlDateTypes = [ + ddl_date_types = [ "date", ] - ddlDateTimeTypes = [ + ddl_date_time_types = [ "datetime", "timestamp", "timestamp without time zone", # PostgreSQL "timestamp with time zone", # PostgreSQL "timestamptz", # PostgreSQL ] - ddlTimeTypes = [ + ddl_time_types = [ "time", "time without time zone", # PostgreSQL "time with time zone", # PostgreSQL ] # Parsers that are initialized later - ddlExpression = None - ddlType = None - ddlColumn = None - ddlConstraint = None - ddlCtWithSql = None + ddl_expression = None + ddl_type = None + ddl_column = None + ddl_constraint = None + ddl_ct_with_sql = None ddl = None @classmethod - def initialize(cls, customTypes=None): + def initialize(cls, custom_types=None): """Initialize the DDL parser""" # Basic parsers - ddlLeft = pp.Suppress("(") - ddlRight = pp.Suppress(")") - ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee") - ddlString = ( - pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`") + ddl_left = pp.Suppress("(") + ddl_right = pp.Suppress(")") + ddl_number = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee") + ddl_string = ( + pp.QuotedString("'") | pp.QuotedString('"', esc_quote='""') | pp.QuotedString("`") ) - # ddlString.setDebug(True) #uncomment to debug pyparsing - ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_.$") - ddlName = pp.Or([ddlTerm, ddlString, pp.Combine(ddlString + "." + ddlString), pp.Combine(ddlTerm + ddlString)]) - ddlOperator = pp.Or( + # ddl_string.set_debug(True) #uncomment to debug pyparsing + ddl_term = pp.Word(pp.alphas + "_", pp.alphanums + "_.$") + ddl_name = pp.Or([ddl_term, ddl_string, pp.Combine(ddl_string + "." + ddl_string), pp.Combine(ddl_term + ddl_string)]) + ddl_operator = pp.Or( map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"]), pp.CaselessKeyword("DIV") ) - ddlBracedExpression = pp.Forward() - ddlFunctionCall = pp.Forward() - ddlCastEnd = "::" + ddlTerm - ddlCast = ddlString + ddlCastEnd - ddlBracedArguments = pp.Forward() - cls.ddlExpression = pp.OneOrMore( - ddlBracedExpression - | ddlFunctionCall - | ddlCastEnd - | ddlCast - | ddlOperator - | ddlString - | ddlTerm - | ddlNumber - | ddlBracedArguments + ddl_braced_expression = pp.Forward() + ddl_function_call = pp.Forward() + ddl_cast_end = "::" + ddl_term + ddl_cast = ddl_string + ddl_cast_end + ddl_braced_arguments = pp.Forward() + cls.ddl_expression = pp.OneOrMore( + ddl_braced_expression + | ddl_function_call + | ddl_cast_end + | ddl_cast + | ddl_operator + | ddl_string + | ddl_term + | ddl_number + | ddl_braced_arguments ) - ddlBracedArguments << ddlLeft + pp.delimitedList(cls.ddlExpression) + ddlRight - ddlBracedExpression << ddlLeft + cls.ddlExpression + ddlRight - ddlArguments = pp.Suppress(pp.delimitedList(cls.ddlExpression)) - ddlFunctionCall << ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight + ddl_braced_arguments << ddl_left + pp.DelimitedList(cls.ddl_expression) + ddl_right + ddl_braced_expression << ddl_left + cls.ddl_expression + ddl_right + ddl_arguments = pp.Suppress(pp.DelimitedList(cls.ddl_expression)) + ddl_function_call << ddl_name + ddl_left + pp.Optional(ddl_arguments) + ddl_right # Data type parsers - def get_type_parser(key, data_type): - type_names = getattr(cls, f"ddl{key}Types") - if customTypes and (key in customTypes): - type_names.extend(customTypes[key]) + def get_type_parser(base_type, data_type): + type_names = getattr(cls, f"ddl_{base_type}_types") + if custom_types and (base_type in custom_types): + type_names.extend(custom_types[base_type]) return pp.Or( map(pp.CaselessKeyword, sorted(type_names, reverse=True)) - ).setParseAction(pp.replaceWith(data_type)) - - ddlBoolean = get_type_parser("Boolean", "boolean") - ddlInteger = get_type_parser("Integer", "integral") - ddlSerial = get_type_parser("Serial", "integral").setResultsName("hasSerialValue") - ddlFloatingPoint = get_type_parser("FloatingPoint", "floating_point") - ddlText = get_type_parser("Text", "text") - ddlBlob = get_type_parser("Blob", "blob") - ddlDate = get_type_parser("Date", "date") - ddlDateTime = get_type_parser("DateTime", "timestamp") - ddlTime = get_type_parser("Time", "time") - ddlUnknown = pp.Word(pp.alphanums).setParseAction(pp.replaceWith("UNKNOWN")) - cls.ddlType = ( - ddlBoolean - | ddlInteger - | ddlSerial - | ddlFloatingPoint - | ddlText - | ddlBlob - | ddlDateTime - | ddlDate - | ddlTime - | ddlUnknown + ).set_parse_action(pp.replace_with(data_type)) + + ddl_boolean = get_type_parser("boolean", "boolean") + ddl_integer = get_type_parser("integer", "integral") + ddl_serial = get_type_parser("serial", "integral").set_results_name("has_serial_value") + ddl_floating_point = get_type_parser("floating_point", "floating_point") + ddl_text = get_type_parser("text", "text") + ddl_blob = get_type_parser("blob", "blob") + ddl_date = get_type_parser("date", "date") + ddl_date_time = get_type_parser("date_time", "timestamp") + ddl_time = get_type_parser("time", "time") + ddl_unknown = pp.Word(pp.alphanums).set_parse_action(pp.replace_with("UNKNOWN")) + cls.ddl_type = ( + ddl_boolean + | ddl_integer + | ddl_serial + | ddl_floating_point + | ddl_text + | ddl_blob + | ddl_date_time + | ddl_date + | ddl_time + | ddl_unknown ) # Constraints parser - ddlUnsigned = pp.CaselessKeyword("UNSIGNED").setResultsName("isUnsigned") - ddlDigits = "," + pp.Word(pp.nums) - ddlWidth = ddlLeft + pp.Word(pp.nums) + pp.Optional(ddlDigits) + ddlRight - ddlTimezone = ( + ddl_unsigned = pp.CaselessKeyword("UNSIGNED").set_results_name("is_unsigned") + ddl_digits = "," + pp.Word(pp.nums) + ddl_width = ddl_left + pp.Word(pp.nums) + pp.Optional(ddl_digits) + ddl_right + ddl_timezone = ( (pp.CaselessKeyword("with") | pp.CaselessKeyword("without")) + pp.CaselessKeyword("time") + pp.CaselessKeyword("zone") ) - ddlNotNull = (pp.CaselessKeyword("NOT") + pp.CaselessKeyword("NULL")).setResultsName("notNull") - ddlDefaultValue = pp.CaselessKeyword("DEFAULT").setResultsName("hasDefaultValue") - ddlGeneratedValue = pp.CaselessKeyword("GENERATED").setResultsName("hasGeneratedValue") - ddlAutoKeywords = [ + ddl_not_null = (pp.CaselessKeyword("NOT") + pp.CaselessKeyword("NULL")).set_results_name("not_null") + ddl_default_value = pp.CaselessKeyword("DEFAULT").set_results_name("has_default_value") + ddl_generated_value = pp.CaselessKeyword("GENERATED").set_results_name("has_generated_value") + ddl_auto_keywords = [ "AUTO_INCREMENT", "AUTOINCREMENT" ] - ddlAutoValue = pp.Or(map(pp.CaselessKeyword, sorted(ddlAutoKeywords, reverse=True))).setResultsName("hasAutoValue") - ddlPrimaryKey = (pp.CaselessKeyword("PRIMARY") + pp.CaselessKeyword("KEY")).setResultsName("isPrimaryKey") - ddlIgnoredKeywords = [ + ddl_auto_value = pp.Or(map(pp.CaselessKeyword, sorted(ddl_auto_keywords, reverse=True))).set_results_name("has_auto_value") + ddl_primary_key = (pp.CaselessKeyword("PRIMARY") + pp.CaselessKeyword("KEY")).set_results_name("is_primary_key") + ddl_ignored_keywords = [ "CONSTRAINT", "FOREIGN", "KEY", @@ -230,61 +231,65 @@ class DdlParser: "CHECK", "PERIOD", ] - cls.ddlConstraint = ( + cls.ddl_constraint = ( pp.Or(map( pp.CaselessKeyword, - sorted(ddlIgnoredKeywords + ["PRIMARY"], reverse=True) + sorted(ddl_ignored_keywords + ["PRIMARY"], reverse=True) )) - + cls.ddlExpression - ).setResultsName("isConstraint") + + cls.ddl_expression + ).set_results_name("is_constraint") # Column parser - cls.ddlColumn = pp.Group( - ddlName.setResultsName("name") - + cls.ddlType.setResultsName("type") - + pp.Suppress(pp.Optional(ddlWidth)) - + pp.Suppress(pp.Optional(ddlTimezone)) + cls.ddl_column = pp.Group( + ddl_name.set_results_name("name") + + cls.ddl_type.set_results_name("type") + + pp.Suppress(pp.Optional(ddl_width)) + + pp.Suppress(pp.Optional(ddl_timezone)) + pp.ZeroOrMore( - ddlUnsigned - | ddlNotNull + ddl_unsigned + | ddl_not_null | pp.Suppress(pp.CaselessKeyword("NULL")) - | ddlAutoValue - | ddlDefaultValue - | ddlGeneratedValue - | ddlPrimaryKey - | pp.Suppress(pp.OneOrMore(pp.Or(map(pp.CaselessKeyword, sorted(ddlIgnoredKeywords, reverse=True))))) - | pp.Suppress(cls.ddlExpression) + | ddl_auto_value + | ddl_default_value + | ddl_generated_value + | ddl_primary_key + | pp.Suppress(pp.OneOrMore(pp.Or(map(pp.CaselessKeyword, sorted(ddl_ignored_keywords, reverse=True))))) + | pp.Suppress(cls.ddl_expression) ) ) # CREATE TABLE parser - ddlCtBasic = ( + ddl_ct_basic = ( pp.Suppress(pp.CaselessKeyword("CREATE")) + pp.Suppress(pp.Optional(pp.CaselessKeyword("OR") + pp.CaselessKeyword("REPLACE"))) + pp.Suppress(pp.CaselessKeyword("TABLE")) + pp.Suppress(pp.Optional(pp.CaselessKeyword("IF") + pp.CaselessKeyword("NOT") + pp.CaselessKeyword("EXISTS"))) - + ddlName.setResultsName("tableName") - + ddlLeft - + pp.Group(pp.delimitedList(pp.Suppress(cls.ddlConstraint) | cls.ddlColumn)).setResultsName("columns") - + ddlRight + + ddl_name.set_results_name("table_name") + + ddl_left + + pp.Group(pp.DelimitedList(pp.Suppress(cls.ddl_constraint) | cls.ddl_column)).set_results_name("columns") + + ddl_right ) - def addCreateSql(text, loc, tokens): + def add_create_sql(text, loc, tokens): create = tokens.create - create.value["createSql"] = text[create.locn_start:create.locn_end] - cls.ddlCtWithSql = pp.Located(ddlCtBasic).setResultsName("create").setParseAction(addCreateSql) + create.value["create_sql"] = text[create.locn_start:create.locn_end] + cls.ddl_ct_with_sql = pp.Located(ddl_ct_basic).set_results_name("create").set_parse_action(add_create_sql) # Main DDL parser - cls.ddl = pp.OneOrMore(pp.Group(pp.Suppress(pp.SkipTo(ddlCtBasic, False)) + cls.ddlCtWithSql)).setResultsName("tables") - ddlComment = pp.oneOf(["--", "#"]) + pp.restOfLine - cls.ddl.ignore(ddlComment) - cls.ddl.parseWithTabs() + cls.ddl = pp.OneOrMore(pp.Group(pp.Suppress(pp.SkipTo(ddl_ct_basic, False)) + cls.ddl_ct_with_sql)).set_results_name("tables") + ddl_comment = pp.one_of(["--", "#"]) + pp.rest_of_line + cls.ddl.ignore(ddl_comment) + cls.ddl.parse_with_tabs() + + @classmethod + def is_base_type(cls, name): + return hasattr(cls, f"ddl_{name}_types") @classmethod def parse_ddls(cls, ddl_paths): try: - return [cls.ddl.parseFile(path) for path in ddl_paths] + return [cls.ddl.parse_file(path) for path in ddl_paths] except pp.ParseException as e: - print("ERROR: failed to parse " + path) + print(f"ERROR: Failed to parse {path}") print(e.explain(1)) sys.exit(ExitCode.STRANGE_PARSING) @@ -296,99 +301,99 @@ class SelfTest: def run(cls): print("Running self-test") DdlParser.initialize() - cls.testBoolean() - cls.testInteger() - cls.testSerial() - cls.testFloatingPoint() - cls.testText() - cls.testBlob() - cls.testDate() - cls.testTime() - cls.testUnknown() - cls.testDateTime() - cls.testColumn() - cls.testConstraint() - cls.testMathExpression() - cls.testRational() - cls.testTable() - cls.testPrimaryKeyAutoIncrement() + cls._test_boolean() + cls._test_integer() + cls._test_serial() + cls._test_floating_point() + cls._test_text() + cls._test_blob() + cls._test_date() + cls._test_time() + cls._test_unknown() + cls._test_date_time() + cls._test_column() + cls._test_constraint() + cls._test_math_expression() + cls._test_rational() + cls._test_table() + cls._test_primary_key_auto_increment() @staticmethod - def testBoolean(): - for t in DdlParser.ddlBooleanTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_boolean(): + for t in DdlParser.ddl_boolean_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "boolean" @staticmethod - def testInteger(): - for t in DdlParser.ddlIntegerTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_integer(): + for t in DdlParser.ddl_integer_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "integral" @staticmethod - def testSerial(): - for t in DdlParser.ddlSerialTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_serial(): + for t in DdlParser.ddl_serial_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "integral" - assert result.hasSerialValue + assert result.has_serial_value @staticmethod - def testFloatingPoint(): - for t in DdlParser.ddlFloatingPointTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_floating_point(): + for t in DdlParser.ddl_floating_point_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "floating_point" @staticmethod - def testText(): - for t in DdlParser.ddlTextTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_text(): + for t in DdlParser.ddl_text_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "text" @staticmethod - def testBlob(): - for t in DdlParser.ddlBlobTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_blob(): + for t in DdlParser.ddl_blob_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "blob" @staticmethod - def testDate(): - for t in DdlParser.ddlDateTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_date(): + for t in DdlParser.ddl_date_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "date" @staticmethod - def testDateTime(): - for t in DdlParser.ddlDateTimeTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_date_time(): + for t in DdlParser.ddl_date_time_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "timestamp" @staticmethod - def testTime(): - for t in DdlParser.ddlTimeTypes: - result = DdlParser.ddlType.parseString(t, parseAll=True) + def _test_time(): + for t in DdlParser.ddl_time_types: + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "time" @staticmethod - def testUnknown(): + def _test_unknown(): for t in ["cheesecake", "blueberry"]: - result = DdlParser.ddlType.parseString(t, parseAll=True) + result = DdlParser.ddl_type.parse_string(t, parse_all=True) assert result[0] == "UNKNOWN" @staticmethod - def testColumn(): - testData = [ + def _test_column(): + test_data = [ { "text": "\"id\" int(8) unsigned NOT NULL DEFAULT nextval('dk_id_seq'::regclass)", "expected": { "name": "id", "type": "integral", - "isUnsigned": True, - "notNull": True, - "hasAutoValue": False, - "hasDefaultValue": True, - "hasGeneratedValue": False, - "hasSerialValue": False, - "isPrimaryKey": False + "is_unsigned": True, + "not_null": True, + "has_auto_value": False, + "has_default_value": True, + "has_generated_value": False, + "has_serial_value": False, + "is_primary_key": False } }, { @@ -396,13 +401,13 @@ class SelfTest: "expected": { "name": "fld", "type": "integral", - "isUnsigned": False, - "notNull": False, - "hasAutoValue": True, - "hasDefaultValue": False, - "hasGeneratedValue": False, - "hasSerialValue": False, - "isPrimaryKey": False + "is_unsigned": False, + "not_null": False, + "has_auto_value": True, + "has_default_value": False, + "has_generated_value": False, + "has_serial_value": False, + "is_primary_key": False } }, { @@ -410,61 +415,61 @@ class SelfTest: "expected": { "name": "fld2", "type": "integral", - "isUnsigned": False, - "notNull": True, - "hasAutoValue": False, - "hasDefaultValue": False, - "hasGeneratedValue": True, - "hasSerialValue": False, - "isPrimaryKey": False + "is_unsigned": False, + "not_null": True, + "has_auto_value": False, + "has_default_value": False, + "has_generated_value": True, + "has_serial_value": False, + "is_primary_key": False } } ] - for td in testData: - result = DdlParser.ddlColumn.parseString(td["text"], parseAll=True)[0] + for td in test_data: + result = DdlParser.ddl_column.parse_string(td["text"], parse_all=True)[0] expected = td["expected"] assert result.name == expected["name"] assert result.type == expected["type"] - assert bool(result.isUnsigned) == expected["isUnsigned"] - assert bool(result.notNull) == expected["notNull"] - assert bool(result.hasAutoValue) == expected["hasAutoValue"] - assert bool(result.hasDefaultValue) == expected["hasDefaultValue"] - assert bool(result.hasGeneratedValue) == expected["hasGeneratedValue"] - assert bool(result.hasSerialValue) == expected["hasSerialValue"] - assert bool(result.isPrimaryKey) == expected["isPrimaryKey"] + assert bool(result.is_unsigned) == expected["is_unsigned"] + assert bool(result.not_null) == expected["not_null"] + assert bool(result.has_auto_value) == expected["has_auto_value"] + assert bool(result.has_default_value) == expected["has_default_value"] + assert bool(result.has_generated_value) == expected["has_generated_value"] + assert bool(result.has_serial_value) == expected["has_serial_value"] + assert bool(result.is_primary_key) == expected["is_primary_key"] @staticmethod - def testConstraint(): + def _test_constraint(): for text in [ "CONSTRAINT unique_person UNIQUE (first_name, last_name)", "UNIQUE (id)", "UNIQUE (first_name,last_name)" ]: - result = DdlParser.ddlConstraint.parseString(text, parseAll=True) - assert result.isConstraint + result = DdlParser.ddl_constraint.parse_string(text, parse_all=True) + assert result.is_constraint @staticmethod - def testMathExpression(): + def _test_math_expression(): text = "2 DIV 2" - result = DdlParser.ddlExpression.parseString(text, parseAll=True) + result = DdlParser.ddl_expression.parse_string(text, parse_all=True) assert len(result) == 3 assert result[0] == "2" assert result[1] == "DIV" assert result[2] == "2" @staticmethod - def testRational(): + def _test_rational(): for text in [ "pos RATIONAL NOT NULL DEFAULT nextval('rational_seq')::integer", ]: - result = DdlParser.ddlColumn.parseString(text, parseAll=True) + result = DdlParser.ddl_column.parse_string(text, parse_all=True) column = result[0] assert column.name == "pos" assert column.type == "text" - assert column.notNull + assert column.not_null @staticmethod - def testTable(): + def _test_table(): text = """ CREATE TABLE "public"."dk" ( "id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass), @@ -473,62 +478,62 @@ class SelfTest: PRIMARY KEY (id) ) """ - result = DdlParser.ddlCtWithSql.parseString(text, parseAll=True) + result = DdlParser.ddl_ct_with_sql.parse_string(text, parse_all=True) @staticmethod - def testPrimaryKeyAutoIncrement(): + def _test_primary_key_auto_increment(): for text in [ "CREATE TABLE tab (col INTEGER NOT NULL AUTO_INCREMENT PRIMARY KEY)", # mysql "CREATE TABLE tab (col INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT)", # mysql "CREATE TABLE tab (col INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT)", # sqlite ]: - result = DdlParser.ddlCtWithSql.parseString(text, parseAll=True) + result = DdlParser.ddl_ct_with_sql.parse_string(text, parse_all=True) assert len(result) == 1 table = result.create.value - assert table.tableName == "tab" + assert table.table_name == "tab" assert len(table.columns) == 1 column = table.columns[0] - assert not column.isConstraint + assert not column.is_constraint assert column.name == "col" assert column.type == "integral" - assert column.notNull - assert column.hasAutoValue - assert column.isPrimaryKey - assert table.createSql == text + assert column.not_null + assert column.has_auto_value + assert column.is_primary_key + assert table.create_sql == text class ModelWriter: """This class uses the parsed DDL definitions to generate and write the C++ database model file(s)""" @classmethod - def write(cls, parsedDdls, args): + def write(cls, parsed_ddls, args): if args.path_to_header: - cls.createHeader(parsedDdls, args) + cls._create_header(parsed_ddls, args) if args.path_to_header_directory: - cls.createSplitHeaders(parsedDdls, args) + cls._create_split_headers(parsed_ddls, args) if args.path_to_module: - cls.createModule(parsedDdls, args) + cls._create_module(parsed_ddls, args) @classmethod - def createHeader(cls, parsedDdls, args): - header = cls.beginHeader(args.path_to_header, args) - for pd in parsedDdls: + def _create_header(cls, parsed_ddls, args): + header = cls._begin_header(args.path_to_header, args) + for pd in parsed_ddls: for table in pd.tables: - cls.writeTable(table, header, args) - cls.endHeader(header, args) + cls._write_table(table, header, args) + cls._end_header(header, args) @classmethod - def createSplitHeaders(cls, parsedDdls, args): - for pd in parsedDdls: + def _create_split_headers(cls, parsed_ddls, args): + for pd in parsed_ddls: for table in pd.tables: - sqlTableName = table.create.value.tableName - header = cls.beginHeader(os.path.join(args.path_to_header_directory, cls.toClassName(sqlTableName, args) + ".h"), args) - cls.writeTable(table, header, args) - cls.endHeader(header, args) + sql_table_name = table.create.value.table_name + header = cls._begin_header(os.path.join(args.path_to_header_directory, cls._to_class_name(sql_table_name, args) + ".h"), args) + cls._write_table(table, header, args) + cls._end_header(header, args) @staticmethod - def beginHeader(pathToHeader, args): - header = open(pathToHeader, "w") + def _begin_header(path_to_header, args): + header = open(path_to_header, "w") print("#pragma once", file=header) print("", file=header) print("// clang-format off", file=header) @@ -553,21 +558,21 @@ class ModelWriter: return header @staticmethod - def endHeader(header, args): + def _end_header(header, args): print("} // namespace " + args.namespace, file=header) header.close() @classmethod - def createModule(cls, parsedDdls, args): - module = cls.beginModule(args.path_to_module, args) - for pd in parsedDdls: + def _create_module(cls, parsed_ddls, args): + module = cls._begin_module(args.path_to_module, args) + for pd in parsed_ddls: for table in pd.tables: - cls.writeTable(table, module, args) - cls.endModule(module, args) + cls._write_table(table, module, args) + cls._end_module(module, args) @staticmethod - def beginModule(pathToModule, args): - module = open(pathToModule, "w") + def _begin_module(path_to_module, args): + module = open(path_to_module, "w") print("module;", file=module) print("", file=module) print("// clang-format off", file=module) @@ -588,130 +593,123 @@ class ModelWriter: return module @staticmethod - def endModule(module, args): + def _end_module(module, args): print("} // namespace " + args.namespace, file=module) module.close() @classmethod - def writeTable(cls, table, header, args): + def _write_table(cls, table, header, args): export = "export " if args.path_to_module else "" - DataTypeError = False + data_type_error = False create = table.create.value - sqlTableName = create.tableName - tableClass = cls.toClassName(sqlTableName, args) - tableMember = cls.toMemberName(sqlTableName, args) - tableSpec = tableClass + "_" - tableTemplateParameters = "" - tableRequiredInsertColumns = "" + sql_table_name = create.table_name + table_class = cls._to_class_name(sql_table_name, args) + table_member = cls._to_member_name(sql_table_name, args) + table_spec = table_class + "_" + table_template_parameters = "" + table_required_insert_columns = "" if args.generate_table_creation_helper: - creationHelperFunc = "create" + ("" if args.naming_style == "camel-case" else "_") + tableClass + creation_helper_func = "create" + ("" if args.naming_style == "camel-case" else "_") + table_class print(" " + export + "template", file=header) - print(" void " + creationHelperFunc + "(Db& db) {", file=header) - print(" db(R\"+++(DROP TABLE IF EXISTS " + sqlTableName + ")+++\");", file=header) - print(" db(R\"+++(" + create.createSql + ")+++\");", file=header) + print(" void " + creation_helper_func + "(Db& db) {", file=header) + print(" db(R\"+++(DROP TABLE IF EXISTS " + sql_table_name + ")+++\");", file=header) + print(" db(R\"+++(" + create.create_sql + ")+++\");", file=header) print(" }", file=header) print("", file=header) - print(" " + export + "struct " + tableSpec + " {", file=header) + print(" " + export + "struct " + table_spec + " {", file=header) for column in create.columns: - if column.isConstraint: + if column.is_constraint: continue - sqlColumnName = column.name - columnClass = cls.toClassName(sqlColumnName, args) - columnMember = cls.toMemberName(sqlColumnName, args) - columnType = column.type - if columnType == "UNKNOWN": - print( - "Error: datatype of %s.%s is not supported." - % (sqlTableName, sqlColumnName) - ) - DataTypeError = True - if columnType == "integral" and column.isUnsigned: - columnType = "unsigned_" + columnType - if columnType == "timestamp" and not args.suppress_timestamp_warning: + sql_column_name = column.name + column_class = cls._to_class_name(sql_column_name, args) + column_member = cls._to_member_name(sql_column_name, args) + column_type = column.type + if column_type == "UNKNOWN": + print(f"ERROR: SQL data type of {sql_table_name}.{sql_column_name} is not supported.") + data_type_error = True + if column_type == "integral" and column.is_unsigned: + column_type = "unsigned_" + column_type + if column_type == "timestamp" and not args.suppress_timestamp_warning: args.suppress_timestamp_warning = True - print( - "Warning: date and time values are assumed to be without timezone." - ) - print( - "Warning: If you are using types WITH timezones, your code has to deal with that." - ) + print("WARNING: date and time values are assumed to be without timezone.") + print("WARNING: If you are using types WITH timezones, your code has to deal with that.") print("You can disable this warning using --suppress-timestamp-warning") - print(" struct " + columnClass + " {", file=header) + print(" struct " + column_class + " {", file=header) print(" SQLPP_CREATE_NAME_TAG_FOR_SQL_AND_CPP(" - + cls.escape_if_reserved(sqlColumnName) + ", " + columnMember + ");" + + cls._escape_if_reserved(sql_column_name) + ", " + column_member + ");" , file=header) - columnIsConst = column.hasGeneratedValue - constPrefix = "const " if columnIsConst else "" - columnCanBeNull = not column.notNull and not column.isPrimaryKey and not column.hasSerialValue - if columnCanBeNull: - print(" using data_type = " + constPrefix + "std::optional<::sqlpp::" + columnType + ">;", file=header) + column_is_const = column.has_generated_value + const_prefix = "const " if column_is_const else "" + column_can_be_null = not column.not_null and not column.is_primary_key and not column.has_serial_value + if column_can_be_null: + print(" using data_type = " + const_prefix + "std::optional<::sqlpp::" + column_type + ">;", file=header) else: - print(" using data_type = " + constPrefix + "::sqlpp::" + columnType + ";", file=header) - columnHasDefault = column.hasDefaultValue or \ - column.hasSerialValue or \ - column.hasAutoValue or \ - column.hasGeneratedValue or \ - (args.assume_auto_id and sqlColumnName == "id") or \ - columnCanBeNull - if columnHasDefault: + print(" using data_type = " + const_prefix + "::sqlpp::" + column_type + ";", file=header) + column_has_default = column.has_default_value or \ + column.has_serial_value or \ + column.has_auto_value or \ + column.has_generated_value or \ + (args.assume_auto_id and sql_column_name == "id") or \ + column_can_be_null + if column_has_default: print(" using has_default = std::true_type;", file=header) else: print(" using has_default = std::false_type;", file=header) print(" };", file=header) - if tableTemplateParameters: - tableTemplateParameters += "," - tableTemplateParameters += "\n " + columnClass - if not columnHasDefault: - if tableRequiredInsertColumns: - tableRequiredInsertColumns += "," - tableRequiredInsertColumns += "\n sqlpp::column_t, " + columnClass + ">"; + if table_template_parameters: + table_template_parameters += "," + table_template_parameters += "\n " + column_class + if not column_has_default: + if table_required_insert_columns: + table_required_insert_columns += "," + table_required_insert_columns += "\n sqlpp::column_t, " + column_class + ">"; print(" SQLPP_CREATE_NAME_TAG_FOR_SQL_AND_CPP(" - + cls.escape_if_reserved(sqlTableName) + ", " + tableMember + ");" + + cls._escape_if_reserved(sql_table_name) + ", " + table_member + ");" , file=header) print(" template", file=header) print(" using _table_columns = sqlpp::table_columns;", file=header) print(" using _required_insert_columns = sqlpp::detail::type_set<" - + tableRequiredInsertColumns + + table_required_insert_columns + ">;", file=header) print(" };", file=header) print( - " " + export + "using " + tableClass + " = ::sqlpp::table_t<" + tableSpec + ">;", file=header) + " " + export + "using " + table_class + " = ::sqlpp::table_t<" + table_spec + ">;", file=header) print("", file=header) - if DataTypeError: - print("Error: unsupported SQL data type(s).") + if data_type_error: + print("ERROR: Unsupported SQL data type(s).") print("Possible solutions:") - print("A) Use the '--path-to-datatype-file' command line argument to map the SQL data type to a known sqlpp23 data type (example: README)") + print("A) Use the '--path-to-custom-types' command line argument to map the SQL data type to a known sqlpp23 data type (example: README)") print("B) Implement this data type in sqlpp23 (examples: sqlpp23/data_types) and in sqlpp23-ddl2cpp") print("C) Raise an issue on github") sys.exit(ExitCode.BAD_DATA_TYPE) # return non-zero error code, we might need it for automation @classmethod - def toClassName(cls, name, args): + def _to_class_name(cls, name, args): if args.naming_style == "camel-case": name = name.replace(".", "_") - return re.sub(r"(^|\s|[_0-9])(\S)", cls.repl_camel_case_func, name) + return re.sub(r"(^|\s|[_0-9])(\S)", cls._repl_camel_case_func, name) # otherwise return identity return name @classmethod - def toMemberName(cls, name, args): + def _to_member_name(cls, name, args): if args.naming_style == "camel-case": name = name.replace(".", "_") - return re.sub(r"(\s|_|[0-9])(\S)", cls.repl_camel_case_func, name) + return re.sub(r"(\s|_|[0-9])(\S)", cls._repl_camel_case_func, name) # otherwise return identity return name @staticmethod - def repl_camel_case_func(m): + def _repl_camel_case_func(m): if m.group(1) == "_": return m.group(2).upper() else: return m.group(1) + m.group(2).upper() @staticmethod - def escape_if_reserved(name): + def _escape_if_reserved(name): reserved_names = [ "BEGIN", "END", @@ -723,19 +721,19 @@ class ModelWriter: return name -def parseCommandlineArgs(): - argParser = argparse.ArgumentParser(prog="sqlpp23-ddl2cpp") - required = argParser.add_argument_group("Required parameters for code generation") +def parse_commandline_args(): + arg_parser = argparse.ArgumentParser(prog="sqlpp23-ddl2cpp") + required = arg_parser.add_argument_group("Required parameters for code generation") required.add_argument("--path-to-ddl", nargs="*", help="one or more path(s) to DDL input file(s)") required.add_argument("--namespace", help="namespace for generated table classes") - paths = argParser.add_argument_group("Paths", "Choose one or more paths for code generation:") + paths = arg_parser.add_argument_group("Paths", "Choose one or more paths for code generation:") paths.add_argument("--path-to-module", help="path to generated module file (also requires --module-name)") paths.add_argument("--path-to-header", help="path to generated header file (one file for all tables)") paths.add_argument("--path-to-header-directory", help="path to directory for generated header files (one file per table)") - paths.add_argument("--path-to-datatype-file", help="path to csv file containing additional sql2cpp file type mappings") + paths.add_argument("--path-to-custom-types", help="path to csv file defining aliases of existing SQL data types") - options = argParser.add_argument_group("Additional options") + options = arg_parser.add_argument_group("Additional options") options.add_argument("--module-name", help="name of the generated module (to be used with --path-to-module)") options.add_argument("--suppress-timestamp-warning", action="store_true", help="suppress show warning about date / time data types") options.add_argument("--assume-auto-id", action="store_true", help="assume column 'id' to have an automatic value as if AUTO_INCREMENT was specified (e.g. implicit for SQLite ROWID (default: False)") @@ -745,56 +743,70 @@ def parseCommandlineArgs(): options.add_argument("--use-import-std", action="store_true", help="import std as module instead of including the respective standard header files (default: False)") options.add_argument("--self-test", action="store_true", help="run parser self-test (this ignores all other arguments)") - args = argParser.parse_args() + args = arg_parser.parse_args() if args.self_test: return args if not args.path_to_ddl or not len(args.path_to_ddl): print("Missing argument --path-to-ddl") - argParser.print_help() + arg_parser.print_help() sys.exit(ExitCode.BAD_ARGS) if not args.namespace: print("Missing argument --namespace") - argParser.print_help() + arg_parser.print_help() sys.exit(ExitCode.BAD_ARGS) if not args.path_to_module and not args.path_to_header and not args.path_to_header_directory: print("Missing argument(s): at least one path for code generation") - argParser.print_help() + arg_parser.print_help() sys.exit(ExitCode.BAD_ARGS) if args.path_to_module and not args.module_name: print("Missing argument --module-name") - argParser.print_help() + arg_parser.print_help() sys.exit(ExitCode.BAD_ARGS) return args -def get_extended_types(filename): +def get_custom_types(filename): if not filename: return None import csv - with open(filename, newline="") as csvfile: - reader = csv.DictReader(csvfile, fieldnames=["baseType"], restkey="customTypes", delimiter=",") + with open(filename, newline="") as csv_file: + reader = csv.DictReader(csv_file, fieldnames=["base_type"], restkey="custom_types", delimiter=",") types = {} + def strip_garbage(name): + return name.strip(" \"'") + def clean_custom_type(name): + return strip_garbage(name).lower() + def clean_base_type(name): + name_ident = strip_garbage(name) + if DdlParser.is_base_type(name_ident): + return name_ident + name_from_camel = re.sub(r"[A-Z]", lambda m : ("_" if m.start() else "") + m[0].lower(), name_ident) + if DdlParser.is_base_type(name_from_camel): + return name_from_camel + print(f"ERROR: Custom types file uses an unknown base type {name_ident}") + sys.exit(ExitCode.BAD_CUSTOM_TYPES) for row in reader: - var_values = [clean_val for value in row["customTypes"] if (clean_val := value.strip(" \"'").lower())] - if var_values: - types[row["baseType"]] = var_values + values = [cleaned for value in row["custom_types"] if (cleaned := clean_custom_type(value)) != ""] + if values: + key = clean_base_type(row["base_type"]) + types[key] = values return types if __name__ == "__main__": - args = parseCommandlineArgs() + args = parse_commandline_args() if args.self_test: SelfTest.run() else: - customTypes = get_extended_types(args.path_to_datatype_file) - DdlParser.initialize(customTypes) - parsedDdls = DdlParser.parse_ddls(args.path_to_ddl) - ModelWriter.write(parsedDdls, args) + custom_types = get_custom_types(args.path_to_custom_types) + DdlParser.initialize(custom_types) + parsed_ddls = DdlParser.parse_ddls(args.path_to_ddl) + ModelWriter.write(parsed_ddls, args) sys.exit(ExitCode.SUCCESS) diff --git a/tests/scripts/CMakeLists.txt b/tests/scripts/CMakeLists.txt index 412bd5f3..ccf0ba5f 100644 --- a/tests/scripts/CMakeLists.txt +++ b/tests/scripts/CMakeLists.txt @@ -55,7 +55,7 @@ if (${Python3_Interpreter_FOUND}) "--path-to-header" "${CMAKE_CURRENT_BINARY_DIR}/fail.h" "--namespace" "test") set_tests_properties(sqlpp23.scripts.ddl2cpp.bad_has_parse_error PROPERTIES - PASS_REGULAR_EXPRESSION "ERROR: failed to parse.*") + PASS_REGULAR_EXPRESSION "ERROR: Failed to parse.*") add_test(NAME sqlpp23.scripts.ddl2cpp.good_succeeds COMMAND "${Python3_EXECUTABLE}" "${CMAKE_CURRENT_LIST_DIR}/../../scripts/sqlpp23-ddl2cpp" @@ -100,7 +100,7 @@ if (${Python3_Interpreter_FOUND}) "--namespace" "test" "--suppress-timestamp-warning") set_tests_properties("${bad_type_test_name}" PROPERTIES - PASS_REGULAR_EXPRESSION "Error: unsupported SQL data type\\(s\\).") + PASS_REGULAR_EXPRESSION "ERROR: Unsupported SQL data type\\(s\\).") endforeach() # Custom types defined in a CSV file @@ -109,7 +109,7 @@ if (${Python3_Interpreter_FOUND}) add_custom_command( OUTPUT "${sqlpp.scripts.generated.custom_type_sql.include}" COMMAND "${Python3_EXECUTABLE}" "${CMAKE_CURRENT_LIST_DIR}/../../scripts/sqlpp23-ddl2cpp" - "--path-to-datatype-file=${CMAKE_CURRENT_LIST_DIR}/custom_types.csv" + "--path-to-custom-types=${CMAKE_CURRENT_LIST_DIR}/custom_types.csv" "--path-to-ddl" "${CMAKE_CURRENT_LIST_DIR}/${custom_type_sql}.sql" "--path-to-header" "${sqlpp.scripts.generated.custom_type_sql.include}" "--namespace" "test"