diff --git a/process_sql.py b/process_sql.py index 839612e..af098da 100644 --- a/process_sql.py +++ b/process_sql.py @@ -26,6 +26,7 @@ import json import sqlite3 +from typing import List from nltk import word_tokenize CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') @@ -112,11 +113,24 @@ def get_schema_from_json(fpath): return schema +def find_quoted_strings(text: str) -> List[int]: + quote_idxs = [] + idx = 0 + escaped = False + while idx < len(text): + if escaped: + idx, escaped = idx + 1, False + c = text[idx] + if c == "\\": + escaped = True + elif c == "\'" or c == "\"": + quote_idxs.append(idx) + idx += 1 + return quote_idxs + def tokenize(string): - string = str(string) - string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? - quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] + quote_idxs = find_quoted_strings(string) assert len(quote_idxs) % 2 == 0, "Unexpected quote" # keep string value as token