diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index fb1ae4f91..878a3bb4e 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -144,6 +144,10 @@ class InvalidColumns(Exception): pass +class ExpandError(Exception): + pass + + _COUNTS_TABLE_CREATE_SQL = """ CREATE TABLE IF NOT EXISTS [{}]( [table] TEXT PRIMARY KEY, @@ -1156,6 +1160,8 @@ def extract_expand( ): "Use expand function to transform values in column and extract them into a new table" table = table or column + # Track whether we are creating a many-to-many or many-to-one relation + m2m, m21 = (False, False) fk_column = fk_column or "{}_id".format(table) self.add_column(fk_column, fk_column_type) for row_pk, row in self.pks_and_rows_where(): @@ -1164,10 +1170,39 @@ def extract_expand( if isinstance(expanded, dict): new_pk = self.db[table].insert(expanded, pk="id", replace=True).last_pk self.update(row_pk, {fk_column: new_pk}) - # Can drop the original column now - self.transform(drop=[column]) - # And add that foreign key - self.add_foreign_key(fk_column, table, "id") + elif isinstance(expanded, list): + if not len(expanded): + m21 = True + continue + elif isinstance(expanded[0], dict): + m2m = True + self.m2m(table, expanded, pk="id", our_id=row_pk, alter=True) + else: + m21 = True + pk_column = "{}_id".format(self.name) + new_rows = [ + { + "id": index, + pk_column: row_pk, + "value": value, + } + for index, value in enumerate(expanded, start=1) + ] + self.db[table].insert_all( + new_rows, + pk=('id', pk_column), + foreign_keys=[(pk_column, self.name)], + replace=True) + else: + raise ExpandError("expanded value needs to be list or dict") + + if m21 or m2m: + self.transform(drop=[column, fk_column]) + else: + # Can drop the original column now + self.transform(drop=[column]) + # And add that foreign key + self.add_foreign_key(fk_column, table, "id") return self def create_index(self, columns, index_name=None, unique=False, if_not_exists=False): @@ -2081,10 +2116,11 @@ def m2m( lookup=None, m2m_table=None, alter=False, + our_id = None ): if isinstance(other_table, str): other_table = self.db.table(other_table, pk=pk) - our_id = self.last_pk + our_id = our_id or self.last_pk if lookup is not None: assert record_or_iterable is None, "Provide lookup= or record, not both" else: diff --git a/tests/test_extract.py b/tests/test_extract.py index 2507c758b..9396be4ff 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -193,3 +193,52 @@ def test_extract_expand(fresh_db): table="trees", column="species_id", other_table="species", other_column="id" ) ] + + +def test_extract_expand_m21(fresh_db): + fresh_db["trees"].insert( + {"id": 1, "names": '["Palm", "Arecaceae"]'}, + pk="id", + ) + assert fresh_db.table_names() == ["trees"] + fresh_db["trees"].extract_expand( + "names", expand=json.loads, table="names", pk="id" + ) + assert set(fresh_db.table_names()) == {"trees", "names"} + assert list(fresh_db["trees"].rows) == [ + {"id": 1}, + ] + assert list(fresh_db["names"].rows) == [ + {"id": 1, "trees_id": 1, "value": "Palm"}, + {"id": 2, "trees_id": 1, "value": "Arecaceae"}, + ] + assert fresh_db["names"].foreign_keys == [ + ForeignKey( + table="names", column="trees_id", other_table="trees", other_column="id" + ) + ] + + +def test_extract_expand_m2m(fresh_db): + fresh_db["trees"].insert( + {"id": 1, "tags": '[{"id": 1, "name": "warm-climate"}, {"id": 2, "name": "green-leaves"}]'}, + pk="id", + ) + assert fresh_db.table_names() == ["trees"] + fresh_db["trees"].extract_expand( + "tags", expand=json.loads, table="tags", pk="id" + ) + assert set(fresh_db.table_names()) == {"trees", "tags", "tags_trees"} + assert list(fresh_db["trees"].rows) == [{"id": 1}] + assert list(fresh_db["tags"].rows) == [ + {"id": 1, "name": "warm-climate"}, + {"id": 2, "name": "green-leaves"}, + ] + assert list(fresh_db["tags_trees"].rows) == [ + {"trees_id": 1, "tags_id": 1}, + {"trees_id": 1, "tags_id": 2}, + ] + assert fresh_db["tags_trees"].foreign_keys == [ + ForeignKey(table="tags_trees", column="trees_id", other_table="trees", other_column="id"), + ForeignKey(table="tags_trees", column="tags_id", other_table="tags", other_column="id") + ]