Skip to content

Commit 5974826

Browse files
committed
added somemore similarity measurement method #1333
1 parent 4e2c4a2 commit 5974826

File tree

1 file changed

+39
-19
lines changed

1 file changed

+39
-19
lines changed

NLP/Sentence_Similarity.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import numpy as np
23
from sentence_transformers import SentenceTransformer, util
34

45
MODEL_NAME = 'all-MiniLM-L6-v2'
@@ -21,23 +22,34 @@ def load_or_download_model():
2122
print(f"Model saved to {model_path}")
2223
return model
2324

24-
def find_similar_sentences(query, file_path, top_n=5):
25-
# Load the pre-trained model
26-
model = load_or_download_model()
25+
def cosine_similarity(query_embedding, sentence_embeddings):
26+
return util.pytorch_cos_sim(query_embedding, sentence_embeddings)[0]
2727

28-
# Load and encode the sentences from the file
29-
sentences = load_file(file_path)
30-
sentence_embeddings = model.encode(sentences)
28+
def euclidean_distance(query_embedding, sentence_embeddings):
29+
return -np.linalg.norm(query_embedding - sentence_embeddings, axis=1)
3130

32-
# Encode the query
33-
query_embedding = model.encode([query])
31+
def manhattan_distance(query_embedding, sentence_embeddings):
32+
return -np.sum(np.abs(query_embedding - sentence_embeddings), axis=1)
3433

35-
# Calculate cosine similarities
36-
cosine_scores = util.pytorch_cos_sim(query_embedding, sentence_embeddings)[0]
34+
def dot_product(query_embedding, sentence_embeddings):
35+
return np.dot(sentence_embeddings, query_embedding.T).flatten()
3736

38-
# Get top N results
39-
top_results = sorted(zip(sentences, cosine_scores), key=lambda x: x[1], reverse=True)[:top_n]
37+
similarity_functions = {
38+
'1': ('Cosine Similarity', cosine_similarity),
39+
'2': ('Euclidean Distance', euclidean_distance),
40+
'3': ('Manhattan Distance', manhattan_distance),
41+
'4': ('Dot Product', dot_product)
42+
}
4043

44+
def find_similar_sentences(query, file_path, similarity_func, top_n=5):
45+
model = load_or_download_model()
46+
sentences = load_file(file_path)
47+
sentence_embeddings = model.encode(sentences)
48+
query_embedding = model.encode([query])
49+
50+
similarity_scores = similarity_func(query_embedding, sentence_embeddings)
51+
top_results = sorted(zip(sentences, similarity_scores), key=lambda x: x[1], reverse=True)[:top_n]
52+
4153
return top_results
4254

4355
def validate_file_path(file_path):
@@ -48,26 +60,34 @@ def validate_file_path(file_path):
4860
return file_path
4961

5062
def main():
51-
print("Welcome to the Sentence Similarity Search Tool!")
63+
print("Welcome to the Enhanced Sentence Similarity Search Tool!")
5264

53-
# Get user input for query
5465
query = input("Enter your query: ")
5566

56-
# Get user input for file path and validate it
5767
while True:
5868
file_path = input("Enter the path to your text file without extension: ")
5969
try:
6070
file_path = validate_file_path(file_path)
6171
break
6272
except FileNotFoundError as e:
6373
print(f"Error: {str(e)} Please try again.")
74+
75+
print("\nChoose a similarity measurement method:")
76+
for key, (name, _) in similarity_functions.items():
77+
print(f"{key}. {name}")
78+
79+
while True:
80+
choice = input("Enter the number of your choice: ")
81+
if choice in similarity_functions:
82+
similarity_name, similarity_func = similarity_functions[choice]
83+
break
84+
print("Invalid choice. Please try again.")
6485

6586
try:
66-
results = find_similar_sentences(query, file_path)
67-
68-
print(f"\nTop 5 similar sentences for query: '{query}'\n")
87+
results = find_similar_sentences(query, file_path, similarity_func)
88+
print(f"\nTop 5 similar sentences for query: '{query}' using {similarity_name}\n")
6989
for sentence, score in results:
70-
print(f"Similarity: {score:.4f}")
90+
print(f"Similarity Score: {score:.4f}")
7191
print(f"Sentence: {sentence}\n")
7292
except Exception as e:
7393
print(f"An error occurred: {str(e)}")

0 commit comments

Comments
 (0)