@@ -14,15 +14,17 @@ limitations under the License. */
1414#define GLOG_NO_ABBREVIATED_SEVERITIES
1515#define GOOGLE_GLOG_DLL_DECL
1616
17+ #include " paddle/fluid/framework/operator.h"
1718#include < gflags/gflags.h>
1819#include < glog/logging.h>
19-
2020#include < algorithm>
21-
21+ #include < sstream>
22+ #include < string>
23+ #include < vector>
2224#include " paddle/fluid/framework/data_transform.h"
2325#include " paddle/fluid/framework/executor.h"
2426#include " paddle/fluid/framework/lod_tensor.h"
25- #include " paddle/fluid/framework/operator .h"
27+ #include " paddle/fluid/framework/op_proto_maker .h"
2628#include " paddle/fluid/framework/shape_inference.h"
2729#include " paddle/fluid/framework/var_type.h"
2830#include " paddle/fluid/platform/profiler.h"
@@ -140,19 +142,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
140142}
141143
142144void OperatorBase::Run (const Scope& scope, const platform::Place& place) {
143- VLOG (4 ) << place << " " << DebugStringEx (&scope);
144- if (platform::is_gpu_place (place)) {
145+ try {
146+ if (VLOG_IS_ON (4 )) {
147+ VLOG (4 ) << place << " " << DebugStringEx (&scope);
148+ }
149+ if (platform::is_gpu_place (place)) {
145150#ifndef PADDLE_WITH_CUDA
146- PADDLE_THROW (" Cannot run operator on place %s" , place);
151+ PADDLE_THROW (" Cannot run operator on place %s" , place);
147152#else
148- auto dev_id = boost::get<platform::CUDAPlace>(place).device ;
149- platform::SetDeviceId (dev_id);
153+ auto dev_id = boost::get<platform::CUDAPlace>(place).device ;
154+ platform::SetDeviceId (dev_id);
150155#endif
156+ }
157+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
158+ platform::RecordEvent record_event (Type (), pool.Get (place));
159+ RunImpl (scope, place);
160+ if (VLOG_IS_ON (3 )) {
161+ VLOG (3 ) << place << " " << DebugStringEx (&scope);
162+ }
163+ } catch (platform::EnforceNotMet exception) {
164+ if (Attrs ().count (" sub_block" ) != 0 ) {
165+ throw exception;
166+ }
167+
168+ auto & callstack = Attr<std::vector<std::string>>(
169+ OpProtoAndCheckerMaker::OpCreationCallstackAttrName ());
170+
171+ if (callstack.empty ()) {
172+ throw exception;
173+ }
174+ std::ostringstream sout;
175+ sout << " Invoke operator " << Type () << " error.\n " ;
176+ sout << " Python Callstacks: \n " ;
177+ for (auto & line : callstack) {
178+ sout << line;
179+ }
180+ sout << " C++ Callstacks: \n " ;
181+ sout << exception.err_str_ ;
182+ exception.err_str_ = sout.str ();
183+ throw exception;
184+ } catch (...) {
185+ std::rethrow_exception (std::current_exception ());
151186 }
152- platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
153- platform::RecordEvent record_event (Type (), pool.Get (place));
154- RunImpl (scope, place);
155- VLOG (3 ) << place << " " << DebugStringEx (&scope);
156187}
157188
158189bool OperatorBase::HasInputs (const std::string& name) const {
@@ -180,7 +211,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
180211}
181212
182213bool OperatorBase::HasOutputs (const std::string& name) const {
183- if (outputs_.find (name ) != outputs_.end ( )) {
214+ if (outputs_.end ( ) != outputs_.find (name )) {
184215 return true ;
185216 } else {
186217 return false ;
0 commit comments