Skip to content

Commit b08843b

Browse files
committed
Refactor: change row_factory from tuple to dict
Make arguments in class initialisation not order-sensitive Adjust existing calls to match the new format Assign some default values based on existing calls Fix inconsistent indentation
1 parent 3950b59 commit b08843b

File tree

11 files changed

+79
-82
lines changed

11 files changed

+79
-82
lines changed

src/data/database.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def living_in(the_database):
2121
return None
2222
return DatabaseDatabase(db)
2323

24+
def dict_factory(cursor, row) -> dict:
25+
fields = [column[0] for column in cursor.description]
26+
return dict(zip(fields, row))
27+
2428
# Database
2529

2630
def db_error(f):
@@ -50,9 +54,10 @@ def protected(*args, **kwargs):
5054
return decorate
5155

5256
class DatabaseDatabase:
53-
def __init__(self, db):
57+
def __init__(self, db: sqlite3.Connection):
5458
self._db = db
55-
self.q = db.cursor()
59+
self._db.row_factory = dict_factory
60+
self.q = self._db.cursor()
5661

5762
# Set up collations
5863
self._db.create_collation("alphanum", _collate_alphanum)
@@ -63,7 +68,7 @@ def __getattr__(self, attr):
6368
return getattr(self._db, attr)
6469

6570
def get_count(self):
66-
return self.q.fetchone()[0]
71+
return self.q.fetchone()['count(*)']
6772

6873
def save(self):
6974
self.commit()
@@ -217,19 +222,19 @@ def get_service(self, id=None, key=None) -> Optional[Service]:
217222
error("ID or key required to get service")
218223
return None
219224
service = self.q.fetchone()
220-
return Service(*service)
225+
return Service(**service)
221226

222227
@db_error_default(list())
223228
def get_services(self, enabled=True, disabled=False) -> List[Service]:
224229
services = list()
225230
if enabled:
226231
self.q.execute("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 1")
227232
for service in self.q.fetchall():
228-
services.append(Service(*service))
233+
services.append(Service(**service))
229234
if disabled:
230235
self.q.execute("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 0")
231236
for service in self.q.fetchall():
232-
services.append(Service(*service))
237+
services.append(Service(**service))
233238
return services
234239

235240
@db_error_default(None)
@@ -242,7 +247,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
242247
if stream is None:
243248
error("Stream {} not found".format(id))
244249
return None
245-
stream = Stream(*stream)
250+
stream = Stream(**stream)
246251
elif service_tuple is not None:
247252
service, show_key = service_tuple
248253
debug("Getting stream for {}/{}".format(service, show_key))
@@ -252,7 +257,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
252257
if stream is None:
253258
error("Stream {} not found".format(id))
254259
return None
255-
stream = Stream(*stream)
260+
stream = Stream(**stream)
256261
else:
257262
error("Nothing provided to get stream")
258263
return None
@@ -299,7 +304,7 @@ def get_streams(self, service=None, show=None, active=True, unmatched=False, mis
299304
return list()
300305

301306
streams = self.q.fetchall()
302-
streams = [Stream(*stream) for stream in streams]
307+
streams = [Stream(**stream) for stream in streams]
303308
for stream in streams:
304309
stream.show = self.get_show(id=stream.show) # convert show id to show model
305310
return streams
@@ -359,7 +364,7 @@ def get_lite_streams(self, service=None, show=None, missing_link=False) -> List[
359364
return list()
360365

361366
lite_streams = self.q.fetchall()
362-
lite_streams = [LiteStream(*lite_stream) for lite_stream in lite_streams]
367+
lite_streams = [LiteStream(**lite_stream) for lite_stream in lite_streams]
363368
return lite_streams
364369

365370
@db_error
@@ -381,19 +386,19 @@ def get_link_site(self, id:str=None, key:str=None) -> Optional[LinkSite]:
381386
site = self.q.fetchone()
382387
if site is None:
383388
return None
384-
return LinkSite(*site)
389+
return LinkSite(**site)
385390

386391
@db_error_default(list())
387392
def get_link_sites(self, enabled=True, disabled=False) -> List[LinkSite]:
388393
sites = list()
389394
if enabled:
390395
self.q.execute("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 1")
391396
for link in self.q.fetchall():
392-
sites.append(LinkSite(*link))
397+
sites.append(LinkSite(**link))
393398
if disabled:
394399
self.q.execute("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 0")
395400
for link in self.q.fetchall():
396-
sites.append(LinkSite(*link))
401+
sites.append(LinkSite(**link))
397402
return sites
398403

399404
@db_error_default(list())
@@ -404,7 +409,7 @@ def get_links(self, show:Show=None) -> List[Link]:
404409
# Get all streams with show ID
405410
self.q.execute("SELECT site, show, site_key FROM Links WHERE show = ?", (show.id,))
406411
links = self.q.fetchall()
407-
links = [Link(*link) for link in links]
412+
links = [Link(**link) for link in links]
408413
return links
409414
else:
410415
error("A show must be provided to get links")
@@ -418,7 +423,7 @@ def get_link(self, show: Show, link_site: LinkSite) -> Optional[Link]:
418423
link = self.q.fetchone()
419424
if link is None:
420425
return None
421-
link = Link(*link)
426+
link = Link(**link)
422427
return link
423428

424429
@db_error_default(False)
@@ -449,15 +454,15 @@ def add_link(self, raw_show: UnprocessedShow, show_id, commit=True):
449454

450455
# Shows
451456
@db_error_default(list())
452-
def get_shows(self, missing_length=False, missing_stream=False, enabled=True, delayed=False) -> [Show]:
457+
def get_shows(self, missing_length=False, missing_stream=False, enabled=True, delayed=False) -> list[Show]:
453458
shows = list()
454459
if missing_length:
455460
self.q.execute(
456-
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
461+
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
457462
WHERE (length IS NULL OR length = '' OR length = 0) AND enabled = ?", (enabled,))
458463
elif missing_stream:
459464
self.q.execute(
460-
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows show\
465+
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows show\
461466
WHERE (SELECT count(*) FROM Streams stream, Services service \
462467
WHERE stream.show = show.id \
463468
AND stream.active = 1 \
@@ -467,14 +472,14 @@ def get_shows(self, missing_length=False, missing_stream=False, enabled=True, de
467472
(enabled,))
468473
elif delayed:
469474
self.q.execute(
470-
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
475+
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
471476
WHERE delayed = 1 AND enabled = ?", (enabled,))
472477
else:
473478
self.q.execute(
474-
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
479+
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
475480
WHERE enabled = ?", (enabled,))
476481
for show in self.q.fetchall():
477-
show = Show(*show)
482+
show = Show(**show)
478483
show.aliases = self.get_aliases(show)
479484
shows.append(show)
480485
return shows
@@ -492,12 +497,12 @@ def get_show(self, id=None, stream=None) -> Optional[Show]:
492497
error("Show ID not provided to get_show")
493498
return None
494499
self.q.execute(
495-
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
500+
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
496501
WHERE id = ?", (id,))
497502
show = self.q.fetchone()
498503
if show is None:
499504
return None
500-
show = Show(*show)
505+
show = Show(**show)
501506
show.aliases = self.get_aliases(show)
502507
return show
503508

@@ -506,19 +511,19 @@ def get_show_by_name(self, name) -> Optional[Show]:
506511
#debug("Getting show from database")
507512

508513
self.q.execute(
509-
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
514+
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
510515
WHERE name = ?", (name,))
511516
show = self.q.fetchone()
512517
if show is None:
513518
return None
514-
show = Show(*show)
519+
show = Show(**show)
515520
show.aliases = self.get_aliases(show)
516521
return show
517522

518523
@db_error_default(list())
519-
def get_aliases(self, show: Show) -> [str]:
524+
def get_aliases(self, show: Show) -> list[str]:
520525
self.q.execute("SELECT alias FROM Aliases where show = ?", (show.id,))
521-
return [s for s, in self.q.fetchall()]
526+
return [s["alias"] for s in self.q.fetchall()]
522527

523528
@db_error_default(None)
524529
def add_show(self, raw_show: UnprocessedShow, commit=True) -> int:
@@ -556,7 +561,7 @@ def update_show(self, show_id: str, raw_show: UnprocessedShow, commit=True):
556561
is_nsfw = raw_show.is_nsfw
557562

558563
if name_en:
559-
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
564+
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
560565
if length != 0:
561566
self.q.execute("UPDATE Shows SET length = ? WHERE id = ?", (length, show_id))
562567
self.q.execute("UPDATE Shows SET type = ?, has_source = ?, is_nsfw = ? WHERE id = ?", (show_type, has_source, is_nsfw, show_id))
@@ -599,10 +604,10 @@ def stream_has_episode(self, stream: Stream, episode_num) -> bool:
599604

600605
@db_error_default(None)
601606
def get_latest_episode(self, show: Show) -> Optional[Episode]:
602-
self.q.execute("SELECT episode, post_url FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1", (show.id,))
607+
self.q.execute("SELECT episode AS number, post_url AS link FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1", (show.id,))
603608
data = self.q.fetchone()
604609
if data is not None:
605-
return Episode(data[0], None, data[1], None)
610+
return Episode(**data)
606611
return None
607612

608613
@db_error
@@ -614,9 +619,9 @@ def add_episode(self, show, episode_num, post_url):
614619
@db_error_default(list())
615620
def get_episodes(self, show, ensure_sorted=True) -> List[Episode]:
616621
episodes = list()
617-
self.q.execute("SELECT episode, post_url FROM Episodes WHERE show = ?", (show.id,))
622+
self.q.execute("SELECT episode AS number, post_url AS link FROM Episodes WHERE show = ?", (show.id,))
618623
for data in self.q.fetchall():
619-
episodes.append(Episode(data[0], None, data[1], None))
624+
episodes.append(Episode(**data))
620625

621626
if ensure_sorted:
622627
episodes = sorted(episodes, key=lambda e: e.number)
@@ -625,23 +630,23 @@ def get_episodes(self, show, ensure_sorted=True) -> List[Episode]:
625630
# Scores
626631
@db_error_default(list())
627632
def get_show_scores(self, show: Show) -> List[EpisodeScore]:
628-
self.q.execute("SELECT episode, site, score FROM Scores WHERE show=?", (show.id,))
629-
return [EpisodeScore(show.id, *s) for s in self.q.fetchall()]
633+
self.q.execute("SELECT episode, site AS site_id, score FROM Scores WHERE show=?", (show.id,))
634+
return [EpisodeScore(show_id=show.id, **s) for s in self.q.fetchall()]
630635

631636
@db_error_default(list())
632637
def get_episode_scores(self, show: Show, episode: Episode) -> List[EpisodeScore]:
633-
self.q.execute("SELECT site, score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
634-
return [EpisodeScore(show.id, episode.number, *s) for s in self.q.fetchall()]
638+
self.q.execute("SELECT site AS site_id, score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
639+
return [EpisodeScore(show_id=show.id, episode=episode.number, **s) for s in self.q.fetchall()]
635640

636641
@db_error_default(None)
637642
def get_episode_score_avg(self, show: Show, episode: Episode) -> Optional[EpisodeScore]:
638643
debug("Calculating avg score for {} ({})".format(show.name, show.id))
639644
self.q.execute("SELECT score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
640-
scores = [s[0] for s in self.q.fetchall()]
645+
scores = [s["score"] for s in self.q.fetchall()]
641646
if len(scores) > 0:
642647
score = sum(scores)/len(scores)
643648
debug(" Score: {} (from {} scores)".format(score, len(scores)))
644-
return EpisodeScore(show.id, episode.number, None, score)
649+
return EpisodeScore(show_id=show.id, episode=episode.number, score=score)
645650
return None
646651

647652
@db_error
@@ -664,7 +669,7 @@ def get_poll_site(self, id:str=None, key:str=None) -> Optional[PollSite]:
664669
site = self.q.fetchone()
665670
if site is None:
666671
return None
667-
return PollSite(*site)
672+
return PollSite(**site)
668673

669674
@db_error
670675
def add_poll(self, show: Show, episode: Episode, site: PollSite, poll_id, commit=True):
@@ -681,24 +686,24 @@ def update_poll_score(self, poll: Poll, score, commit=True):
681686

682687
@db_error_default(None)
683688
def get_poll(self, show: Show, episode: Episode):
684-
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ? AND episode = ?", (show.id, episode.number))
689+
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE show = ? AND episode = ?", (show.id, episode.number))
685690
poll = self.q.fetchone()
686691
if poll is None:
687692
return None
688-
return Poll(*poll)
693+
return Poll(**poll)
689694

690695
@db_error_default(list())
691696
def get_polls(self, show: Show=None, missing_score=False):
692697
polls = list()
693698
if show is not None:
694-
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ?", (show.id,))
699+
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE show = ?", (show.id,))
695700
elif missing_score:
696-
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)")
701+
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)")
697702
else:
698703
error("Need to select a show to get polls")
699704
return list()
700705
for poll in self.q.fetchall():
701-
polls.append(Poll(*poll))
706+
polls.append(Poll(**poll))
702707
return polls
703708

704709
# Searching
@@ -713,8 +718,8 @@ def search_show_ids_by_names(self, *names, exact=False) -> Set[Show]:
713718
self.q.execute("SELECT show, name FROM ShowNames WHERE name = ? COLLATE alphanum", (name,))
714719
matched = self.q.fetchall()
715720
for match in matched:
716-
debug(" Found match: {} | {}".format(match[0], match[1]))
717-
shows.add(match[0])
721+
debug(" Found match: {} | {}".format(match['show'], match['name']))
722+
shows.add(match['show'])
718723
return shows
719724

720725
# Helper methods

0 commit comments

Comments
 (0)