@@ -21,6 +21,10 @@ def living_in(the_database):
2121 return None
2222 return DatabaseDatabase (db )
2323
24+ def dict_factory (cursor , row ) -> dict :
25+ fields = [column [0 ] for column in cursor .description ]
26+ return dict (zip (fields , row ))
27+
2428# Database
2529
2630def db_error (f ):
@@ -50,9 +54,10 @@ def protected(*args, **kwargs):
5054 return decorate
5155
5256class DatabaseDatabase :
53- def __init__ (self , db ):
57+ def __init__ (self , db : sqlite3 . Connection ):
5458 self ._db = db
55- self .q = db .cursor ()
59+ self ._db .row_factory = dict_factory
60+ self .q = self ._db .cursor ()
5661
5762 # Set up collations
5863 self ._db .create_collation ("alphanum" , _collate_alphanum )
@@ -63,7 +68,7 @@ def __getattr__(self, attr):
6368 return getattr (self ._db , attr )
6469
6570 def get_count (self ):
66- return self .q .fetchone ()[0 ]
71+ return self .q .fetchone ()['count(*)' ]
6772
6873 def save (self ):
6974 self .commit ()
@@ -217,19 +222,19 @@ def get_service(self, id=None, key=None) -> Optional[Service]:
217222 error ("ID or key required to get service" )
218223 return None
219224 service = self .q .fetchone ()
220- return Service (* service )
225+ return Service (** service )
221226
222227 @db_error_default (list ())
223228 def get_services (self , enabled = True , disabled = False ) -> List [Service ]:
224229 services = list ()
225230 if enabled :
226231 self .q .execute ("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 1" )
227232 for service in self .q .fetchall ():
228- services .append (Service (* service ))
233+ services .append (Service (** service ))
229234 if disabled :
230235 self .q .execute ("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 0" )
231236 for service in self .q .fetchall ():
232- services .append (Service (* service ))
237+ services .append (Service (** service ))
233238 return services
234239
235240 @db_error_default (None )
@@ -242,7 +247,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
242247 if stream is None :
243248 error ("Stream {} not found" .format (id ))
244249 return None
245- stream = Stream (* stream )
250+ stream = Stream (** stream )
246251 elif service_tuple is not None :
247252 service , show_key = service_tuple
248253 debug ("Getting stream for {}/{}" .format (service , show_key ))
@@ -252,7 +257,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
252257 if stream is None :
253258 error ("Stream {} not found" .format (id ))
254259 return None
255- stream = Stream (* stream )
260+ stream = Stream (** stream )
256261 else :
257262 error ("Nothing provided to get stream" )
258263 return None
@@ -299,7 +304,7 @@ def get_streams(self, service=None, show=None, active=True, unmatched=False, mis
299304 return list ()
300305
301306 streams = self .q .fetchall ()
302- streams = [Stream (* stream ) for stream in streams ]
307+ streams = [Stream (** stream ) for stream in streams ]
303308 for stream in streams :
304309 stream .show = self .get_show (id = stream .show ) # convert show id to show model
305310 return streams
@@ -359,7 +364,7 @@ def get_lite_streams(self, service=None, show=None, missing_link=False) -> List[
359364 return list ()
360365
361366 lite_streams = self .q .fetchall ()
362- lite_streams = [LiteStream (* lite_stream ) for lite_stream in lite_streams ]
367+ lite_streams = [LiteStream (** lite_stream ) for lite_stream in lite_streams ]
363368 return lite_streams
364369
365370 @db_error
@@ -381,19 +386,19 @@ def get_link_site(self, id:str=None, key:str=None) -> Optional[LinkSite]:
381386 site = self .q .fetchone ()
382387 if site is None :
383388 return None
384- return LinkSite (* site )
389+ return LinkSite (** site )
385390
386391 @db_error_default (list ())
387392 def get_link_sites (self , enabled = True , disabled = False ) -> List [LinkSite ]:
388393 sites = list ()
389394 if enabled :
390395 self .q .execute ("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 1" )
391396 for link in self .q .fetchall ():
392- sites .append (LinkSite (* link ))
397+ sites .append (LinkSite (** link ))
393398 if disabled :
394399 self .q .execute ("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 0" )
395400 for link in self .q .fetchall ():
396- sites .append (LinkSite (* link ))
401+ sites .append (LinkSite (** link ))
397402 return sites
398403
399404 @db_error_default (list ())
@@ -404,7 +409,7 @@ def get_links(self, show:Show=None) -> List[Link]:
404409 # Get all streams with show ID
405410 self .q .execute ("SELECT site, show, site_key FROM Links WHERE show = ?" , (show .id ,))
406411 links = self .q .fetchall ()
407- links = [Link (* link ) for link in links ]
412+ links = [Link (** link ) for link in links ]
408413 return links
409414 else :
410415 error ("A show must be provided to get links" )
@@ -418,7 +423,7 @@ def get_link(self, show: Show, link_site: LinkSite) -> Optional[Link]:
418423 link = self .q .fetchone ()
419424 if link is None :
420425 return None
421- link = Link (* link )
426+ link = Link (** link )
422427 return link
423428
424429 @db_error_default (False )
@@ -449,15 +454,15 @@ def add_link(self, raw_show: UnprocessedShow, show_id, commit=True):
449454
450455 # Shows
451456 @db_error_default (list ())
452- def get_shows (self , missing_length = False , missing_stream = False , enabled = True , delayed = False ) -> [Show ]:
457+ def get_shows (self , missing_length = False , missing_stream = False , enabled = True , delayed = False ) -> list [Show ]:
453458 shows = list ()
454459 if missing_length :
455460 self .q .execute (
456- "SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
461+ "SELECT id, name, name_en, length, type AS show_type , has_source, is_nsfw, enabled, delayed FROM Shows \
457462 WHERE (length IS NULL OR length = '' OR length = 0) AND enabled = ?" , (enabled ,))
458463 elif missing_stream :
459464 self .q .execute (
460- "SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows show\
465+ "SELECT id, name, name_en, length, type AS show_type , has_source, is_nsfw, enabled, delayed FROM Shows show\
461466 WHERE (SELECT count(*) FROM Streams stream, Services service \
462467 WHERE stream.show = show.id \
463468 AND stream.active = 1 \
@@ -467,14 +472,14 @@ def get_shows(self, missing_length=False, missing_stream=False, enabled=True, de
467472 (enabled ,))
468473 elif delayed :
469474 self .q .execute (
470- "SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
475+ "SELECT id, name, name_en, length, type AS show_type , has_source, is_nsfw, enabled, delayed FROM Shows \
471476 WHERE delayed = 1 AND enabled = ?" , (enabled ,))
472477 else :
473478 self .q .execute (
474- "SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
479+ "SELECT id, name, name_en, length, type AS show_type , has_source, is_nsfw, enabled, delayed FROM Shows \
475480 WHERE enabled = ?" , (enabled ,))
476481 for show in self .q .fetchall ():
477- show = Show (* show )
482+ show = Show (** show )
478483 show .aliases = self .get_aliases (show )
479484 shows .append (show )
480485 return shows
@@ -492,12 +497,12 @@ def get_show(self, id=None, stream=None) -> Optional[Show]:
492497 error ("Show ID not provided to get_show" )
493498 return None
494499 self .q .execute (
495- "SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
500+ "SELECT id, name, name_en, length, type AS show_type , has_source, is_nsfw, enabled, delayed FROM Shows \
496501 WHERE id = ?" , (id ,))
497502 show = self .q .fetchone ()
498503 if show is None :
499504 return None
500- show = Show (* show )
505+ show = Show (** show )
501506 show .aliases = self .get_aliases (show )
502507 return show
503508
@@ -506,19 +511,19 @@ def get_show_by_name(self, name) -> Optional[Show]:
506511 #debug("Getting show from database")
507512
508513 self .q .execute (
509- "SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
514+ "SELECT id, name, name_en, length, type AS show_type , has_source, is_nsfw, enabled, delayed FROM Shows \
510515 WHERE name = ?" , (name ,))
511516 show = self .q .fetchone ()
512517 if show is None :
513518 return None
514- show = Show (* show )
519+ show = Show (** show )
515520 show .aliases = self .get_aliases (show )
516521 return show
517522
518523 @db_error_default (list ())
519- def get_aliases (self , show : Show ) -> [str ]:
524+ def get_aliases (self , show : Show ) -> list [str ]:
520525 self .q .execute ("SELECT alias FROM Aliases where show = ?" , (show .id ,))
521- return [s for s , in self .q .fetchall ()]
526+ return [s [ "alias" ] for s in self .q .fetchall ()]
522527
523528 @db_error_default (None )
524529 def add_show (self , raw_show : UnprocessedShow , commit = True ) -> int :
@@ -556,7 +561,7 @@ def update_show(self, show_id: str, raw_show: UnprocessedShow, commit=True):
556561 is_nsfw = raw_show .is_nsfw
557562
558563 if name_en :
559- self .q .execute ("UPDATE Shows SET name_en = ? WHERE id = ?" , (name_en , show_id ))
564+ self .q .execute ("UPDATE Shows SET name_en = ? WHERE id = ?" , (name_en , show_id ))
560565 if length != 0 :
561566 self .q .execute ("UPDATE Shows SET length = ? WHERE id = ?" , (length , show_id ))
562567 self .q .execute ("UPDATE Shows SET type = ?, has_source = ?, is_nsfw = ? WHERE id = ?" , (show_type , has_source , is_nsfw , show_id ))
@@ -599,10 +604,10 @@ def stream_has_episode(self, stream: Stream, episode_num) -> bool:
599604
600605 @db_error_default (None )
601606 def get_latest_episode (self , show : Show ) -> Optional [Episode ]:
602- self .q .execute ("SELECT episode, post_url FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1" , (show .id ,))
607+ self .q .execute ("SELECT episode AS number , post_url AS link FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1" , (show .id ,))
603608 data = self .q .fetchone ()
604609 if data is not None :
605- return Episode (data [ 0 ], None , data [ 1 ], None )
610+ return Episode (** data )
606611 return None
607612
608613 @db_error
@@ -614,9 +619,9 @@ def add_episode(self, show, episode_num, post_url):
614619 @db_error_default (list ())
615620 def get_episodes (self , show , ensure_sorted = True ) -> List [Episode ]:
616621 episodes = list ()
617- self .q .execute ("SELECT episode, post_url FROM Episodes WHERE show = ?" , (show .id ,))
622+ self .q .execute ("SELECT episode AS number , post_url AS link FROM Episodes WHERE show = ?" , (show .id ,))
618623 for data in self .q .fetchall ():
619- episodes .append (Episode (data [ 0 ], None , data [ 1 ], None ))
624+ episodes .append (Episode (** data ))
620625
621626 if ensure_sorted :
622627 episodes = sorted (episodes , key = lambda e : e .number )
@@ -625,23 +630,23 @@ def get_episodes(self, show, ensure_sorted=True) -> List[Episode]:
625630 # Scores
626631 @db_error_default (list ())
627632 def get_show_scores (self , show : Show ) -> List [EpisodeScore ]:
628- self .q .execute ("SELECT episode, site, score FROM Scores WHERE show=?" , (show .id ,))
629- return [EpisodeScore (show .id , * s ) for s in self .q .fetchall ()]
633+ self .q .execute ("SELECT episode, site AS site_id , score FROM Scores WHERE show=?" , (show .id ,))
634+ return [EpisodeScore (show_id = show .id , * *s ) for s in self .q .fetchall ()]
630635
631636 @db_error_default (list ())
632637 def get_episode_scores (self , show : Show , episode : Episode ) -> List [EpisodeScore ]:
633- self .q .execute ("SELECT site, score FROM Scores WHERE show=? AND episode=?" , (show .id , episode .number ))
634- return [EpisodeScore (show .id , episode .number , * s ) for s in self .q .fetchall ()]
638+ self .q .execute ("SELECT site AS site_id , score FROM Scores WHERE show=? AND episode=?" , (show .id , episode .number ))
639+ return [EpisodeScore (show_id = show .id , episode = episode .number , * *s ) for s in self .q .fetchall ()]
635640
636641 @db_error_default (None )
637642 def get_episode_score_avg (self , show : Show , episode : Episode ) -> Optional [EpisodeScore ]:
638643 debug ("Calculating avg score for {} ({})" .format (show .name , show .id ))
639644 self .q .execute ("SELECT score FROM Scores WHERE show=? AND episode=?" , (show .id , episode .number ))
640- scores = [s [0 ] for s in self .q .fetchall ()]
645+ scores = [s ["score" ] for s in self .q .fetchall ()]
641646 if len (scores ) > 0 :
642647 score = sum (scores )/ len (scores )
643648 debug (" Score: {} (from {} scores)" .format (score , len (scores )))
644- return EpisodeScore (show .id , episode .number , None , score )
649+ return EpisodeScore (show_id = show .id , episode = episode .number , score = score )
645650 return None
646651
647652 @db_error
@@ -664,7 +669,7 @@ def get_poll_site(self, id:str=None, key:str=None) -> Optional[PollSite]:
664669 site = self .q .fetchone ()
665670 if site is None :
666671 return None
667- return PollSite (* site )
672+ return PollSite (** site )
668673
669674 @db_error
670675 def add_poll (self , show : Show , episode : Episode , site : PollSite , poll_id , commit = True ):
@@ -681,24 +686,24 @@ def update_poll_score(self, poll: Poll, score, commit=True):
681686
682687 @db_error_default (None )
683688 def get_poll (self , show : Show , episode : Episode ):
684- self .q .execute ("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ? AND episode = ?" , (show .id , episode .number ))
689+ self .q .execute ("SELECT show AS show_id , episode, poll_service AS service , poll_id AS id , timestamp AS date , score FROM Polls WHERE show = ? AND episode = ?" , (show .id , episode .number ))
685690 poll = self .q .fetchone ()
686691 if poll is None :
687692 return None
688- return Poll (* poll )
693+ return Poll (** poll )
689694
690695 @db_error_default (list ())
691696 def get_polls (self , show : Show = None , missing_score = False ):
692697 polls = list ()
693698 if show is not None :
694- self .q .execute ("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ?" , (show .id ,))
699+ self .q .execute ("SELECT show AS show_id , episode, poll_service AS service , poll_id AS id , timestamp AS date , score FROM Polls WHERE show = ?" , (show .id ,))
695700 elif missing_score :
696- self .q .execute ("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)" )
701+ self .q .execute ("SELECT show AS show_id , episode, poll_service AS service , poll_id AS id , timestamp AS date , score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)" )
697702 else :
698703 error ("Need to select a show to get polls" )
699704 return list ()
700705 for poll in self .q .fetchall ():
701- polls .append (Poll (* poll ))
706+ polls .append (Poll (** poll ))
702707 return polls
703708
704709 # Searching
@@ -713,8 +718,8 @@ def search_show_ids_by_names(self, *names, exact=False) -> Set[Show]:
713718 self .q .execute ("SELECT show, name FROM ShowNames WHERE name = ? COLLATE alphanum" , (name ,))
714719 matched = self .q .fetchall ()
715720 for match in matched :
716- debug (" Found match: {} | {}" .format (match [0 ], match [1 ]))
717- shows .add (match [0 ])
721+ debug (" Found match: {} | {}" .format (match ['show' ], match ['name' ]))
722+ shows .add (match ['show' ])
718723 return shows
719724
720725# Helper methods
0 commit comments