Skip to content

Commit 98ae65d

Browse files
committed
MNT: Add more global state I missed to the thread_unsafe_state struct
1 parent c237038 commit 98ae65d

File tree

6 files changed

+48
-99
lines changed

6 files changed

+48
-99
lines changed

numpy/_core/src/multiarray/arrayobject.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ maintainer email: [email protected]
6262

6363
#include "binop_override.h"
6464
#include "array_coercion.h"
65-
66-
67-
NPY_NO_EXPORT npy_bool numpy_warn_if_no_mem_policy = 0;
65+
#include "multiarraymodule.h"
6866

6967
/*NUMPY_API
7068
Compute the size of an array (in number of items)
@@ -429,7 +427,7 @@ array_dealloc(PyArrayObject *self)
429427
}
430428
}
431429
if (fa->mem_handler == NULL) {
432-
if (numpy_warn_if_no_mem_policy) {
430+
if (npy_thread_unsafe_state.warn_if_no_mem_policy) {
433431
char const *msg = "Trying to dealloc data, but a memory policy "
434432
"is not set. If you take ownership of the data, you must "
435433
"set a base owning the data (e.g. a PyCapsule).";

numpy/_core/src/multiarray/arrayobject.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
extern "C" {
1010
#endif
1111

12-
extern NPY_NO_EXPORT npy_bool numpy_warn_if_no_mem_policy;
13-
1412
NPY_NO_EXPORT PyObject *
1513
_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
1614
int rstrip);

numpy/_core/src/multiarray/descriptor.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
2424
#include "descriptor.h"
2525
#include "npy_static_data.h"
26-
#include "multiarraymodule.h"
26+
#include "multiarraymodule.h" // for thread unsafe state access
2727
#include "alloc.h"
2828
#include "assert.h"
2929
#include "npy_buffer.h"

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 12 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,15 @@ NPY_NO_EXPORT PyObject *
9898
_umath_strings_richcompare(
9999
PyArrayObject *self, PyArrayObject *other, int cmp_op, int rstrip);
100100

101-
/*
102-
* global variable to determine if legacy printing is enabled, accessible from
103-
* C. For simplicity the mode is encoded as an integer where INT_MAX means no
104-
* legacy mode, and '113'/'121' means 1.13/1.21 legacy mode; and 0 maps to
105-
* INT_MAX. We can upgrade this if we have more complex requirements in the
106-
* future.
107-
*/
108-
int npy_legacy_print_mode = INT_MAX;
109-
110101

111102
static PyObject *
112103
set_legacy_print_mode(PyObject *NPY_UNUSED(self), PyObject *args)
113104
{
114-
if (!PyArg_ParseTuple(args, "i", &npy_legacy_print_mode)) {
105+
if (!PyArg_ParseTuple(args, "i", &npy_thread_unsafe_state.legacy_print_mode)) {
115106
return NULL;
116107
}
117-
if (!npy_legacy_print_mode) {
118-
npy_legacy_print_mode = INT_MAX;
108+
if (!npy_thread_unsafe_state.legacy_print_mode) {
109+
npy_thread_unsafe_state.legacy_print_mode = INT_MAX;
119110
}
120111
Py_RETURN_NONE;
121112
}
@@ -4333,8 +4324,8 @@ _set_numpy_warn_if_no_mem_policy(PyObject *NPY_UNUSED(self), PyObject *arg)
43334324
if (res < 0) {
43344325
return NULL;
43354326
}
4336-
int old_value = numpy_warn_if_no_mem_policy;
4337-
numpy_warn_if_no_mem_policy = res;
4327+
int old_value = npy_thread_unsafe_state.warn_if_no_mem_policy;
4328+
npy_thread_unsafe_state.warn_if_no_mem_policy = res;
43384329
if (old_value) {
43394330
Py_RETURN_TRUE;
43404331
}
@@ -4774,69 +4765,13 @@ static int
47744765
initialize_thread_unsafe_state(void) {
47754766
char *env = getenv("NUMPY_WARN_IF_NO_MEM_POLICY");
47764767
if ((env != NULL) && (strncmp(env, "1", 1) == 0)) {
4777-
numpy_warn_if_no_mem_policy = 1;
4768+
npy_thread_unsafe_state.warn_if_no_mem_policy = 1;
47784769
}
47794770
else {
4780-
numpy_warn_if_no_mem_policy = 0;
4781-
}
4782-
4783-
// default_truediv_type_tup
4784-
PyArray_Descr *tmp = PyArray_DescrFromType(NPY_DOUBLE);
4785-
if (tmp == NULL) {
4786-
return -1;
4787-
}
4788-
4789-
npy_static_pydata.default_truediv_type_tup =
4790-
PyTuple_Pack(3, tmp, tmp, tmp);
4791-
if (npy_static_pydata.default_truediv_type_tup == NULL) {
4792-
Py_DECREF(tmp);
4793-
return -1;
4794-
}
4795-
Py_DECREF(tmp);
4796-
4797-
PyObject *flags = PySys_GetObject("flags"); /* borrowed object */
4798-
if (flags == NULL) {
4799-
PyErr_SetString(PyExc_AttributeError, "cannot get sys.flags");
4800-
return -1;
4801-
}
4802-
PyObject *level = PyObject_GetAttrString(flags, "optimize");
4803-
if (level == NULL) {
4804-
return -1;
4805-
}
4806-
npy_static_cdata.optimize = PyLong_AsLong(level);
4807-
Py_DECREF(level);
4808-
4809-
/*
4810-
* see unpack_bits for how this table is used.
4811-
*
4812-
* LUT for bigendian bitorder, littleendian is handled via
4813-
* byteswapping in the loop.
4814-
*
4815-
* 256 8 byte blocks representing 8 bits expanded to 1 or 0 bytes
4816-
*/
4817-
npy_intp j;
4818-
for (j=0; j < 256; j++) {
4819-
npy_intp k;
4820-
for (k=0; k < 8; k++) {
4821-
npy_uint8 v = (j & (1 << k)) == (1 << k);
4822-
npy_static_cdata.unpack_lookup_big[j].bytes[7 - k] = v;
4823-
}
4824-
}
4825-
4826-
npy_static_pydata.kwnames_is_copy = Py_BuildValue("(s)", "copy");
4827-
if (npy_static_pydata.kwnames_is_copy == NULL) {
4828-
return -1;
4771+
npy_thread_unsafe_state.warn_if_no_mem_policy = 0;
48294772
}
48304773

4831-
npy_static_pydata.one_obj = PyLong_FromLong((long) 1);
4832-
if (npy_static_pydata.one_obj == NULL) {
4833-
return -1;
4834-
}
4835-
4836-
npy_static_pydata.zero_obj = PyLong_FromLong((long) 0);
4837-
if (npy_static_pydata.zero_obj == NULL) {
4838-
return -1;
4839-
}
4774+
npy_thread_unsafe_state.legacy_print_mode = INT_MAX;
48404775

48414776
return 0;
48424777
}
@@ -4903,6 +4838,10 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
49034838
goto err;
49044839
}
49054840

4841+
if (initialize_thread_unsafe_state() < 0) {
4842+
goto err;
4843+
}
4844+
49064845
if (init_extobj() < 0) {
49074846
goto err;
49084847
}

numpy/_core/src/multiarray/multiarraymodule.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ typedef struct npy_thread_unsafe_state_struct {
7070
* used to detect module reloading in the reload guard
7171
*/
7272
int reload_guard_initialized;
73+
74+
/*
75+
* global variable to determine if legacy printing is enabled,
76+
* accessible from C. For simplicity the mode is encoded as an
77+
* integer where INT_MAX means no legacy mode, and '113'/'121'
78+
* means 1.13/1.21 legacy mode; and 0 maps to INT_MAX. We can
79+
* upgrade this if we have more complex requirements in the future.
80+
*/
81+
int legacy_print_mode;
82+
83+
/*
84+
* Holds the user-defined setting for whether or not to warn
85+
* if there is no memory policy set
86+
*/
87+
int warn_if_no_mem_policy;
88+
7389
} npy_thread_unsafe_state_struct;
7490

7591

numpy/_core/src/multiarray/scalartypes.c.src

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@
3333
#include "dragon4.h"
3434
#include "npy_longdouble.h"
3535
#include "npy_buffer.h"
36+
#include "npy_static_data.h"
3637
#include "multiarraymodule.h"
3738

3839
#include <stdlib.h>
3940

4041
#include "binop_override.h"
4142

42-
/* determines if legacy mode is enabled, global set in multiarraymodule.c */
43-
extern int npy_legacy_print_mode;
44-
4543
/*
4644
* used for allocating a single scalar, so use the default numpy
4745
* memory allocators instead of the (maybe) user overrides
@@ -338,7 +336,7 @@ genint_type_repr(PyObject *self)
338336
if (value_string == NULL) {
339337
return NULL;
340338
}
341-
if (npy_legacy_print_mode <= 125) {
339+
if (npy_thread_unsafe_state.legacy_print_mode <= 125) {
342340
return value_string;
343341
}
344342

@@ -375,7 +373,7 @@ genbool_type_str(PyObject *self)
375373
static PyObject *
376374
genbool_type_repr(PyObject *self)
377375
{
378-
if (npy_legacy_print_mode <= 125) {
376+
if (npy_thread_unsafe_state.legacy_print_mode <= 125) {
379377
return genbool_type_str(self);
380378
}
381379
return PyUnicode_FromString(
@@ -501,7 +499,7 @@ stringtype_@form@(PyObject *self)
501499
if (ret == NULL) {
502500
return NULL;
503501
}
504-
if (npy_legacy_print_mode > 125) {
502+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
505503
Py_SETREF(ret, PyUnicode_FromFormat("np.bytes_(%S)", ret));
506504
}
507505
#endif /* IS_repr */
@@ -548,7 +546,7 @@ unicodetype_@form@(PyObject *self)
548546
if (ret == NULL) {
549547
return NULL;
550548
}
551-
if (npy_legacy_print_mode > 125) {
549+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
552550
Py_SETREF(ret, PyUnicode_FromFormat("np.str_(%S)", ret));
553551
}
554552
#endif /* IS_repr */
@@ -629,7 +627,7 @@ voidtype_repr(PyObject *self)
629627
/* Python helper checks for the legacy mode printing */
630628
return _void_scalar_to_string(self, 1);
631629
}
632-
if (npy_legacy_print_mode > 125) {
630+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
633631
return _void_to_hex(s->obval, s->descr->elsize, "np.void(b'", "\\x", "')");
634632
}
635633
else {
@@ -681,7 +679,7 @@ datetimetype_repr(PyObject *self)
681679
*/
682680
if ((scal->obmeta.num == 1 && scal->obmeta.base != NPY_FR_h) ||
683681
scal->obmeta.base == NPY_FR_GENERIC) {
684-
if (npy_legacy_print_mode > 125) {
682+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
685683
ret = PyUnicode_FromFormat("np.datetime64('%s')", iso);
686684
}
687685
else {
@@ -693,7 +691,7 @@ datetimetype_repr(PyObject *self)
693691
if (meta == NULL) {
694692
return NULL;
695693
}
696-
if (npy_legacy_print_mode > 125) {
694+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
697695
ret = PyUnicode_FromFormat("np.datetime64('%s','%S')", iso, meta);
698696
}
699697
else {
@@ -737,7 +735,7 @@ timedeltatype_repr(PyObject *self)
737735

738736
/* The metadata unit */
739737
if (scal->obmeta.base == NPY_FR_GENERIC) {
740-
if (npy_legacy_print_mode > 125) {
738+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
741739
ret = PyUnicode_FromFormat("np.timedelta64(%S)", val);
742740
}
743741
else {
@@ -750,7 +748,7 @@ timedeltatype_repr(PyObject *self)
750748
Py_DECREF(val);
751749
return NULL;
752750
}
753-
if (npy_legacy_print_mode > 125) {
751+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
754752
ret = PyUnicode_FromFormat("np.timedelta64(%S,'%S')", val, meta);
755753
}
756754
else {
@@ -1052,7 +1050,7 @@ static PyObject *
10521050
npy_bool sign)
10531051
{
10541052

1055-
if (npy_legacy_print_mode <= 113) {
1053+
if (npy_thread_unsafe_state.legacy_print_mode <= 113) {
10561054
return legacy_@name@_format@kind@(val);
10571055
}
10581056

@@ -1083,7 +1081,7 @@ static PyObject *
10831081
if (string == NULL) {
10841082
return NULL;
10851083
}
1086-
if (npy_legacy_print_mode > 125) {
1084+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
10871085
Py_SETREF(string, PyUnicode_FromFormat("@repr_format@", string));
10881086
}
10891087
#endif /* IS_repr */
@@ -1098,7 +1096,7 @@ c@name@type_@kind@(PyObject *self)
10981096
npy_c@name@ val = PyArrayScalar_VAL(self, C@Name@);
10991097
TrimMode trim = TrimMode_DptZeros;
11001098

1101-
if (npy_legacy_print_mode <= 113) {
1099+
if (npy_thread_unsafe_state.legacy_print_mode <= 113) {
11021100
return legacy_c@name@_format@kind@(val);
11031101
}
11041102

@@ -1111,7 +1109,7 @@ c@name@type_@kind@(PyObject *self)
11111109
#ifdef IS_str
11121110
ret = PyUnicode_FromFormat("%Sj", istr);
11131111
#else /* IS_repr */
1114-
if (npy_legacy_print_mode <= 125) {
1112+
if (npy_thread_unsafe_state.legacy_print_mode <= 125) {
11151113
ret = PyUnicode_FromFormat("%Sj", istr);
11161114
}
11171115
else {
@@ -1159,7 +1157,7 @@ c@name@type_@kind@(PyObject *self)
11591157
#ifdef IS_str
11601158
string = PyUnicode_FromFormat("(%S%Sj)", rstr, istr);
11611159
#else /* IS_repr */
1162-
if (npy_legacy_print_mode > 125) {
1160+
if (npy_thread_unsafe_state.legacy_print_mode > 125) {
11631161
string = PyUnicode_FromFormat("@crepr_format@", rstr, istr);
11641162
}
11651163
else {
@@ -1184,7 +1182,7 @@ halftype_@kind@(PyObject *self)
11841182
float floatval = npy_half_to_float(val);
11851183
float absval;
11861184

1187-
if (npy_legacy_print_mode <= 113) {
1185+
if (npy_thread_unsafe_state.legacy_print_mode <= 113) {
11881186
return legacy_float_format@kind@(floatval);
11891187
}
11901188

@@ -1200,7 +1198,7 @@ halftype_@kind@(PyObject *self)
12001198
#ifdef IS_str
12011199
return string;
12021200
#else
1203-
if (string == NULL || npy_legacy_print_mode <= 125) {
1201+
if (string == NULL || npy_thread_unsafe_state.legacy_print_mode <= 125) {
12041202
return string;
12051203
}
12061204
PyObject *res = PyUnicode_FromFormat("np.float16(%S)", string);

0 commit comments

Comments
 (0)