1
+ import os
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer , util
4
+
5
+ MODEL_NAME = 'all-MiniLM-L6-v2'
6
+ MODEL_FOLDER = 'model'
7
+
8
+ def load_file (file_path ):
9
+ with open (file_path , 'r' , encoding = 'utf-8' ) as file :
10
+ return [line .strip () for line in file if line .strip ()]
11
+
12
+ def load_or_download_model ():
13
+ model_path = os .path .join (MODEL_FOLDER , MODEL_NAME )
14
+ if os .path .exists (model_path ):
15
+ print (f"Loading model from { model_path } " )
16
+ return SentenceTransformer (model_path )
17
+ else :
18
+ print (f"Downloading model { MODEL_NAME } " )
19
+ model = SentenceTransformer (MODEL_NAME )
20
+ os .makedirs (MODEL_FOLDER , exist_ok = True )
21
+ model .save (model_path )
22
+ print (f"Model saved to { model_path } " )
23
+ return model
24
+
25
+ def cosine_similarity (query_embedding , sentence_embeddings ):
26
+ return util .pytorch_cos_sim (query_embedding , sentence_embeddings )[0 ]
27
+
28
+ def euclidean_distance (query_embedding , sentence_embeddings ):
29
+ return - np .linalg .norm (query_embedding - sentence_embeddings , axis = 1 )
30
+
31
+ def manhattan_distance (query_embedding , sentence_embeddings ):
32
+ return - np .sum (np .abs (query_embedding - sentence_embeddings ), axis = 1 )
33
+
34
+ def dot_product (query_embedding , sentence_embeddings ):
35
+ return np .dot (sentence_embeddings , query_embedding .T ).flatten ()
36
+
37
+ similarity_functions = {
38
+ '1' : ('Cosine Similarity' , cosine_similarity ),
39
+ '2' : ('Euclidean Distance' , euclidean_distance ),
40
+ '3' : ('Manhattan Distance' , manhattan_distance ),
41
+ '4' : ('Dot Product' , dot_product )
42
+ }
43
+
44
+ def find_similar_sentences (query , file_path , similarity_func , top_n = 5 ):
45
+ model = load_or_download_model ()
46
+ sentences = load_file (file_path )
47
+ sentence_embeddings = model .encode (sentences )
48
+ query_embedding = model .encode ([query ])
49
+
50
+ similarity_scores = similarity_func (query_embedding , sentence_embeddings )
51
+ top_results = sorted (zip (sentences , similarity_scores ), key = lambda x : x [1 ], reverse = True )[:top_n ]
52
+
53
+ return top_results
54
+
55
+ def validate_file_path (file_path ):
56
+ if not file_path .endswith ('.txt' ):
57
+ file_path += '.txt'
58
+ if not os .path .exists (file_path ):
59
+ raise FileNotFoundError (f"The file '{ file_path } ' does not exist." )
60
+ return file_path
61
+
62
+ def main ():
63
+ print ("Welcome to the Enhanced Sentence Similarity Search Tool!" )
64
+
65
+ query = input ("Enter your query: " )
66
+
67
+ while True :
68
+ file_path = input ("Enter the path to your text file without extension: " )
69
+ try :
70
+ file_path = validate_file_path (file_path )
71
+ break
72
+ except FileNotFoundError as e :
73
+ print (f"Error: { str (e )} Please try again." )
74
+
75
+ print ("\n Choose a similarity measurement method:" )
76
+ for key , (name , _ ) in similarity_functions .items ():
77
+ print (f"{ key } . { name } " )
78
+
79
+ while True :
80
+ choice = input ("Enter the number of your choice: " )
81
+ if choice in similarity_functions :
82
+ similarity_name , similarity_func = similarity_functions [choice ]
83
+ break
84
+ print ("Invalid choice. Please try again." )
85
+
86
+ try :
87
+ results = find_similar_sentences (query , file_path , similarity_func )
88
+ print (f"\n Top 5 similar sentences for query: '{ query } ' using { similarity_name } \n " )
89
+ for sentence , score in results :
90
+ print (f"Similarity Score: { score :.4f} " )
91
+ print (f"Sentence: { sentence } \n " )
92
+ except Exception as e :
93
+ print (f"An error occurred: { str (e )} " )
94
+
95
+ if __name__ == "__main__" :
96
+ main ()
0 commit comments