@@ -200,82 +200,65 @@ and pp_logical_op ppf op lhs rhs =
200200 pp_expr lhs op pp_expr rhs
201201
202202and pp_unary ppf fm es = pf ppf fm pp_expr (List. hd_exn es)
203- and pp_binary ppf fm es = pf ppf fm pp_expr (first es) pp_expr (second es)
203+
204+ and pp_binary_op ppf op es =
205+ pf ppf " (%a@ %s@ %a)" pp_expr (first es) op pp_expr (second es)
204206
205207and pp_binary_f ppf f es =
206208 pf ppf " %s(@,%a,@ %a)" f pp_expr (first es) pp_expr (second es)
207209
208210and first es = List. nth_exn es 0
209211and second es = List. nth_exn es 1
210212
211- and pp_scalar_binary ppf scalar_fmt generic_fmt es =
212- pp_binary ppf
213- ( if is_scalar (first es) && is_scalar (second es) then scalar_fmt
214- else generic_fmt )
215- es
213+ and pp_scalar_binary ppf op fn es =
214+ if is_scalar (first es) && is_scalar (second es) then pp_binary_op ppf op es
215+ else pp_binary_f ppf fn es
216216
217- and gen_operator_app = function
218- | Operator. Plus ->
219- fun ppf es ->
220- pp_scalar_binary ppf " (%a@ +@ %a)" " stan::math::add(@,%a,@ %a)" es
217+ and gen_operator_app op ppf es =
218+ match op with
219+ | Operator. Plus -> pp_scalar_binary ppf " +" " stan::math::add" es
221220 | PMinus ->
222- fun ppf es ->
223- pp_unary ppf
224- ( if is_scalar (List. hd_exn es) then " -%a"
225- else " stan::math::minus(@,%a)" )
226- es
227- | PPlus -> fun ppf es -> pp_unary ppf " %a" es
221+ pp_unary ppf
222+ (if is_scalar (List. hd_exn es) then " -%a" else " stan::math::minus(@,%a)" )
223+ es
224+ | PPlus -> pp_unary ppf " %a" es
228225 | Transpose ->
229- fun ppf es ->
230- pp_unary ppf
231- ( if is_scalar (List. hd_exn es) then " %a"
232- else " stan::math::transpose(@,%a)" )
233- es
234- | PNot -> fun ppf es -> pp_unary ppf " stan::math::logical_negation(@,%a)" es
235- | Minus ->
236- fun ppf es ->
237- pp_scalar_binary ppf " (%a@ -@ %a)" " stan::math::subtract(@,%a,@ %a)" es
238- | Times ->
239- fun ppf es ->
240- pp_scalar_binary ppf " (%a@ *@ %a)" " stan::math::multiply(@,%a,@ %a)" es
226+ pp_unary ppf
227+ ( if is_scalar (List. hd_exn es) then " %a"
228+ else " stan::math::transpose(@,%a)" )
229+ es
230+ | PNot -> pp_unary ppf " stan::math::logical_negation(@,%a)" es
231+ | Minus -> pp_scalar_binary ppf " -" " stan::math::subtract" es
232+ | Times -> pp_scalar_binary ppf " *" " stan::math::multiply" es
241233 | Divide | IntDivide ->
242- fun ppf es ->
243- if
244- is_matrix (second es)
245- && (is_matrix (first es) || is_row_vector (first es))
246- then pp_binary_f ppf " stan::math::mdivide_right" es
247- else
248- pp_scalar_binary ppf " (%a@ /@ %a)" " stan::math::divide(@,%a,@ %a)" es
249- | Modulo -> fun ppf es -> pp_binary_f ppf " stan::math::modulus" es
250- | LDivide -> fun ppf es -> pp_binary_f ppf " stan::math::mdivide_left" es
234+ if
235+ is_matrix (second es)
236+ && (is_matrix (first es) || is_row_vector (first es))
237+ then pp_binary_f ppf " stan::math::mdivide_right" es
238+ else pp_scalar_binary ppf " /" " stan::math::divide" es
239+ | Modulo -> pp_binary_f ppf " stan::math::modulus" es
240+ | LDivide -> pp_binary_f ppf " stan::math::mdivide_left" es
251241 | And | Or ->
252242 Common.FatalError. fatal_error_msg
253243 [% message " And/Or should have been converted to an expression" ]
254- | EltTimes ->
255- fun ppf es ->
256- pp_scalar_binary ppf " (%a@ *@ %a)" " stan::math::elt_multiply(@,%a,@ %a)"
257- es
258- | EltDivide ->
259- fun ppf es ->
260- pp_scalar_binary ppf " (%a@ /@ %a)" " stan::math::elt_divide(@,%a,@ %a)"
261- es
262- | Pow -> fun ppf es -> pp_binary_f ppf " stan::math::pow" es
263- | EltPow -> fun ppf es -> pp_binary_f ppf " stan::math::pow" es
264- | Equals -> fun ppf es -> pp_binary_f ppf " stan::math::logical_eq" es
265- | NEquals -> fun ppf es -> pp_binary_f ppf " stan::math::logical_neq" es
266- | Less -> fun ppf es -> pp_binary_f ppf " stan::math::logical_lt" es
267- | Leq -> fun ppf es -> pp_binary_f ppf " stan::math::logical_lte" es
268- | Greater -> fun ppf es -> pp_binary_f ppf " stan::math::logical_gt" es
269- | Geq -> fun ppf es -> pp_binary_f ppf " stan::math::logical_gte" es
244+ | EltTimes -> pp_scalar_binary ppf " *" " stan::math::elt_multiply" es
245+ | EltDivide -> pp_scalar_binary ppf " /" " stan::math::elt_divide" es
246+ | Pow -> pp_binary_f ppf " stan::math::pow" es
247+ | EltPow -> pp_binary_f ppf " stan::math::pow" es
248+ | Equals -> pp_binary_f ppf " stan::math::logical_eq" es
249+ | NEquals -> pp_binary_f ppf " stan::math::logical_neq" es
250+ | Less -> pp_binary_f ppf " stan::math::logical_lt" es
251+ | Leq -> pp_binary_f ppf " stan::math::logical_lte" es
252+ | Greater -> pp_binary_f ppf " stan::math::logical_gt" es
253+ | Geq -> pp_binary_f ppf " stan::math::logical_gte" es
270254
271255and gen_misc_special_math_app f =
272256 match f with
273257 | "lmultiply" ->
274- Some (fun ppf es -> pp_binary ppf " stan::math::multiply_log(@,%a,@ %a) " es)
258+ Some (fun ppf es -> pp_binary_f ppf " stan::math::multiply_log" es)
275259 | "lchoose" ->
276260 Some
277- (fun ppf es ->
278- pp_binary ppf " stan::math::binomial_coefficient_log(@,%a,@ %a)" es )
261+ (fun ppf es -> pp_binary_f ppf " stan::math::binomial_coefficient_log" es)
279262 | "target" -> Some (fun ppf _ -> pf ppf " stan::math::get_lp(lp__, lp_accum__)" )
280263 | "get_lp" -> Some (fun ppf _ -> pf ppf " stan::math::get_lp(lp__, lp_accum__)" )
281264 | f when Map. mem fn_renames f ->
0 commit comments