-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdropblock-gpu-inl.h
More file actions
70 lines (51 loc) · 2.02 KB
/
dropblock-gpu-inl.h
File metadata and controls
70 lines (51 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
//
// Created by yijie.yu on 2019/2/25.
//
#ifndef DROPBLOCK_CPP_DROPBLOCK_GPU_H
#define DROPBLOCK_CPP_DROPBLOCK_GPU_H
#include <vector>
#include <utility>
#include "../mshadow_op.h"
#include "../tensor/init_op.h"
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <map>
#include <string>
#include <algorithm>
#include "../mxnet_op.h"
#include "../random/sampler.h"
#include "../tensor/elemwise_binary_broadcast_op.h"
namespace mxnet {
namespace op {
// Declare enumeration of input order to make code more intuitive.
// These enums are only visible within this header
namespace gpudropblock {
enum GPUDropblockOpInputs {kData};
enum GPUDropblockOpOutputs {kOut,kMask};
enum GPUDropblockOpForwardResource {kRandom};
enum GPUDropblockOpMode {kTraining,kAlways};
} // namespace dropblock
struct GPUDropblockParam : public dmlc::Parameter<GPUDropblockParam> {
real_t p;
int mode;
int block_size;
TShape axes;
DMLC_DECLARE_PARAMETER(GPUDropblockParam) {
DMLC_DECLARE_FIELD(p).set_default(0.5)
.set_range(0,1)
.describe("Fraction of the input that gets dropped out during training time.");
DMLC_DECLARE_FIELD(block_size).set_default(3)
.describe("the block size");
DMLC_DECLARE_FIELD(mode)
.add_enum("training",gpudropblock::kTraining)
.add_enum("always",gpudropblock::kAlways)
.set_default(gpudropblock::kTraining)
.describe("Whether to only turn on dropblock during training or to also turn on for inference.");
DMLC_DECLARE_FIELD(axes).set_default(TShape())
.describe("Axes for variational dropblock kernel.");
}
};
} // namespace op
} // namespace mxnet
#endif //DROPBLOCK_CPP_DROPBLOCK_GPU_H