|
| 1 | +import os |
| 2 | +import random |
| 3 | +import spacy |
| 4 | +from spacy.util import minibatch, compounding |
| 5 | +import pandas as pd |
| 6 | + |
| 7 | + |
| 8 | +TEST_REVIEW = """ |
| 9 | +Transcendently beautiful in moments outside the office, it seems almost |
| 10 | +sitcom-like in those scenes. When Toni Colette walks out and ponders |
| 11 | +life silently, it's gorgeous.<br /><br />The movie doesn't seem to decide |
| 12 | +whether it's slapstick, farce, magical realism, or drama, but the best of it |
| 13 | +doesn't matter. (The worst is sort of tedious - like Office Space with less |
| 14 | +humor.) |
| 15 | +""" |
| 16 | + |
| 17 | + |
| 18 | +eval_list = [] |
| 19 | + |
| 20 | + |
| 21 | +def train_model( |
| 22 | + training_data: list, test_data: list, iterations: int = 20 |
| 23 | +) -> None: |
| 24 | + # Build pipeline |
| 25 | + nlp = spacy.load("en_core_web_sm") |
| 26 | + if "textcat" not in nlp.pipe_names: |
| 27 | + textcat = nlp.create_pipe( |
| 28 | + "textcat", config={"architecture": "simple_cnn"} |
| 29 | + ) |
| 30 | + nlp.add_pipe(textcat, last=True) |
| 31 | + else: |
| 32 | + textcat = nlp.get_pipe("textcat") |
| 33 | + |
| 34 | + textcat.add_label("pos") |
| 35 | + textcat.add_label("neg") |
| 36 | + |
| 37 | + # Train only textcat |
| 38 | + training_excluded_pipes = [ |
| 39 | + pipe for pipe in nlp.pipe_names if pipe != "textcat" |
| 40 | + ] |
| 41 | + with nlp.disable_pipes(training_excluded_pipes): |
| 42 | + optimizer = nlp.begin_training() |
| 43 | + # Training loop |
| 44 | + print("Beginning training") |
| 45 | + print("Loss\tPrecision\tRecall\tF-score") |
| 46 | + batch_sizes = compounding( |
| 47 | + 4.0, 32.0, 1.001 |
| 48 | + ) # A generator that yields infinite series of input numbers |
| 49 | + for i in range(iterations): |
| 50 | + print(f"Training iteration {i}") |
| 51 | + loss = {} |
| 52 | + random.shuffle(training_data) |
| 53 | + batches = minibatch(training_data, size=batch_sizes) |
| 54 | + for batch in batches: |
| 55 | + text, labels = zip(*batch) |
| 56 | + nlp.update(text, labels, drop=0.2, sgd=optimizer, losses=loss) |
| 57 | + with textcat.model.use_params(optimizer.averages): |
| 58 | + evaluation_results = evaluate_model( |
| 59 | + tokenizer=nlp.tokenizer, |
| 60 | + textcat=textcat, |
| 61 | + test_data=test_data, |
| 62 | + ) |
| 63 | + print( |
| 64 | + f"{loss['textcat']}\t{evaluation_results['precision']}" |
| 65 | + f"\t{evaluation_results['recall']}" |
| 66 | + f"\t{evaluation_results['f-score']}" |
| 67 | + ) |
| 68 | + |
| 69 | + # Save model |
| 70 | + with nlp.use_params(optimizer.averages): |
| 71 | + nlp.to_disk("model_artifacts") |
| 72 | + |
| 73 | + |
| 74 | +def evaluate_model(tokenizer, textcat, test_data: list) -> dict: |
| 75 | + reviews, labels = zip(*test_data) |
| 76 | + reviews = (tokenizer(review) for review in reviews) |
| 77 | + true_positives = 0 |
| 78 | + false_positives = 1e-8 # Can't be 0 because of presence in denominator |
| 79 | + true_negatives = 0 |
| 80 | + false_negatives = 1e-8 |
| 81 | + for i, review in enumerate(textcat.pipe(reviews)): |
| 82 | + true_label = labels[i] |
| 83 | + for predicted_label, score in review.cats.items(): |
| 84 | + # Every cats dictionary includes both labels, you can get all |
| 85 | + # the info you need with just the pos label |
| 86 | + if predicted_label == "neg": |
| 87 | + continue |
| 88 | + if score >= 0.5 and true_label == "pos": |
| 89 | + true_positives += 1 |
| 90 | + elif score >= 0.5 and true_label == "neg": |
| 91 | + false_positives += 1 |
| 92 | + elif score < 0.5 and true_label == "neg": |
| 93 | + true_negatives += 1 |
| 94 | + elif score < 0.5 and true_label == "pos": |
| 95 | + false_negatives += 1 |
| 96 | + precision = true_positives / (true_positives + false_positives) |
| 97 | + recall = true_positives / (true_positives + false_negatives) |
| 98 | + |
| 99 | + if precision + recall == 0: |
| 100 | + f_score = 0 |
| 101 | + else: |
| 102 | + f_score = 2 * (precision * recall) / (precision + recall) |
| 103 | + return {"precision": precision, "recall": recall, "f-score": f_score} |
| 104 | + |
| 105 | + |
| 106 | +def test_model(input_data: str = TEST_REVIEW): |
| 107 | + # Load saved trained model |
| 108 | + loaded_model = spacy.load("model_artifacts") |
| 109 | + # Generate prediction |
| 110 | + parsed_text = loaded_model(input_data) |
| 111 | + # Determine prediction to return |
| 112 | + if parsed_text.cats["pos"] > parsed_text.cats["neg"]: |
| 113 | + prediction = "Positive" |
| 114 | + score = parsed_text.cats["pos"] |
| 115 | + else: |
| 116 | + prediction = "Negative" |
| 117 | + score = parsed_text.cats["neg"] |
| 118 | + print( |
| 119 | + f"Review text: {input_data}\nPredicted sentiment: {prediction}" |
| 120 | + f"\tScore: {score}" |
| 121 | + ) |
| 122 | + |
| 123 | + |
| 124 | +def load_training_data( |
| 125 | + data_directory: str = "aclImdb/train", split: float = 0.8, limit: int = 0 |
| 126 | +) -> tuple: |
| 127 | + # Load from files |
| 128 | + reviews = [] |
| 129 | + for label in ["pos", "neg"]: |
| 130 | + labeled_directory = f"{data_directory}/{label}" |
| 131 | + for review in os.listdir(labeled_directory): |
| 132 | + if review.endswith(".txt"): |
| 133 | + with open(f"{labeled_directory}/{review}") as f: |
| 134 | + text = f.read() |
| 135 | + text = text.replace("<br />", "\n\n") |
| 136 | + if text.strip(): |
| 137 | + spacy_label = { |
| 138 | + "cats": { |
| 139 | + "pos": "pos" == label, |
| 140 | + "neg": "neg" == label, |
| 141 | + } |
| 142 | + } |
| 143 | + reviews.append((text, spacy_label)) |
| 144 | + random.shuffle(reviews) |
| 145 | + |
| 146 | + if limit: |
| 147 | + reviews = reviews[:limit] |
| 148 | + split = int(len(reviews) * split) |
| 149 | + return reviews[:split], reviews[split:] |
| 150 | + |
| 151 | + |
| 152 | +if __name__ == "__main__": |
| 153 | + train, test = load_training_data(limit=25) |
| 154 | + print("Training model") |
| 155 | + train_model(train, test) |
| 156 | + df = pd.DataFrame(eval_list) |
| 157 | + pd.DataFrame.plot(df) |
| 158 | + print("Testing model") |
| 159 | + test_model() |
0 commit comments