@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
14- #include < gflags/gflags.h>
15- #include < glog/logging.h>
16-
14+ #include " paddle/fluid/framework/operator.h"
1715#include < algorithm>
18-
16+ #include < sstream>
17+ #include < string>
18+ #include < vector>
19+ #include " gflags/gflags.h"
20+ #include " glog/logging.h"
1921#include " paddle/fluid/framework/data_transform.h"
2022#include " paddle/fluid/framework/executor.h"
2123#include " paddle/fluid/framework/lod_tensor.h"
22- #include " paddle/fluid/framework/operator .h"
24+ #include " paddle/fluid/framework/op_proto_maker .h"
2325#include " paddle/fluid/framework/shape_inference.h"
2426#include " paddle/fluid/framework/var_type.h"
2527#include " paddle/fluid/platform/profiler.h"
@@ -137,19 +139,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
137139}
138140
139141void OperatorBase::Run (const Scope& scope, const platform::Place& place) {
140- VLOG (4 ) << place << " " << DebugStringEx (&scope);
141- if (platform::is_gpu_place (place)) {
142+ try {
143+ if (VLOG_IS_ON (4 )) {
144+ VLOG (4 ) << place << " " << DebugStringEx (&scope);
145+ }
146+ if (platform::is_gpu_place (place)) {
142147#ifndef PADDLE_WITH_CUDA
143- PADDLE_THROW (" Cannot run operator on place %s" , place);
148+ PADDLE_THROW (" Cannot run operator on place %s" , place);
144149#else
145- auto dev_id = boost::get<platform::CUDAPlace>(place).device ;
146- platform::SetDeviceId (dev_id);
150+ auto dev_id = boost::get<platform::CUDAPlace>(place).device ;
151+ platform::SetDeviceId (dev_id);
147152#endif
153+ }
154+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
155+ platform::RecordEvent record_event (Type (), pool.Get (place));
156+ RunImpl (scope, place);
157+ if (VLOG_IS_ON (3 )) {
158+ VLOG (3 ) << place << " " << DebugStringEx (&scope);
159+ }
160+ } catch (platform::EnforceNotMet exception) {
161+ if (Attrs ().count (" sub_block" ) != 0 ) {
162+ throw exception;
163+ }
164+
165+ auto & callstack = Attr<std::vector<std::string>>(
166+ OpProtoAndCheckerMaker::OpCreationCallstackAttrName ());
167+
168+ if (callstack.empty ()) {
169+ throw exception;
170+ }
171+ std::ostringstream sout;
172+ sout << " Invoke operator " << Type () << " error.\n " ;
173+ sout << " Python Callstacks: \n " ;
174+ for (auto & line : callstack) {
175+ sout << line;
176+ }
177+ sout << " C++ Callstacks: \n " ;
178+ sout << exception.err_str_ ;
179+ exception.err_str_ = sout.str ();
180+ throw exception;
181+ } catch (...) {
182+ std::rethrow_exception (std::current_exception ());
148183 }
149- platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
150- platform::RecordEvent record_event (Type (), pool.Get (place));
151- RunImpl (scope, place);
152- VLOG (3 ) << place << " " << DebugStringEx (&scope);
153184}
154185
155186bool OperatorBase::HasInputs (const std::string& name) const {
@@ -177,7 +208,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
177208}
178209
179210bool OperatorBase::HasOutputs (const std::string& name) const {
180- if (outputs_.find (name ) != outputs_.end ( )) {
211+ if (outputs_.end ( ) != outputs_.find (name )) {
181212 return true ;
182213 } else {
183214 return false ;
0 commit comments