@@ -55,6 +55,51 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
5555
5656AnalysisKey IR2VecVocabAnalysis::Key;
5757
58+ namespace llvm ::json {
59+ inline bool fromJSON (const llvm::json::Value &E, Embedding &Out,
60+ llvm::json::Path P) {
61+ std::vector<double > TempOut;
62+ if (!llvm::json::fromJSON (E, TempOut, P))
63+ return false ;
64+ Out = Embedding (std::move (TempOut));
65+ return true ;
66+ }
67+ } // namespace llvm::json
68+
69+ // ==----------------------------------------------------------------------===//
70+ // Embedding
71+ // ===----------------------------------------------------------------------===//
72+
73+ Embedding &Embedding::operator +=(const Embedding &RHS) {
74+ assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
75+ std::transform (this ->begin (), this ->end (), RHS.begin (), this ->begin (),
76+ std::plus<double >());
77+ return *this ;
78+ }
79+
80+ Embedding &Embedding::operator -=(const Embedding &RHS) {
81+ assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
82+ std::transform (this ->begin (), this ->end (), RHS.begin (), this ->begin (),
83+ std::minus<double >());
84+ return *this ;
85+ }
86+
87+ Embedding &Embedding::scaleAndAdd (const Embedding &Src, float Factor) {
88+ assert (this ->size () == Src.size () && " Vectors must have the same dimension" );
89+ for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
90+ (*this )[Itr] += Src[Itr] * Factor;
91+ return *this ;
92+ }
93+
94+ bool Embedding::approximatelyEquals (const Embedding &RHS,
95+ double Tolerance) const {
96+ assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
97+ for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
98+ if (std::abs ((*this )[Itr] - RHS[Itr]) > Tolerance)
99+ return false ;
100+ return true ;
101+ }
102+
58103// ==----------------------------------------------------------------------===//
59104// Embedder and its subclasses
60105// ===----------------------------------------------------------------------===//
@@ -73,20 +118,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
73118 return make_error<StringError>(" Unknown IR2VecKind" , errc::invalid_argument);
74119}
75120
76- void Embedder::addVectors (Embedding &Dst, const Embedding &Src) {
77- assert (Dst.size () == Src.size () && " Vectors must have the same dimension" );
78- std::transform (Dst.begin (), Dst.end (), Src.begin (), Dst.begin (),
79- std::plus<double >());
80- }
81-
82- void Embedder::addScaledVector (Embedding &Dst, const Embedding &Src,
83- float Factor) {
84- assert (Dst.size () == Src.size () && " Vectors must have the same dimension" );
85- for (size_t i = 0 ; i < Dst.size (); ++i) {
86- Dst[i] += Src[i] * Factor;
87- }
88- }
89-
90121// FIXME: Currently lookups are string based. Use numeric Keys
91122// for efficiency
92123Embedding Embedder::lookupVocab (const std::string &Key) const {
@@ -164,20 +195,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
164195 Embedding InstVector (Dimension, 0 );
165196
166197 const auto OpcVec = lookupVocab (I.getOpcodeName ());
167- addScaledVector ( InstVector, OpcVec, OpcWeight);
198+ InstVector. scaleAndAdd ( OpcVec, OpcWeight);
168199
169200 // FIXME: Currently lookups are string based. Use numeric Keys
170201 // for efficiency.
171202 const auto Type = I.getType ();
172203 const auto TypeVec = getTypeEmbedding (Type);
173- addScaledVector ( InstVector, TypeVec, TypeWeight);
204+ InstVector. scaleAndAdd ( TypeVec, TypeWeight);
174205
175206 for (const auto &Op : I.operands ()) {
176207 const auto OperandVec = getOperandEmbedding (Op.get ());
177- addScaledVector ( InstVector, OperandVec, ArgWeight);
208+ InstVector. scaleAndAdd ( OperandVec, ArgWeight);
178209 }
179210 InstVecMap[&I] = InstVector;
180- addVectors ( BBVector, InstVector) ;
211+ BBVector += InstVector;
181212 }
182213 BBVecMap[&BB] = BBVector;
183214}
@@ -187,7 +218,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
187218 return ;
188219 for (const auto &BB : F) {
189220 computeEmbeddings (BB);
190- addVectors ( FuncVector, BBVecMap[&BB]) ;
221+ FuncVector += BBVecMap[&BB];
191222 }
192223}
193224
0 commit comments