@@ -327,16 +327,19 @@ void RecordTagWer(vector<shared_ptr<Stitching>> stitches) {
327327 for (auto &stitch : stitches) {
328328 if (!stitch->nlpRow .wer_tags .empty ()) {
329329 for (auto wer_tag : stitch->nlpRow .wer_tags ) {
330- wer_results.insert (std::pair<std::string, WerResult>(wer_tag, {0 , 0 , 0 , 0 , 0 }));
330+ int tag_start = wer_tag.find_first_not_of (' #' );
331+ int tag_end = wer_tag.find (' _' );
332+ string wer_tag_id = wer_tag.substr (tag_start, tag_end - tag_start);
333+ wer_results.insert (std::pair<std::string, WerResult>(wer_tag_id, {0 , 0 , 0 , 0 , 0 }));
331334 // Check with rfind since other comments can be there
332335 bool del = stitch->comment .rfind (" del" , 0 ) == 0 ;
333336 bool ins = stitch->comment .rfind (" ins" , 0 ) == 0 ;
334337 bool sub = stitch->comment .rfind (" sub" , 0 ) == 0 ;
335- wer_results[wer_tag ].insertions += ins;
336- wer_results[wer_tag ].deletions += del;
337- wer_results[wer_tag ].substitutions += sub;
338+ wer_results[wer_tag_id ].insertions += ins;
339+ wer_results[wer_tag_id ].deletions += del;
340+ wer_results[wer_tag_id ].substitutions += sub;
338341 if (!ins) {
339- wer_results[wer_tag ].numWordsInReference += 1 ;
342+ wer_results[wer_tag_id ].numWordsInReference += 1 ;
340343 }
341344 }
342345 }
@@ -503,7 +506,7 @@ void AddErrorGroup(ErrorGroups &groups, size_t &line, string &ref, string &hyp)
503506 hyp = " " ;
504507}
505508
506- void WriteSbs (spWERA topAlignment, string sbs_filename) {
509+ void WriteSbs (spWERA topAlignment, vector<shared_ptr<Stitching>> stitches, string sbs_filename) {
507510 auto logger = logger::GetOrCreateLogger (" wer" );
508511 logger->set_level (spdlog::level::info);
509512
@@ -514,7 +517,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
514517 triple *tk_pair = new triple ();
515518 string prev_tk_classLabel = " " ;
516519 logger->info (" Side-by-Side alignment info going into {}" , sbs_filename);
517- myfile << fmt::format (" {0:>20}\t {1:20}\t {2}\t {3}" , " ref_token" , " hyp_token" , " IsErr" , " Class" ) << endl;
520+ myfile << fmt::format (" {0:>20}\t {1:20}\t {2}\t {3}\t {4} " , " ref_token" , " hyp_token" , " IsErr" , " Class" , " Wer_Tag_Entities " ) << endl;
518521
519522 // keep track of error groupings
520523 ErrorGroups groups_err;
@@ -525,10 +528,15 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
525528 std::set<std::string> op_set = {" <ins>" , " <del>" , " <sub>" };
526529
527530 size_t offset = 2 ; // line number in output file where first triple starts
528- while (visitor.NextTriple (tk_pair)) {
529- string tk_classLabel = tk_pair->classLabel ;
530- string ref_tk = tk_pair->ref ;
531- string hyp_tk = tk_pair->hyp ;
531+ for (auto p_stitch: stitches) {
532+ string tk_classLabel = p_stitch->classLabel ;
533+ string tk_wer_tags = " " ;
534+ auto wer_tags = p_stitch->nlpRow .wer_tags ;
535+ for (auto wer_tag: wer_tags) {
536+ tk_wer_tags = tk_wer_tags + wer_tag + " |" ;
537+ }
538+ string ref_tk = p_stitch->reftk ;
539+ string hyp_tk = p_stitch->hyptk ;
532540 string tag = " " ;
533541
534542 if (ref_tk == NOOP) {
@@ -560,7 +568,7 @@ void WriteSbs(spWERA topAlignment, string sbs_filename) {
560568 eff_class = tk_classLabel;
561569 }
562570
563- myfile << fmt::format (" {0:>20}\t {1:20}\t {2}\t {3}" , ref_tk, hyp_tk, tag, eff_class) << endl;
571+ myfile << fmt::format (" {0:>20}\t {1:20}\t {2}\t {3}\t {4} " , ref_tk, hyp_tk, tag, eff_class, tk_wer_tags ) << endl;
564572 offset++;
565573 }
566574
0 commit comments