forked from NVIDIA/TransformerEngine
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathactivation_template.h
More file actions
72 lines (59 loc) · 2.78 KB
/
activation_template.h
File metadata and controls
72 lines (59 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file activation_template.h
* \brief Activation functions template.
*/
#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_