Skip to content

Commit f9ceded

Browse files
committed
Clean up scalar special casing
1 parent 0ea1aa1 commit f9ceded

File tree

1 file changed

+39
-56
lines changed

1 file changed

+39
-56
lines changed

src/stan_math_backend/Expression_gen.ml

Lines changed: 39 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -200,82 +200,65 @@ and pp_logical_op ppf op lhs rhs =
200200
pp_expr lhs op pp_expr rhs
201201

202202
and pp_unary ppf fm es = pf ppf fm pp_expr (List.hd_exn es)
203-
and pp_binary ppf fm es = pf ppf fm pp_expr (first es) pp_expr (second es)
203+
204+
and pp_binary_op ppf op es =
205+
pf ppf "(%a@ %s@ %a)" pp_expr (first es) op pp_expr (second es)
204206

205207
and pp_binary_f ppf f es =
206208
pf ppf "%s(@,%a,@ %a)" f pp_expr (first es) pp_expr (second es)
207209

208210
and first es = List.nth_exn es 0
209211
and second es = List.nth_exn es 1
210212

211-
and pp_scalar_binary ppf scalar_fmt generic_fmt es =
212-
pp_binary ppf
213-
( if is_scalar (first es) && is_scalar (second es) then scalar_fmt
214-
else generic_fmt )
215-
es
213+
and pp_scalar_binary ppf op fn es =
214+
if is_scalar (first es) && is_scalar (second es) then pp_binary_op ppf op es
215+
else pp_binary_f ppf fn es
216216

217-
and gen_operator_app = function
218-
| Operator.Plus ->
219-
fun ppf es ->
220-
pp_scalar_binary ppf "(%a@ +@ %a)" "stan::math::add(@,%a,@ %a)" es
217+
and gen_operator_app op ppf es =
218+
match op with
219+
| Operator.Plus -> pp_scalar_binary ppf "+" "stan::math::add" es
221220
| PMinus ->
222-
fun ppf es ->
223-
pp_unary ppf
224-
( if is_scalar (List.hd_exn es) then "-%a"
225-
else "stan::math::minus(@,%a)" )
226-
es
227-
| PPlus -> fun ppf es -> pp_unary ppf "%a" es
221+
pp_unary ppf
222+
(if is_scalar (List.hd_exn es) then "-%a" else "stan::math::minus(@,%a)")
223+
es
224+
| PPlus -> pp_unary ppf "%a" es
228225
| Transpose ->
229-
fun ppf es ->
230-
pp_unary ppf
231-
( if is_scalar (List.hd_exn es) then "%a"
232-
else "stan::math::transpose(@,%a)" )
233-
es
234-
| PNot -> fun ppf es -> pp_unary ppf "stan::math::logical_negation(@,%a)" es
235-
| Minus ->
236-
fun ppf es ->
237-
pp_scalar_binary ppf "(%a@ -@ %a)" "stan::math::subtract(@,%a,@ %a)" es
238-
| Times ->
239-
fun ppf es ->
240-
pp_scalar_binary ppf "(%a@ *@ %a)" "stan::math::multiply(@,%a,@ %a)" es
226+
pp_unary ppf
227+
( if is_scalar (List.hd_exn es) then "%a"
228+
else "stan::math::transpose(@,%a)" )
229+
es
230+
| PNot -> pp_unary ppf "stan::math::logical_negation(@,%a)" es
231+
| Minus -> pp_scalar_binary ppf "-" "stan::math::subtract" es
232+
| Times -> pp_scalar_binary ppf "*" "stan::math::multiply" es
241233
| Divide | IntDivide ->
242-
fun ppf es ->
243-
if
244-
is_matrix (second es)
245-
&& (is_matrix (first es) || is_row_vector (first es))
246-
then pp_binary_f ppf "stan::math::mdivide_right" es
247-
else
248-
pp_scalar_binary ppf "(%a@ /@ %a)" "stan::math::divide(@,%a,@ %a)" es
249-
| Modulo -> fun ppf es -> pp_binary_f ppf "stan::math::modulus" es
250-
| LDivide -> fun ppf es -> pp_binary_f ppf "stan::math::mdivide_left" es
234+
if
235+
is_matrix (second es)
236+
&& (is_matrix (first es) || is_row_vector (first es))
237+
then pp_binary_f ppf "stan::math::mdivide_right" es
238+
else pp_scalar_binary ppf "/" "stan::math::divide" es
239+
| Modulo -> pp_binary_f ppf "stan::math::modulus" es
240+
| LDivide -> pp_binary_f ppf "stan::math::mdivide_left" es
251241
| And | Or ->
252242
Common.FatalError.fatal_error_msg
253243
[%message "And/Or should have been converted to an expression"]
254-
| EltTimes ->
255-
fun ppf es ->
256-
pp_scalar_binary ppf "(%a@ *@ %a)" "stan::math::elt_multiply(@,%a,@ %a)"
257-
es
258-
| EltDivide ->
259-
fun ppf es ->
260-
pp_scalar_binary ppf "(%a@ /@ %a)" "stan::math::elt_divide(@,%a,@ %a)"
261-
es
262-
| Pow -> fun ppf es -> pp_binary_f ppf "stan::math::pow" es
263-
| EltPow -> fun ppf es -> pp_binary_f ppf "stan::math::pow" es
264-
| Equals -> fun ppf es -> pp_binary_f ppf "stan::math::logical_eq" es
265-
| NEquals -> fun ppf es -> pp_binary_f ppf "stan::math::logical_neq" es
266-
| Less -> fun ppf es -> pp_binary_f ppf "stan::math::logical_lt" es
267-
| Leq -> fun ppf es -> pp_binary_f ppf "stan::math::logical_lte" es
268-
| Greater -> fun ppf es -> pp_binary_f ppf "stan::math::logical_gt" es
269-
| Geq -> fun ppf es -> pp_binary_f ppf "stan::math::logical_gte" es
244+
| EltTimes -> pp_scalar_binary ppf "*" "stan::math::elt_multiply" es
245+
| EltDivide -> pp_scalar_binary ppf "/" "stan::math::elt_divide" es
246+
| Pow -> pp_binary_f ppf "stan::math::pow" es
247+
| EltPow -> pp_binary_f ppf "stan::math::pow" es
248+
| Equals -> pp_binary_f ppf "stan::math::logical_eq" es
249+
| NEquals -> pp_binary_f ppf "stan::math::logical_neq" es
250+
| Less -> pp_binary_f ppf "stan::math::logical_lt" es
251+
| Leq -> pp_binary_f ppf "stan::math::logical_lte" es
252+
| Greater -> pp_binary_f ppf "stan::math::logical_gt" es
253+
| Geq -> pp_binary_f ppf "stan::math::logical_gte" es
270254

271255
and gen_misc_special_math_app f =
272256
match f with
273257
| "lmultiply" ->
274-
Some (fun ppf es -> pp_binary ppf "stan::math::multiply_log(@,%a,@ %a)" es)
258+
Some (fun ppf es -> pp_binary_f ppf "stan::math::multiply_log" es)
275259
| "lchoose" ->
276260
Some
277-
(fun ppf es ->
278-
pp_binary ppf "stan::math::binomial_coefficient_log(@,%a,@ %a)" es )
261+
(fun ppf es -> pp_binary_f ppf "stan::math::binomial_coefficient_log" es)
279262
| "target" -> Some (fun ppf _ -> pf ppf "stan::math::get_lp(lp__, lp_accum__)")
280263
| "get_lp" -> Some (fun ppf _ -> pf ppf "stan::math::get_lp(lp__, lp_accum__)")
281264
| f when Map.mem fn_renames f ->

0 commit comments

Comments
 (0)