|
| 1 | +import os |
| 2 | +import random |
| 3 | +import spacy |
| 4 | +from spacy.util import minibatch, compounding |
| 5 | + |
| 6 | + |
| 7 | +TEST_REVIEW = """ |
| 8 | +Transcendently beautiful in moments outside the office, it seems almost |
| 9 | +sitcom-like in those scenes. When Toni Colette walks out and ponders |
| 10 | +life silently, it's gorgeous.<br /><br />The movie doesn't seem to decide |
| 11 | +whether it's slapstick, farce, magical realism, or drama, but the best of it |
| 12 | +doesn't matter. (The worst is sort of tedious - like Office Space with less humor.) |
| 13 | +""" |
| 14 | + |
| 15 | + |
| 16 | +def train_model(training_data: list, test_data: list, iterations: int = 20): |
| 17 | + # Build pipeline |
| 18 | + nlp = spacy.load("en_core_web_sm") |
| 19 | + if "textcat" not in nlp.pipe_names: |
| 20 | + textcat = nlp.create_pipe("textcat", config={"architecture": "simple_cnn"}) |
| 21 | + nlp.add_pipe(textcat, last=True) |
| 22 | + else: |
| 23 | + textcat = nlp.get_pipe("textcat") |
| 24 | + |
| 25 | + # Add labels |
| 26 | + textcat.add_label("pos") |
| 27 | + textcat.add_label("neg") |
| 28 | + |
| 29 | + # Train only textcat |
| 30 | + training_excluded_pipes = [pipe for pipe in nlp.pipe_names if pipe != "textcat"] |
| 31 | + with nlp.disable_pipes(training_excluded_pipes): |
| 32 | + optimizer = nlp.begin_training() |
| 33 | + # Training loop |
| 34 | + print("Beginning training") |
| 35 | + print("Loss\tPrecision\tRecall\tF-score") |
| 36 | + batch_sizes = compounding( |
| 37 | + 4.0, 32.0, 1.001 |
| 38 | + ) # A generator that yields infinite series of input numbers |
| 39 | + for i in range(iterations): |
| 40 | + print(f"Training iteration {i}") |
| 41 | + loss = {} |
| 42 | + random.shuffle(training_data) |
| 43 | + batches = minibatch(training_data, size=batch_sizes) |
| 44 | + for batch in batches: |
| 45 | + text, labels = zip(*batch) |
| 46 | + nlp.update(text, labels, drop=0.2, sgd=optimizer, losses=loss) |
| 47 | + with textcat.model.use_params(optimizer.averages): |
| 48 | + evaluation_results = evaluate_model( |
| 49 | + tokenizer=nlp.tokenizer, textcat=textcat, test_data=test_data |
| 50 | + ) |
| 51 | + print( |
| 52 | + f"{loss['textcat']}\t{evaluation_results['precision']}\t{evaluation_results['recall']}\t{evaluation_results['f-score']}" |
| 53 | + ) |
| 54 | + |
| 55 | + # Save model |
| 56 | + with nlp.use_params(optimizer.averages): |
| 57 | + nlp.to_disk("model_artifacts") |
| 58 | + |
| 59 | + |
| 60 | +def evaluate_model(tokenizer, textcat, test_data: list) -> dict: |
| 61 | + reviews, labels = zip(*test_data) |
| 62 | + reviews = (tokenizer(review) for review in reviews) |
| 63 | + true_positives = 0 |
| 64 | + false_positives = 1e-8 # Can't be 0 because of presence in denominator |
| 65 | + true_negatives = 0 |
| 66 | + false_negatives = 1e-8 |
| 67 | + for i, review in enumerate(textcat.pipe(reviews)): |
| 68 | + true_label = labels[i]["cats"] |
| 69 | + for predicted_label, score in review.cats.items(): |
| 70 | + if ( |
| 71 | + predicted_label == "neg" |
| 72 | + ): # Every `cats` dictionary includes both labels, you can get all the info we need with just the pos label |
| 73 | + continue |
| 74 | + if score >= 0.5 and true_label["pos"]: |
| 75 | + true_positives += 1 |
| 76 | + elif score >= 0.5 and true_label["neg"]: |
| 77 | + false_positives += 1 |
| 78 | + elif score < 0.5 and true_label["neg"]: |
| 79 | + true_negatives += 1 |
| 80 | + elif score < 0.5 and true_label["pos"]: |
| 81 | + false_negatives += 1 |
| 82 | + precision = true_positives / (true_positives + false_positives) |
| 83 | + recall = true_positives / (true_positives + false_negatives) |
| 84 | + |
| 85 | + if precision + recall == 0: |
| 86 | + f_score = 0 |
| 87 | + else: |
| 88 | + f_score = 2 * (precision * recall) / (precision + recall) |
| 89 | + return {"precision": precision, "recall": recall, "f-score": f_score} |
| 90 | + |
| 91 | + |
| 92 | +def test_model(input_data: str = TEST_REVIEW): |
| 93 | + # Load saved trained model |
| 94 | + loaded_model = spacy.load("model_artifacts") |
| 95 | + parsed_text = loaded_model(input_data) |
| 96 | + prediction = ( |
| 97 | + "Positive" if parsed_text.cats["pos"] > parsed_text.cats["neg"] else "Negative" |
| 98 | + ) |
| 99 | + score = ( |
| 100 | + parsed_text.cats["pos"] if prediction == "Positive" else parsed_text.cats["neg"] |
| 101 | + ) |
| 102 | + print( |
| 103 | + f"Review text: {input_data}\nPredicted sentiment: {prediction}\tScore: {score}" |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +def load_training_data( |
| 108 | + data_directory: str = "aclImdb/train", split: float = 0.8, limit: int = 0 |
| 109 | +) -> list: |
| 110 | + # Load from files |
| 111 | + reviews = [] |
| 112 | + for label in ["pos", "neg"]: |
| 113 | + labeled_directory = f"{data_directory}/{label}" |
| 114 | + for review in os.listdir(labeled_directory): |
| 115 | + if review.endswith(".txt"): |
| 116 | + with open(f"{labeled_directory}/{review}") as f: |
| 117 | + text = f.read() |
| 118 | + text = text.replace("<br />", "\n\n") |
| 119 | + if text.strip(): |
| 120 | + spacy_label = { |
| 121 | + "cats": {"pos": "pos" == label, "neg": "neg" == label} |
| 122 | + } |
| 123 | + reviews.append((text, spacy_label)) |
| 124 | + # Shuffle |
| 125 | + random.shuffle(reviews) |
| 126 | + if limit: |
| 127 | + reviews = reviews[:limit] |
| 128 | + split = int(len(reviews) * split) |
| 129 | + return reviews[:split], reviews[split:] |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + train, test = load_training_data(limit=2500) |
| 134 | + print("Training model") |
| 135 | + train_model(train, test) |
| 136 | + print("Testing model") |
| 137 | + test_model() |
0 commit comments