Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ let is_eigen_type ut =

let is_fun_type = function UFun _ | UMathLibraryFunction -> true | _ -> false

let rec contains_int ut =
match ut with
| UFun (_, Void, _, _)
|UReal | UComplex | UVector | URowVector | UMatrix | UMathLibraryFunction ->
false
| UInt -> true
| UArray x | UFun (_, ReturnType x, _, _) -> contains_int x

let rec is_indexing_matrix = function
| UArray t, _ :: idcs -> is_indexing_matrix (t, idcs)
| UMatrix, [] -> false
Expand Down
84 changes: 49 additions & 35 deletions src/stan_math_backend/Stan_math_code_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ let pp_located ppf _ =
{|stan::lang::rethrow_located(e, locations_array__[current_statement__]);|}

(** Detect if argument requires C++ template *)
let arg_needs_template = function
let arg_needs_template arg =
match arg with
| UnsizedType.DataOnly, _, t -> UnsizedType.is_eigen_type t
| _, _, t when UnsizedType.is_int_type t -> false
| _ -> true
Expand All @@ -64,22 +65,23 @@ let arg_needs_template = function
@return A list of arguments with template parameter names added.
*)
let maybe_templated_arg_types (args : Program.fun_arg_decl) =
List.mapi args ~f:(fun i a ->
match arg_needs_template a with
| true -> Some (sprintf "T%d__" i)
| false -> None )
List.mapi args ~f:(fun i (adtype, _, ut) ->
match ut with
| UMatrix | UVector | URowVector -> Some [sprintf "T%d__" i]
| (UReal | UComplex) when adtype = AutoDiffable -> Some [sprintf "T%d__" i]
| UArray _ -> Some [sprintf "T%d__" i; sprintf "Alloc%d__" i]
| UInt | UReal | UComplex | UMathLibraryFunction | UFun _ -> None )

let return_arg_types (args : Program.fun_arg_decl) =
List.mapi args ~f:(fun i ((_, _, ut) as a) ->
if UnsizedType.is_eigen_type ut && arg_needs_template a then
Some (sprintf "stan::value_type_t<T%d__>" i)
if not (UnsizedType.is_scalar_type ut) then Some (sprintf "T%d__" i)
else if arg_needs_template a then Some (sprintf "T%d__" i)
else None )

let%expect_test "arg types templated correctly" =
[(AutoDiffable, "xreal", UReal); (DataOnly, "yint", UInt)]
|> maybe_templated_arg_types |> List.filter_opt |> String.concat ~sep:","
|> print_endline ;
|> maybe_templated_arg_types |> List.filter_opt |> List.concat
|> String.concat ~sep:"," |> print_endline ;
[%expect {| T0__ |}]

(** Print the code for promoting stan real types
Expand All @@ -90,27 +92,24 @@ let pp_promoted_scalar ppf args =
match args with
| [] -> pf ppf "double"
| _ ->
let rec promote_args_chunked ppf args =
let go ppf tl =
match tl with [] -> () | _ -> pf ppf ", %a" promote_args_chunked tl
in
match args with
| [] -> pf ppf "double"
| hd :: tl ->
pf ppf "stan::promote_args_t<%a%a>" (list ~sep:comma string) hd go
tl in
promote_args_chunked ppf
List.(chunks_of ~length:5 (filter_opt (return_arg_types args)))
let blah init xx =
match xx with
| Some x when init <> "" -> String.concat ~sep:", " [init; x]
| Some x -> String.concat ~sep:", " [x]
| None -> init in
pf ppf "stan::return_type_t<%s>"
(List.fold ~init:"" ~f:blah (return_arg_types args))

(** Pretty-prints a function's return-type, taking into account templated argument
promotion.*)
let pp_returntype ppf arg_types rt =
let scalar = str "%a" pp_promoted_scalar arg_types in
match rt with
| Some ut when UnsizedType.is_int_type ut ->
pf ppf "%a@," pp_unsizedtype_custom_scalar ("int", ut)
| Some ut -> pf ppf "%a@," pp_unsizedtype_custom_scalar (scalar, ut)
| None -> pf ppf "void@,"
| Some ut when UnsizedType.contains_int ut ->
pf ppf "inline %a@," pp_unsizedtype_custom_scalar ("int", ut)
| Some ut when UnsizedType.is_scalar_type ut -> pf ppf "inline auto@,"
| Some ut -> pf ppf "inline %a@," pp_unsizedtype_custom_scalar (scalar, ut)
| None -> pf ppf "inline void@,"

let pp_eigen_arg_to_ref ppf arg_types =
let pp_ref ppf name =
Expand All @@ -133,6 +132,22 @@ let pp_located_error ppf (pp_body_block, body) =
string ppf " catch (const std::exception& e) " ;
pp_block ppf (pp_located, ())

(**
* Print the types used in the C++ function signature.
* For most types we'll simply use the template typename given
* such as `T{id}__, but for std::vector's we will specialize
* the function by wrapping the joint template parameters
* (`T{id}__, Alloc{id}__`) around `std::vector<{Templates}>.
*)
let pp_arg_types ppf (scalar, ut) =
match ut with
| UnsizedType.UInt | UReal | UComplex | UMatrix | URowVector | UVector ->
string ppf scalar
| UArray _ ->
(* Expressions are not accepted for arrays of Eigen::Matrix *)
pf ppf "std::vector<%s>" scalar
| x -> raise_s [%message (x : UnsizedType.t) "not implemented yet"]

(** Print the type of an object.
@param ppf A pretty printer
@param custom_scalar_opt A string representing a types inner scalar value.
Expand All @@ -145,8 +160,7 @@ let pp_arg ppf (custom_scalar_opt, (_, name, ut)) =
| Some scalar -> scalar
| None -> stantype_prim_str ut in
(* we add the _arg suffix for any Eigen types *)
pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut)
name
pf ppf "const %a& %s" pp_arg_types (scalar, ut) name

let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) =
let scalar =
Expand All @@ -156,8 +170,7 @@ let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) =
(* we add the _arg suffix for any Eigen types *)
let opt_arg_suffix =
if UnsizedType.is_eigen_type ut then name ^ "_arg__" else name in
pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut)
opt_arg_suffix
pf ppf "const %a& %s" pp_arg_types (scalar, ut) opt_arg_suffix

(** [pp_located_error_b] automatically adds a Block wrapper *)
let pp_located_error_b ppf body_stmts =
Expand All @@ -171,16 +184,17 @@ let typename = ( ^ ) "typename "
@param fdargs A sexp list of strings representing C++ types.
*)
let get_templates_and_args exprs fdargs =
let argtypetemplates = maybe_templated_arg_types fdargs in
( List.filter_opt argtypetemplates
let argtype_templates = maybe_templated_arg_types fdargs in
let templates =
List.map ~f:(Option.map ~f:(String.concat ~sep:", ")) argtype_templates
in
( List.concat (List.filter_opt argtype_templates)
, if not exprs then
List.map
~f:(fun a -> str "%a" pp_arg a)
(List.zip_exn argtypetemplates fdargs)
List.map ~f:(fun a -> strf "%a" pp_arg a) (List.zip_exn templates fdargs)
else
List.map
~f:(fun a -> str "%a" pp_arg_eigen_suffix a)
(List.zip_exn argtypetemplates fdargs) )
~f:(fun a -> strf "%a" pp_arg_eigen_suffix a)
(List.zip_exn templates fdargs) )

(** Print the C++ template parameter decleration before a function.
@param ppf A pretty printer.
Expand Down
Loading