|
1 | 1 | import itertools |
| 2 | +import warnings |
2 | 3 | from abc import ABC |
3 | 4 | from datetime import datetime |
4 | 5 | from typing import Any, Iterable, List, Optional, Tuple, Union |
5 | 6 |
|
6 | 7 | import numpy as np |
| 8 | +from PIL import Image |
7 | 9 | from rich.console import Console |
8 | 10 | from rich.table import Table |
9 | 11 |
|
@@ -76,6 +78,92 @@ def _top_terms( |
76 | 78 | terms.append(list(vocab[highest])) |
77 | 79 | return terms |
78 | 80 |
|
| 81 | + def get_top_words( |
| 82 | + self, top_k: int = 10, positive: bool = True |
| 83 | + ) -> list[list[str]]: |
| 84 | + """Returns list of top words for each topic. |
| 85 | +
|
| 86 | + Parameters |
| 87 | + ---------- |
| 88 | + top_k: int, default 10 |
| 89 | + Number of words to return. |
| 90 | + positive: bool, default True |
| 91 | + Indicates whether the highest |
| 92 | + or lowest scoring terms should be returned. |
| 93 | + """ |
| 94 | + return self._top_terms(top_k, positive) |
| 95 | + |
| 96 | + def get_top_documents( |
| 97 | + self, |
| 98 | + raw_documents=None, |
| 99 | + document_topic_matrix=None, |
| 100 | + top_k: int = 10, |
| 101 | + positive: bool = True, |
| 102 | + ) -> list[list[str]]: |
| 103 | + """Returns list of top documents for each topic. |
| 104 | +
|
| 105 | + Parameters |
| 106 | + ---------- |
| 107 | + top_k: int, default 10 |
| 108 | + Number of documents to return per topic. |
| 109 | + positive: bool, default True |
| 110 | + Indicates whether the highest |
| 111 | + or lowest scoring documents should be returned. |
| 112 | + """ |
| 113 | + docs = [] |
| 114 | + raw_documents = raw_documents or getattr(self, "corpus", None) |
| 115 | + if raw_documents is None: |
| 116 | + raise ValueError( |
| 117 | + "No corpus was passed, can't search for representative documents." |
| 118 | + ) |
| 119 | + document_topic_matrix = document_topic_matrix or getattr( |
| 120 | + self, "document_topic_matrix", None |
| 121 | + ) |
| 122 | + if document_topic_matrix is None: |
| 123 | + try: |
| 124 | + document_topic_matrix = self.transform(raw_documents) |
| 125 | + except AttributeError: |
| 126 | + raise ValueError( |
| 127 | + "Transductive methods cannot " |
| 128 | + "infer topical content in documents.\n" |
| 129 | + "Please pass a document_topic_matrix." |
| 130 | + ) |
| 131 | + for topic_doc_vec in document_topic_matrix.T: |
| 132 | + if positive: |
| 133 | + topic_doc_vec = -topic_doc_vec |
| 134 | + highest = np.argsort(topic_doc_vec)[:top_k] |
| 135 | + docs.append([raw_documents[i_doc] for i_doc in highest]) |
| 136 | + return docs |
| 137 | + |
| 138 | + def get_top_images(self, top_k: int = True, positive: bool = True): |
| 139 | + """Returns list of top images for each topic. |
| 140 | +
|
| 141 | + Parameters |
| 142 | + ---------- |
| 143 | + top_k: int, default 10 |
| 144 | + Number of images to return. |
| 145 | + positive: bool, default True |
| 146 | + Indicates whether the highest |
| 147 | + or lowest scoring images should be returned. |
| 148 | + """ |
| 149 | + if not hasattr(self, "top_images"): |
| 150 | + raise ValueError( |
| 151 | + "Model either has not been fit or was fit without images. top_images property missing." |
| 152 | + ) |
| 153 | + if (not positive) and not hasattr(self, "negative_images"): |
| 154 | + raise ValueError( |
| 155 | + "Model either has not been fit or was fit without images. top_images property missing." |
| 156 | + ) |
| 157 | + top_images = self.top_images if positive else self.negative_images |
| 158 | + ims = [] |
| 159 | + for topic_images in top_images: |
| 160 | + if len(topic_images) < top_k: |
| 161 | + warnings.warn( |
| 162 | + "Number of images stored in the topic model is smaller than the specified top_k, returning all that the model has." |
| 163 | + ) |
| 164 | + ims.append(topic_images[:top_k]) |
| 165 | + return ims |
| 166 | + |
79 | 167 | def _rename_automatic(self, namer: TopicNamer) -> list[str]: |
80 | 168 | self.topic_names_ = namer.name_topics(self._top_terms()) |
81 | 169 | return self.topic_names_ |
@@ -928,3 +1016,218 @@ def plot_topics_over_time( |
928 | 1016 | fig.update_xaxes(title="Time Slice Start") |
929 | 1017 | fig.update_yaxes(title="Topic Importance") |
930 | 1018 | return fig |
| 1019 | + |
| 1020 | + @staticmethod |
| 1021 | + def _image_grid( |
| 1022 | + images: list[Image.Image], |
| 1023 | + final_size=(1200, 1200), |
| 1024 | + grid_size: tuple[int, int] = (4, 4), |
| 1025 | + ): |
| 1026 | + grid_img = Image.new("RGB", final_size, (255, 255, 255)) |
| 1027 | + cell_width = final_size[0] // grid_size[0] |
| 1028 | + cell_height = final_size[1] // grid_size[1] |
| 1029 | + n_rows, n_cols = grid_size |
| 1030 | + for idx, img in enumerate(images[: n_rows * n_cols]): |
| 1031 | + img = img.resize( |
| 1032 | + (cell_width, cell_height), resample=Image.Resampling.LANCZOS |
| 1033 | + ) |
| 1034 | + x_offset = (idx % grid_size[0]) * cell_width |
| 1035 | + y_offset = (idx // grid_size[1]) * cell_height |
| 1036 | + grid_img.paste(img, (x_offset, y_offset)) |
| 1037 | + return grid_img |
| 1038 | + |
| 1039 | + def plot_topics_with_images(self, n_cols: int = 3, grid_size: int = 4): |
| 1040 | + """Plots the most important images for each topic, along with keywords. |
| 1041 | +
|
| 1042 | + Note that you will need to `pip install plotly` to use plots in Turftopic. |
| 1043 | +
|
| 1044 | + Parameters |
| 1045 | + ---------- |
| 1046 | + n_cols: int, default 3 |
| 1047 | + Number of columns you want to have in the grid of topics. |
| 1048 | + grid_size: int, default 4 |
| 1049 | + The square root of the number of images you want to display for a given topic. |
| 1050 | + For instance if grid_size==4, all topics will have 16 images displayed, |
| 1051 | + since the joint image will have 4 columns and 4 rows. |
| 1052 | +
|
| 1053 | + Returns |
| 1054 | + ------- |
| 1055 | + go.Figure |
| 1056 | + Plotly figure containing top images and keywords for topics. |
| 1057 | + """ |
| 1058 | + if not hasattr(self, "top_images"): |
| 1059 | + raise ValueError( |
| 1060 | + "Model either has not been fit or was fit without images. top_images property missing." |
| 1061 | + ) |
| 1062 | + try: |
| 1063 | + import plotly.graph_objects as go |
| 1064 | + except (ImportError, ModuleNotFoundError) as e: |
| 1065 | + raise ModuleNotFoundError( |
| 1066 | + "Please install plotly if you intend to use plots in Turftopic." |
| 1067 | + ) from e |
| 1068 | + negative_images = getattr(self, "negative_images", None) |
| 1069 | + if negative_images is not None: |
| 1070 | + # If the model has negative images, it should display them side by side with the positive ones. |
| 1071 | + n_components = self.components_.shape[0] |
| 1072 | + fig = go.Figure() |
| 1073 | + width, height = 1200, 1200 |
| 1074 | + scale_factor = 0.25 |
| 1075 | + w, h = width * scale_factor, height * scale_factor |
| 1076 | + padding = 10 |
| 1077 | + figure_height = (h + padding) * n_components |
| 1078 | + figure_width = (w + padding) * 2 |
| 1079 | + fig = fig.add_trace( |
| 1080 | + go.Scatter( |
| 1081 | + x=[0, figure_width], |
| 1082 | + y=[0, figure_height], |
| 1083 | + mode="markers", |
| 1084 | + marker_opacity=0, |
| 1085 | + ) |
| 1086 | + ) |
| 1087 | + vocab = self.get_vocab() |
| 1088 | + for i, component in enumerate(self.components_): |
| 1089 | + positive = vocab[np.argsort(-component)[:7]] |
| 1090 | + negative = vocab[np.argsort(component)[:7]] |
| 1091 | + pos_image = self._image_grid( |
| 1092 | + self.top_images[i], |
| 1093 | + (width, height), |
| 1094 | + grid_size=(grid_size, grid_size), |
| 1095 | + ) |
| 1096 | + neg_image = self._image_grid( |
| 1097 | + self.negative_images[i], |
| 1098 | + (width, height), |
| 1099 | + grid_size=(grid_size, grid_size), |
| 1100 | + ) |
| 1101 | + x0 = 0 |
| 1102 | + y0 = (h + padding) * (n_components - i) |
| 1103 | + fig = fig.add_layout_image( |
| 1104 | + dict( |
| 1105 | + x=x0, |
| 1106 | + sizex=w, |
| 1107 | + y=y0, |
| 1108 | + sizey=h, |
| 1109 | + xref="x", |
| 1110 | + yref="y", |
| 1111 | + opacity=1.0, |
| 1112 | + layer="below", |
| 1113 | + sizing="stretch", |
| 1114 | + source=pos_image, |
| 1115 | + ), |
| 1116 | + ) |
| 1117 | + fig.add_annotation( |
| 1118 | + x=(w / 2), |
| 1119 | + y=(h + padding) * (n_components - i) - (h / 2), |
| 1120 | + text="<b> " + "<br> ".join(positive), |
| 1121 | + font=dict( |
| 1122 | + size=16, |
| 1123 | + family="Times New Roman", |
| 1124 | + color="white", |
| 1125 | + ), |
| 1126 | + bgcolor="rgba(0,0,255, 0.5)", |
| 1127 | + ) |
| 1128 | + x0 = (w + padding) * 1 |
| 1129 | + fig = fig.add_layout_image( |
| 1130 | + dict( |
| 1131 | + x=x0, |
| 1132 | + sizex=w, |
| 1133 | + y=y0, |
| 1134 | + sizey=h, |
| 1135 | + xref="x", |
| 1136 | + yref="y", |
| 1137 | + opacity=1.0, |
| 1138 | + layer="below", |
| 1139 | + sizing="stretch", |
| 1140 | + source=neg_image, |
| 1141 | + ), |
| 1142 | + ) |
| 1143 | + fig.add_annotation( |
| 1144 | + x=(w + padding) + (w / 2), |
| 1145 | + y=(h + padding) * (n_components - i) - (h / 2), |
| 1146 | + text="<b> " + "<br> ".join(negative), |
| 1147 | + font=dict( |
| 1148 | + size=16, |
| 1149 | + family="Times New Roman", |
| 1150 | + color="white", |
| 1151 | + ), |
| 1152 | + bgcolor="rgba(255,0,0, 0.5)", |
| 1153 | + ) |
| 1154 | + fig = fig.update_xaxes(visible=False, range=[0, figure_width]) |
| 1155 | + fig = fig.update_yaxes( |
| 1156 | + visible=False, |
| 1157 | + range=[0, figure_height], |
| 1158 | + # the scaleanchor attribute ensures that the aspect ratio stays constant |
| 1159 | + scaleanchor="x", |
| 1160 | + ) |
| 1161 | + fig = fig.update_layout( |
| 1162 | + width=figure_width, |
| 1163 | + height=figure_height, |
| 1164 | + margin={"l": 0, "r": 0, "t": 0, "b": 0}, |
| 1165 | + ) |
| 1166 | + return fig |
| 1167 | + else: |
| 1168 | + fig = go.Figure() |
| 1169 | + width, height = 1200, 1200 |
| 1170 | + scale_factor = 0.25 |
| 1171 | + w, h = width * scale_factor, height * scale_factor |
| 1172 | + padding = 10 |
| 1173 | + n_components = self.components_.shape[0] |
| 1174 | + n_rows = n_components // n_cols + int(bool(n_components % n_cols)) |
| 1175 | + figure_height = (h + padding) * n_rows |
| 1176 | + figure_width = (w + padding) * n_cols |
| 1177 | + fig = fig.add_trace( |
| 1178 | + go.Scatter( |
| 1179 | + x=[0, figure_width], |
| 1180 | + y=[0, figure_height], |
| 1181 | + mode="markers", |
| 1182 | + marker_opacity=0, |
| 1183 | + ) |
| 1184 | + ) |
| 1185 | + vocab = self.get_vocab() |
| 1186 | + for i, component in enumerate(self.components_): |
| 1187 | + col = i % n_cols |
| 1188 | + row = i // n_cols |
| 1189 | + top_7 = vocab[np.argsort(-component)[:7]] |
| 1190 | + images = self.top_images[i] |
| 1191 | + image = self._image_grid( |
| 1192 | + images, (width, height), grid_size=(grid_size, grid_size) |
| 1193 | + ) |
| 1194 | + x0 = (w + padding) * col |
| 1195 | + y0 = (h + padding) * (n_rows - row) |
| 1196 | + fig = fig.add_layout_image( |
| 1197 | + dict( |
| 1198 | + x=x0, |
| 1199 | + sizex=w, |
| 1200 | + y=y0, |
| 1201 | + sizey=h, |
| 1202 | + xref="x", |
| 1203 | + yref="y", |
| 1204 | + opacity=1.0, |
| 1205 | + layer="below", |
| 1206 | + sizing="stretch", |
| 1207 | + source=image, |
| 1208 | + ), |
| 1209 | + ) |
| 1210 | + fig.add_annotation( |
| 1211 | + x=(w + padding) * col + (w / 2), |
| 1212 | + y=(h + padding) * (n_rows - row) - (h / 2), |
| 1213 | + text="<b> " + "<br> ".join(top_7), |
| 1214 | + font=dict( |
| 1215 | + size=16, |
| 1216 | + family="Times New Roman", |
| 1217 | + color="white", |
| 1218 | + ), |
| 1219 | + bgcolor="rgba(0,0,0, 0.5)", |
| 1220 | + ) |
| 1221 | + fig = fig.update_xaxes(visible=False, range=[0, figure_width]) |
| 1222 | + fig = fig.update_yaxes( |
| 1223 | + visible=False, |
| 1224 | + range=[0, figure_height], |
| 1225 | + # the scaleanchor attribute ensures that the aspect ratio stays constant |
| 1226 | + scaleanchor="x", |
| 1227 | + ) |
| 1228 | + fig = fig.update_layout( |
| 1229 | + width=figure_width, |
| 1230 | + height=figure_height, |
| 1231 | + margin={"l": 0, "r": 0, "t": 0, "b": 0}, |
| 1232 | + ) |
| 1233 | + return fig |
0 commit comments