Skip to content

Commit e7b1159

Browse files
authored
torch_xla: Use new macros for throwing exceptions (part 2). (#9594)
Follow-up: #9588 and #9580 Target: `torch_xla/csrc` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc` directory, replacing every use of those, now deprecated, functions by the newly introduced macros. _Note: since there were lots of files in `torch_xla/csrc` that needed update, they were split in multiple parts._
1 parent f5a2218 commit e7b1159

File tree

4 files changed

+132
-115
lines changed

4 files changed

+132
-115
lines changed

torch_xla/csrc/aten_autograd_ops.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ torch::Tensor EinsumAutogradFunction::forward(
3434
}
3535
ctx->save_for_backward(vars);
3636

37-
std::vector<XLATensorPtr> xla_tensors =
38-
GetValueOrThrow(bridge::GetXlaTensors(tensors));
37+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_tensors,
38+
bridge::GetXlaTensors(tensors));
3939
XLATensorPtr output = tensor_methods::einsum(eq_str, xla_tensors);
4040
return bridge::AtenFromXlaTensor(output);
4141
}
@@ -45,13 +45,12 @@ torch::autograd::variable_list EinsumAutogradFunction::backward(
4545
torch::autograd::variable_list grad_output) {
4646
std::string equation = ctx->saved_data["equation"].toString()->string();
4747
torch::autograd::variable_list tensors = ctx->get_saved_variables();
48-
std::vector<XLATensorPtr> xla_tensors =
49-
GetValueOrThrow(bridge::GetXlaTensors(tensors));
50-
48+
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_tensors,
49+
bridge::GetXlaTensors(tensors));
50+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0,
51+
bridge::GetXlaTensor(grad_output[0]));
5152
std::tuple<XLATensorPtr, XLATensorPtr> outputs =
52-
tensor_methods::einsum_backward(
53-
GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])), xla_tensors,
54-
equation);
53+
tensor_methods::einsum_backward(xla_grad_output_0, xla_tensors, equation);
5554

5655
// For both einsum and max pool, we use "undef" as a placeholder for the
5756
// non-tensor grad inputs, in this case the equation string.
@@ -193,10 +192,10 @@ torch::Tensor MaxPool3dAutogradFunction::forward(
193192
return std::get<0>(results);
194193
}
195194
ctx->save_for_backward({self});
195+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
196196
auto outputs = tensor_methods::max_pool_nd(
197-
GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3,
198-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
199-
XlaHelpers::I64List(padding), ceil_mode);
197+
xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
198+
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
200199
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
201200
}
202201

@@ -221,11 +220,13 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
221220
padding, dilation,
222221
ceil_mode, indices);
223222
}
223+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0,
224+
bridge::GetXlaTensor(grad_output[0]));
225+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
224226
grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
225-
GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])),
226-
GetValueOrThrow(bridge::GetXlaTensor(self)),
227-
/*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
228-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));
227+
xla_grad_output_0, xla_self, /*spatial_dim_count=*/3,
228+
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
229+
XlaHelpers::I64List(padding), ceil_mode));
229230

230231
torch::Tensor undef;
231232
torch::autograd::variable_list grad_inputs = {grad, undef, undef,
@@ -238,22 +239,24 @@ torch::Tensor max_pool2d_forward(torch::Tensor self,
238239
torch::IntArrayRef stride,
239240
torch::IntArrayRef padding,
240241
torch::IntArrayRef dilation, bool ceil_mode) {
242+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
241243
auto outputs = tensor_methods::max_pool_nd(
242-
GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2,
243-
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
244-
XlaHelpers::I64List(padding), ceil_mode);
244+
xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
245+
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
245246
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
246247
}
247248

248249
torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
249250
torch::IntArrayRef kernel_size,
250251
torch::IntArrayRef stride,
251252
torch::IntArrayRef padding, bool ceil_mode) {
253+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
254+
bridge::GetXlaTensor(grad_output));
255+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
252256
auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
253-
GetValueOrThrow(bridge::GetXlaTensor(grad_output)),
254-
GetValueOrThrow(bridge::GetXlaTensor(self)),
255-
/*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
256-
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));
257+
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
258+
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
259+
XlaHelpers::I64List(padding), ceil_mode));
257260
return grad;
258261
}
259262

torch_xla/csrc/ir_dump_util.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,15 @@ std::string DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value> values,
264264
// Annotate HLO sharding selectively in the compuation.
265265
// This is no-op if an instruction doesn't have any sharding annotation.
266266
auto is_sharded = ShardingUtil::SetHloSharding(&lowering_ctx);
267-
xla::XlaComputation computation = GetValueOrThrow(lowering_ctx.BuildXla());
267+
XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla());
268268

269269
static bool dump_post_optimizations =
270270
runtime::sys_util::GetEnvBool("XLA_DUMP_POST_OPTIMIZATIONS", false);
271271
if (dump_post_optimizations) {
272+
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
273+
computation.GetProgramShape());
272274
xla::Shape shape = MakeShapeWithDeviceLayout(
273-
GetValueOrThrow(computation.GetProgramShape()).result(),
274-
static_cast<XlaDeviceType>(device.type()));
275+
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
275276
std::vector<runtime::ComputationClient::CompileInstance> instances;
276277
instances.push_back(
277278
{std::move(computation), device.toString(),
@@ -286,12 +287,17 @@ std::string DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value> values,
286287
}
287288

288289
switch (mode) {
289-
case EmitMode::kHloReadable:
290-
return GetValueOrThrow(runtime::util::GetComputationHloText(computation));
291-
case EmitMode::kHloProto:
292-
return GetValueOrThrow(
293-
runtime::util::GetDeterministicSerializedModuleProto(
294-
computation.proto()));
290+
case EmitMode::kHloReadable: {
291+
XLA_ASSIGN_OR_THROW(std::string hlo_text,
292+
runtime::util::GetComputationHloText(computation));
293+
return hlo_text;
294+
}
295+
case EmitMode::kHloProto: {
296+
XLA_ASSIGN_OR_THROW(std::string serialized_proto,
297+
runtime::util::GetDeterministicSerializedModuleProto(
298+
computation.proto()));
299+
return serialized_proto;
300+
}
295301
case EmitMode::kStableHloReadable:
296302
return hloToStablehlo(&computation.proto(),
297303
/* emit_bytecode = */ false);

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,8 @@ std::vector<at::Tensor> XLAGraphExecutor::GetTensors(
497497
async != nullptr ? async->tensors_data
498498
: absl::Span<const torch::lazy::BackendDataPtr>());
499499

500-
std::vector<xla::Literal> literals =
501-
GetValueOrThrow(ReleaseGilAndTransferData(tensors_data));
500+
XLA_ASSIGN_OR_THROW(std::vector<xla::Literal> literals,
501+
ReleaseGilAndTransferData(tensors_data));
502502

503503
return FetchTensors(tensors, literals,
504504
async != nullptr ? &async->indices : nullptr);
@@ -846,12 +846,12 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
846846
// OutputHandler creates sharded data for sharded
847847
// tensor results. Both sharded and unsharded results should be
848848
// "Assign"ed to the corresponding data placeholders.
849-
std::vector<runtime::ComputationClient::DataPtr> outputs =
850-
GetValueOrThrow(
851-
runtime::GetComputationClientOrDie()->ExecuteReplicated(
852-
*async->cached_computation->computation,
853-
UnwrapXlaData(async->parameters_data), devices,
854-
execute_options));
849+
XLA_ASSIGN_OR_THROW(
850+
std::vector<runtime::ComputationClient::DataPtr> outputs,
851+
runtime::GetComputationClientOrDie()->ExecuteReplicated(
852+
*async->cached_computation->computation,
853+
UnwrapXlaData(async->parameters_data), devices,
854+
execute_options));
855855
results = WrapXlaData(outputs);
856856
TF_VLOG(3) << "Executing Dynamo IR sharded graph hash "
857857
<< torch::lazy::HashToString(hash) << " on devices "
@@ -913,8 +913,8 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::ExecuteStablehlo(
913913

914914
// Get program output shape.
915915
// TODO(lsy323): Get shape info from MLIR Module.
916-
xla::ProgramShape program_shape =
917-
GetValueOrThrow(computation.GetProgramShape());
916+
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
917+
computation.GetProgramShape());
918918
xla::Shape shape = MakeShapeWithDeviceLayout(
919919
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
920920

@@ -946,8 +946,9 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::ExecuteStablehlo(
946946
}
947947
}
948948

949-
std::vector<runtime::ComputationClient::DataPtr> result_data =
950-
GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation(
949+
XLA_ASSIGN_OR_THROW(
950+
std::vector<runtime::ComputationClient::DataPtr> result_data,
951+
runtime::GetComputationClientOrDie()->ExecuteComputation(
951952
*computations[0], UnwrapXlaData(arguments), device.toString()));
952953

953954
return WrapXlaData(result_data);
@@ -1123,12 +1124,12 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
11231124
// OutputHandler creates sharded data for sharded
11241125
// tensor results. Both sharded and unsharded results should be
11251126
// "Assign"ed to the corresponding data placeholders.
1126-
std::vector<runtime::ComputationClient::DataPtr> outputs =
1127-
GetValueOrThrow(
1128-
runtime::GetComputationClientOrDie()->ExecuteReplicated(
1129-
*async->cached_computation->computation,
1130-
UnwrapXlaData(async->parameters_data), devices,
1131-
execute_options));
1127+
XLA_ASSIGN_OR_THROW(
1128+
std::vector<runtime::ComputationClient::DataPtr> outputs,
1129+
runtime::GetComputationClientOrDie()->ExecuteReplicated(
1130+
*async->cached_computation->computation,
1131+
UnwrapXlaData(async->parameters_data), devices,
1132+
execute_options));
11321133
results = WrapXlaData(outputs);
11331134
TORCH_LAZY_COUNTER("ExecuteReplicated", 1);
11341135
TF_VLOG(3) << "Executing IR graph hash "
@@ -1139,14 +1140,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
11391140
TF_VLOG(3) << "Executing IR graph hash "
11401141
<< torch::lazy::HashToString(hash) << " on device "
11411142
<< async->device << " ...";
1142-
std::vector<runtime::ComputationClient::DataPtr> outputs =
1143-
GetValueOrThrow(
1144-
runtime::GetComputationClientOrDie()->ExecuteComputation(
1145-
*async->cached_computation->computation,
1146-
UnwrapXlaData(async->parameters_data),
1147-
async->device.toString(),
1148-
{/*explode_tuple=*/true,
1149-
/*eager_mode=*/use_eager_mode}));
1143+
XLA_ASSIGN_OR_THROW(
1144+
std::vector<runtime::ComputationClient::DataPtr> outputs,
1145+
runtime::GetComputationClientOrDie()->ExecuteComputation(
1146+
*async->cached_computation->computation,
1147+
UnwrapXlaData(async->parameters_data), async->device.toString(),
1148+
{/*explode_tuple=*/true,
1149+
/*eager_mode=*/use_eager_mode}));
11501150
results = WrapXlaData(outputs);
11511151
TORCH_LAZY_COUNTER("ExecuteComputation", 1);
11521152
TF_VLOG(3) << "Executing IR graph hash "
@@ -1416,9 +1416,9 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
14161416

14171417
SetBufferDonors(&lowering_ctx, buffer_donor_indices);
14181418

1419-
xla::XlaComputation computation = GetValueOrThrow(lowering_ctx.BuildXla());
1420-
xla::ProgramShape program_shape =
1421-
GetValueOrThrow(computation.GetProgramShape());
1419+
XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla());
1420+
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
1421+
computation.GetProgramShape());
14221422

14231423
// TODO(yeounoh) enable wrapping with auto-sharding.
14241424
bool should_wrap_parameter =
@@ -1435,10 +1435,11 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
14351435
param_shardings = XlaHelpers::ExtractInputShardings(computation);
14361436
}
14371437

1438-
computation = GetValueOrThrow(
1438+
XLA_ASSIGN_OR_THROW(
1439+
computation,
14391440
XlaHelpers::WrapXlaComputation(computation, program_shape.parameters(),
14401441
param_shardings, buffer_donor_indices));
1441-
program_shape = GetValueOrThrow(computation.GetProgramShape());
1442+
XLA_ASSIGN_OR_THROW(program_shape, computation.GetProgramShape());
14421443
}
14431444
xla::Shape shape = MakeShapeWithDeviceLayout(
14441445
program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 60 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ ConditionMaskData CreateConditionMaskData(xla::XlaOp condition) {
6262

6363
xla::XlaOp GetPromotedMask(xla::XlaOp mask, const xla::Shape& input_shape) {
6464
const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask);
65-
xla::Shape promoted_mask_shape =
66-
GetValueOrThrow(XlaHelpers::GetPromotedShape(mask_shape, input_shape));
65+
XLA_ASSIGN_OR_THROW(xla::Shape promoted_mask_shape,
66+
XlaHelpers::GetPromotedShape(mask_shape, input_shape));
6767
return XlaHelpers::ImplicitBroadcast(mask, mask_shape, promoted_mask_shape);
6868
}
6969

@@ -150,7 +150,9 @@ xla::XlaComputation MakeScatterComputation(
150150
if (combiner != nullptr) {
151151
result = combiner(p0, result);
152152
}
153-
return GetValueOrThrow(cb.Build(result));
153+
XLA_ASSIGN_OR_THROW(xla::XlaComputation scatter_computation,
154+
cb.Build(result));
155+
return scatter_computation;
154156
}
155157

156158
xla::XlaOp CreateIndexAlongDim(
@@ -543,8 +545,8 @@ std::vector<xla::XlaOp> CreateBroadcastTensors(
543545
for (const xla::XlaOp operand : operands) {
544546
const xla::Shape& operand_shape = ShapeHelper::ShapeOfXlaOp(operand);
545547
operand_shapes.push_back(operand_shape);
546-
result_shape = GetValueOrThrow(
547-
XlaHelpers::GetPromotedShape(result_shape, operand_shape));
548+
XLA_ASSIGN_OR_THROW(result_shape, XlaHelpers::GetPromotedShape(
549+
result_shape, operand_shape));
548550
}
549551
std::vector<xla::XlaOp> result;
550552
for (size_t i = 0; i < operands.size(); ++i) {
@@ -1366,54 +1368,59 @@ std::vector<xla::XlaOp> BuildBoxSelectionLoop(int64_t num_boxes,
13661368
// 3. The actual IoU threshold matrix.
13671369
init_values[2] = iou_threshold_mask;
13681370

1369-
return GetValueOrThrow(xla::WhileLoopHelper(
1370-
[=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1371-
xla::XlaOp box_index = values[0];
1372-
// Check: current loop counter is within bounds, i.e. has a
1373-
// corresponding box.
1374-
return xla::Lt(box_index,
1375-
xla::ConstantR0<IndexType>(builder, num_boxes));
1376-
},
1377-
[=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1378-
const xla::XlaOp ONE = xla::One(builder, XLAIndexType);
1379-
const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType);
1380-
1381-
xla::XlaOp box_index = values[0];
1382-
xla::XlaOp state = values[1];
1383-
xla::XlaOp iou_threshold_mask = values[2];
1384-
1385-
// Retrieve the IoU mask row corresponding to this box.
1386-
xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice(
1387-
iou_threshold_mask, {box_index, ZERO}, {1, num_boxes});
1388-
1389-
// Update the current state with the IoU mask.
1390-
// Basically, sets to false every box X whose IoU with the current box
1391-
// is less-than or equal than the given threshold.
1392-
xla::XlaOp updated_state = xla::And(
1393-
state,
1394-
// Update the mask so that if we select this box
1395-
// (i.e. state[box] == true), we don't de-select it.
1396-
xla::DynamicUpdateSlice(
1397-
// Before that, we need to pre-process the mask.
1398-
// 1. Negate the mask: if this box is selected, we only want
1399-
// those that have a low intersection ratio.
1400-
// 2. Reshape it to: [num_boxes].
1401-
xla::Reshape(xla::Not(box_iou_threshold_mask), {num_boxes}),
1402-
xla::ConstantR1<bool>(builder, {true}), {box_index}));
1403-
1404-
// Flag: should this box (loop counter) be included in the output?
1405-
xla::XlaOp should_include = xla::DynamicSlice(state, {box_index}, {1});
1406-
// Pick the new values of state, depending on whether we should include
1407-
// this box or not.
1408-
xla::XlaOp new_state =
1409-
xla::Select(xla::BroadcastInDim(should_include, {num_boxes}, {0}),
1410-
updated_state, state);
1411-
1412-
xla::XlaOp next_box_index = box_index + ONE;
1413-
return std::vector<xla::XlaOp>{next_box_index, new_state,
1414-
iou_threshold_mask};
1415-
},
1416-
init_values, "BoxSelectionLoop", builder));
1371+
XLA_ASSIGN_OR_THROW(
1372+
std::vector<xla::XlaOp> result,
1373+
xla::WhileLoopHelper(
1374+
[=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1375+
xla::XlaOp box_index = values[0];
1376+
// Check: current loop counter is within bounds, i.e. has a
1377+
// corresponding box.
1378+
return xla::Lt(box_index,
1379+
xla::ConstantR0<IndexType>(builder, num_boxes));
1380+
},
1381+
[=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1382+
const xla::XlaOp ONE = xla::One(builder, XLAIndexType);
1383+
const xla::XlaOp ZERO = xla::Zero(builder, XLAIndexType);
1384+
1385+
xla::XlaOp box_index = values[0];
1386+
xla::XlaOp state = values[1];
1387+
xla::XlaOp iou_threshold_mask = values[2];
1388+
1389+
// Retrieve the IoU mask row corresponding to this box.
1390+
xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice(
1391+
iou_threshold_mask, {box_index, ZERO}, {1, num_boxes});
1392+
1393+
// Update the current state with the IoU mask.
1394+
// Basically, sets to false every box X whose IoU with the current
1395+
// box is less-than or equal than the given threshold.
1396+
xla::XlaOp updated_state = xla::And(
1397+
state,
1398+
// Update the mask so that if we select this box
1399+
// (i.e. state[box] == true), we don't de-select it.
1400+
xla::DynamicUpdateSlice(
1401+
// Before that, we need to pre-process the mask.
1402+
// 1. Negate the mask: if this box is selected, we only
1403+
// want
1404+
// those that have a low intersection ratio.
1405+
// 2. Reshape it to: [num_boxes].
1406+
xla::Reshape(xla::Not(box_iou_threshold_mask), {num_boxes}),
1407+
xla::ConstantR1<bool>(builder, {true}), {box_index}));
1408+
1409+
// Flag: should this box (loop counter) be included in the output?
1410+
xla::XlaOp should_include =
1411+
xla::DynamicSlice(state, {box_index}, {1});
1412+
// Pick the new values of state, depending on whether we should
1413+
// include this box or not.
1414+
xla::XlaOp new_state = xla::Select(
1415+
xla::BroadcastInDim(should_include, {num_boxes}, {0}),
1416+
updated_state, state);
1417+
1418+
xla::XlaOp next_box_index = box_index + ONE;
1419+
return std::vector<xla::XlaOp>{next_box_index, new_state,
1420+
iou_threshold_mask};
1421+
},
1422+
init_values, "BoxSelectionLoop", builder));
1423+
return result;
14171424
}
14181425

14191426
xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores,

0 commit comments

Comments
 (0)