diff --git a/pyproject.toml b/pyproject.toml index caa8b91..05006bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ profile = "black" [project] name = "turftopic" -version = "0.18.0" +version = "0.19.0" description = "Topic modeling with contextual representations from sentence transformers." authors = [ { name = "Márton Kardos ", email = "martonkardos@cas.au.dk" } diff --git a/turftopic/_concept_browser.py b/turftopic/_concept_browser.py new file mode 100644 index 0000000..8173c6d --- /dev/null +++ b/turftopic/_concept_browser.py @@ -0,0 +1,443 @@ +import tempfile +import time +import webbrowser +from random import shuffle +from typing import Literal + +import numpy as np + +from turftopic._figure import HTMLFigure +from turftopic.utils import sanitize_for_html + + +def open_html(html: str): + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as f: + url = "file://" + f.name + f.write(html) + time.sleep(1.0) + webbrowser.open(url) + + +COLOR_PALETTE = [ + "rgba(116,221,201, 0.25)", + "rgba(235,171,204, 0.25)", + "rgba(181,217,160, 0.25)", + "rgba(197,179,239, 0.25)", + "rgba(222,202,142, 0.25)", + "rgba(160,188,233, 0.25)", + "rgba(234,177,158, 0.25)", + "rgba(129,209,231, 0.25)", + "rgba(207,198,216, 0.25)", + "rgba(186,213,195, 0.25)", +] + + +def create_bar_plot(seed_id, topic_names, topic_sizes, colors: list[str]): + topic_sizes = [int(size) for size in topic_sizes] + plot_js = """ + const yValue{seed_id} = {topic_sizes}; + const data{seed_id} = [ + {{ + x: {topic_names}, + y: yValue{seed_id}, + type: "bar", + hoverinfo: 'none', + textposition: 'auto', + text: yValue{seed_id}.map(n => `n=${{n}}`).map(String), + marker: {{ + color: {colors}, + line: {{ + color: 'black', + width: 1 + }} + }} + }} + ]; + const layout{seed_id} = {{ + font: {{family: "Merriweather, sans-serif", size: 16}}, + margin: {{t: 0, l: 30, b:15, r: 30}}, + height: 200, + width: 600, + yaxis: {{ + showticklabels: false + }}, + xaxis: {{ + showticklabels: false + }}, + }}; + Plotly.newPlot("plot-{seed_id}", data{seed_id}, layout{seed_id}) + """.format( + topic_names=str(topic_names), + topic_sizes=str(topic_sizes), + colors=str(colors), + seed_id=str(seed_id), + ) + res = """ +
+ + """.format( + plot_js=plot_js, + seed_id=seed_id, + ) + return res + + +def render_cards( + seed_id, + topic_names, + keywords, + topic_descriptions, + topic_sizes, + colors: list[str], +): + # Adding the ID if it doesn't already start like that + topic_names = [ + f"{i} - {name}" + for i, name in enumerate(topic_names) + if not name.startswith(str(i)) + ] + res = create_bar_plot(seed_id, topic_names, topic_sizes, colors) + for i, (name, keys, desc, color) in enumerate( + zip( + topic_names, + keywords, + topic_descriptions, + colors, + ), + ): + desc = sanitize_for_html(desc) + name = sanitize_for_html(name) + res += """ +
+

{name}

+

Keywords: {keywords}

+

Description: {description}

+

+ """.format( + name=name, + keywords=", ".join(keys), + description=desc, + i=i, + color=color, + seed_id=seed_id, + ) + return res + + +def prep_document(doc: str, max_chars: int = 900): + if len(doc) > max_chars: + doc = doc[: max_chars - 3] + "..." + doc = sanitize_for_html(doc) + return doc + + +def render_documents(top_documents: list[list[str]], colors: list[str]): + documents = [] + labels = [] + # Flattening the array out + for topic_id, _top in enumerate(top_documents): + documents.extend(_top) + labels.extend([topic_id] * len(_top)) + indices = list(range(len(documents))) + shuffle(indices) + res = "" + for i_doc in indices: + res += """ +
+ {document} +
+ """.format( + document=prep_document(documents[i_doc]), + bgcolor=colors[labels[i_doc]], + ) + return res + + +def render_widget( + seeds: list[str], + topic_names: list[list[str]], + keywords: list[list[list[str]]], + topic_descriptions: list[list[str]], + topic_sizes: list[np.ndarray], + top_documents: list[list[str]], +) -> str: + n_topics = 0 + for names in topic_names: + n_topics += len(names) + colors = list(COLOR_PALETTE) + # Colors will loop back when we run out of them + while len(colors) <= n_topics: + colors.extend(COLOR_PALETTE) + res = """
\n""" + res += """ +
+
Seed Phrases
+ """ + for seed_id, seed in enumerate(seeds): + res += """ + + """.format( + bgcolor="rgba(160,188,233, 0.4)" if seed_id == 0 else "white", + seed_id=seed_id, + seed=str(seed), + ) + res += "
\n" + color_start = 0 + for seed_id, seed in enumerate(seeds): + n_topics = len(topic_names[seed_id]) + seed_colors = colors[color_start : color_start + n_topics] + color_start += n_topics + container = """ +
+
Concepts
+ {content} +
+ """.format( + content=render_cards( + seed_id, + topic_names[seed_id], + keywords[seed_id], + topic_descriptions[seed_id], + topic_sizes[seed_id], + seed_colors, + ), + visibility="block" if seed_id == 0 else "none", + seed_id=seed_id, + ) + res += container + res += "\n
\n" + color_start = 0 + for seed_id, seed in enumerate(seeds): + n_topics = len(topic_names[seed_id]) + seed_colors = colors[color_start : color_start + n_topics] + color_start += n_topics + res += """ +
+
Example Documents
+ """.format( + seed_id=seed_id, + visibility="flex" if seed_id == 0 else "none", + ) + res += render_documents(top_documents[seed_id], seed_colors) + res += "\n
\n" + res += """ + + """ + return res + + +STYLE = """ +body { + font-family: "Merriweather", Times New Roman; +} +.topic-container { + max-width: 600px; + overflow-y: auto; + overflow-x: hidden; +} +.floating-label { + padding: 10px; + position: fixed; + color: white; + background-color: black; + border-radius: 10px; + z-index: 30; + width: fit-content; + margin-top: -30px; + margin-left: -30px; +} +.card { + padding-top: 2px; + padding-bottom: 2px; + padding-left: 15px; + padding-right: 15px; + margin: 5px; + margin-top: 10px; + margin-bottom: 10px; + background-color: #E6F3FF; + border: solid; + box-shadow: 0px 0px 1px 1px rgba(0,0,0,0.1); + border-color: #999999; + border-width: 1px; + border-radius: 5px; + border-color: black; +} +.button-container { + padding: 10px; + margin: 30px; + margin-bottom: 5px; + background-color: "white"; + border-radius: 5px; + box-shadow: 0px 0px 1px 1px rgba(0,0,0,0.1); + align-self: stretch; + flex-grow: 1; + flex-shrink: 1; + display: flex; + flex-direction: row; + max-width: 600px; +} +.box { + padding: 10px; + margin: 30px; + background-color: "white"; + border-radius: 5px; + box-shadow: 0px 0px 1px 1px rgba(0,0,0,0.1); + align-self: stretch; + flex-grow: 0; + flex-shrink: 1; +} +.document { + padding: 20px; + margin: 5px; + border-radius: 5px; + border: solid; + box-shadow: 0px 0px 1px 1px rgba(0,0,0,0.1); + border-color: #999999; + border-width: 1px; + max-height: 150px; + flex-shrink: 0; + text-align: left; + text-overflow: ellipsis; + font-style: italic; + overflow: hidden; +} +.model-switcher { + font-size: 16px; + font-family: "Merriweather", Times New Roman; + display: flex; + flex-grow: 1; + align-items: center; + align-content: center; + border: solid; + border-color: #999999; + border-width: 1px; + border-radius: 5px; + background-color: white; + margin: 5px; + color: black; + padding: 10px; + justify-content: left; + text-decoration: none; +} +.model-switcher:hover { + border-color: black; +} +.document-viewer { + flex-basis: 600px; + display: flex; + justify-content: flex-start; + flex-direction: column; + overflow-y: scroll; + overflow-x: hidden; +} +#container { + display: flex; + flex-direction: row; + flex-grow: 0; + justify-content: center; + align-items: stretch; + align-content: stretch; + max-height: 1000px; +} +.column { + display: flex; + flex-direction: column; + align-items: stretch; + flex-basis: fit-content; +} +""" +HTML_WRAPPER = """ + + + + + + + + + + +
+ {body_content} +
+ + +""" + + +def create_browser( + seeds: list[str], + topic_names: list[list[str]], + keywords: list[list[list[str]]], + topic_descriptions: list[list[str]], + topic_sizes: np.ndarray, + top_documents: list[list[str]], +) -> HTMLFigure: + """Creates a concept browser figure with which you can investigate concepts related to different seeds. + + Parameters + ---------- + seeds: list[str] + Seed phrases used for the analysis. + topic_names: list[list[str]] + Names of the topics for each of the seed phrases. + keywords: list[list[list[str]]] + Keywords for each of the topics for each seed. + topic_descriptions: list[list[str]] + Descriptions of the topics for each of the seed phrases. + topic_sizes: np.ndarray + Sizes of the topics for each seed, preferably number of documents. + top_documents: list[list[str]] + Top documents for each of the topics for each seed. + + Returns + ------- + HTMLFigure + Interactive HTML figure that you can either display or save. + """ + html = HTML_WRAPPER.format( + style=STYLE, + body_content=render_widget( + seeds, + topic_names, + keywords, + topic_descriptions, + topic_sizes, + top_documents, + ), + ) + return HTMLFigure(html) diff --git a/turftopic/_datamapplot.py b/turftopic/_datamapplot.py index 62acc98..759ad56 100644 --- a/turftopic/_datamapplot.py +++ b/turftopic/_datamapplot.py @@ -1,3 +1,4 @@ +import re import tempfile import time import webbrowser @@ -7,6 +8,8 @@ import numpy as np from sklearn.preprocessing import scale +from turftopic.utils import sanitize_for_html + CUSTOM_CSS = """ .row { display : flex; @@ -110,7 +113,7 @@ def build_datamapplot( topic_tree_kwds={ "color_bullets": True, }, - cluster_boundary_polygons=True, + cluster_boundary_polygons=False, cluster_boundary_line_width=6, polygon_alpha=2, **kwargs, @@ -149,7 +152,7 @@ def build_datamapplot( for label in topic_names: percentages.append(100 * np.sum(labels == label) / len(labels)) # Sanitizing the names so they don't mess up the HTML - topic_names = [name.replace('"', "'") for name in topic_names] + topic_names = [sanitize_for_html(name) for name in topic_names] custom_js = "" custom_js += "const nameToPercent = new Map();\n" for name, percent in zip(topic_names, percentages): @@ -160,7 +163,7 @@ def build_datamapplot( custom_js += "const nameToDesc = new Map();\n" if topic_descriptions is not None: topic_descriptions = [ - desc.replace('"', "'") for desc in topic_descriptions + sanitize_for_html(desc) for desc in topic_descriptions ] for topic_id, name, desc in zip( classes, topic_names, topic_descriptions diff --git a/turftopic/_figure.py b/turftopic/_figure.py new file mode 100644 index 0000000..9668f9f --- /dev/null +++ b/turftopic/_figure.py @@ -0,0 +1,25 @@ +import tempfile +import time +import webbrowser +from pathlib import Path +from typing import Union + + +class HTMLFigure: + def __init__(self, html: str): + self.html = html + + def show(self): + with tempfile.TemporaryDirectory() as temp_dir: + file_name = Path(temp_dir).joinpath("fig.html") + self.write_html(file_name) + webbrowser.open("file://" + str(file_name.absolute()), new=2) + time.sleep(2) + + def write_html(self, path: Union[str, Path]): + path = Path(path) + with path.open("w") as out_file: + out_file.write(self.html) + + def __repr_html__(self): + return self.html diff --git a/turftopic/models/_keynmf.py b/turftopic/models/_keynmf.py index afbcfd5..3a27628 100644 --- a/turftopic/models/_keynmf.py +++ b/turftopic/models/_keynmf.py @@ -8,8 +8,12 @@ import numpy as np import scipy.sparse as spr from sklearn.base import clone -from sklearn.decomposition._nmf import (NMF, MiniBatchNMF, _initialize_nmf, - _update_coordinate_descent) +from sklearn.decomposition._nmf import ( + NMF, + MiniBatchNMF, + _initialize_nmf, + _update_coordinate_descent, +) from sklearn.exceptions import NotFittedError from sklearn.feature_extraction.text import CountVectorizer from sklearn.metrics.pairwise import cosine_similarity @@ -169,6 +173,7 @@ def batch_extract_keywords( documents: list[str], embeddings: Optional[np.ndarray] = None, seed_embedding: Optional[np.ndarray] = None, + seed_exponent: float = 4.0, fitting: bool = True, ) -> list[dict[str, float]]: if not len(documents): @@ -199,6 +204,7 @@ def batch_extract_keywords( else: document_relevance = np.dot(embeddings, seed_embedding) document_relevance[document_relevance < 0] = 0 + document_relevance = np.power(document_relevance, seed_exponent) for i in range(total): terms = document_term_matrix[i, :].todense() embedding = embeddings[i].reshape(1, -1) diff --git a/turftopic/models/keynmf.py b/turftopic/models/keynmf.py index c8b2d8c..3d26b9b 100644 --- a/turftopic/models/keynmf.py +++ b/turftopic/models/keynmf.py @@ -60,6 +60,8 @@ class KeyNMF(ContextualModel, DynamicTopicModel, MultimodalModel): Describes an aspect of the corpus that the model should explore. It can be a free-text query, such as "Christian Denominations: Protestantism and Catholicism" + seed_exponent: float, default 2.0 + Exponent that is applied to document weight in relation to the provided seed phrase. cross_lingual: bool, default False Indicates whether KeyNMF should match terms across languages. This is useful when you have a corpus containing multiple languages. @@ -78,12 +80,14 @@ def __init__( random_state: Optional[int] = None, metric: Literal["cosine", "dot"] = "cosine", seed_phrase: Optional[str] = None, + seed_exponent: float = 2.0, cross_lingual: bool = False, term_match_threshold: float = 0.9, ): self.random_state = random_state self.n_components = n_components self.top_n = top_n + self.seed_exponent = seed_exponent self.metric = metric self.encoder = encoder self._has_custom_vectorizer = vectorizer is not None @@ -140,6 +144,7 @@ def extract_keywords( batch_or_document, embeddings=embeddings, seed_embedding=self.seed_embedding, + seed_exponent=self.seed_exponent, fitting=fitting, ) if self.cross_lingual: diff --git a/turftopic/utils.py b/turftopic/utils.py index 1417589..25d99b6 100644 --- a/turftopic/utils.py +++ b/turftopic/utils.py @@ -56,3 +56,18 @@ def export_table( raise ValueError( f"Format '{format}' not supported for tables, please use 'markdown', 'latex' or 'csv'" ) + + +def sanitize_for_html(text: str) -> str: + """Sanitizes strings so they can be put into JS or HTML strings""" + # Escaping special characters + text = ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + # Removing unnecessary whitespace + text = " ".join(text.split()) + return text