Skip to content

Commit 5e16aaa

Browse files
authored
Merge pull request #54 from sintel-dev/interface_pipeline_update
Update interface of SigPro w/pipelines and primitives
2 parents 6e12f4c + 9ced51d commit 5e16aaa

File tree

15 files changed

+3645
-28
lines changed

15 files changed

+3645
-28
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
runs-on: ${{ matrix.os }}
3030
strategy:
3131
matrix:
32-
python-version: ['3.8', '3.9', '3.10', '3.11']
32+
python-version: ['3.8']
3333
os: [ubuntu-20.04]
3434
steps:
3535
- uses: actions/checkout@v1

sigpro/basic_primitives.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# -*- coding: utf-8 -*-
2+
"""Reference class implementations of existing primitives."""
3+
from sigpro import contributing, primitive
4+
5+
# Transformations
6+
7+
8+
class Identity(primitive.AmplitudeTransformation):
9+
"""Identity primitive class."""
10+
11+
def __init__(self):
12+
super().__init__('sigpro.transformations.amplitude.identity.identity')
13+
14+
15+
class PowerSpectrum(primitive.AmplitudeTransformation):
16+
"""PowerSpectrum primitive class."""
17+
18+
def __init__(self):
19+
super().__init__('sigpro.transformations.amplitude.spectrum.power_spectrum')
20+
primitive_spec = contributing._get_primitive_spec('transformation', 'frequency')
21+
self.set_primitive_inputs(primitive_spec['args'])
22+
self.set_primitive_outputs(primitive_spec['output'])
23+
24+
25+
class FFT(primitive.FrequencyTransformation):
26+
"""FFT primitive class."""
27+
28+
def __init__(self):
29+
super().__init__("sigpro.transformations.frequency.fft.fft")
30+
31+
32+
class FFTReal(primitive.FrequencyTransformation):
33+
"""FFTReal primitive class."""
34+
35+
def __init__(self):
36+
super().__init__("sigpro.transformations.frequency.fft.fft_real")
37+
38+
39+
class FrequencyBand(primitive.FrequencyTransformation):
40+
"""
41+
FrequencyBand primitive class.
42+
43+
Filter between a high and low band frequency and return the amplitude values and
44+
frequency values for those.
45+
46+
Args:
47+
low (int): Lower band frequency of filter.
48+
high (int): Higher band frequency of filter.
49+
"""
50+
51+
def __init__(self, low, high):
52+
super().__init__("sigpro.transformations.frequency.band.frequency_band",
53+
init_params={'low': low, 'high': high})
54+
self.set_primitive_inputs([{"name": "amplitude_values", "type": "numpy.ndarray"},
55+
{"name": "frequency_values", "type": "numpy.ndarray"}])
56+
self.set_primitive_outputs([{'name': 'amplitude_values', 'type': "numpy.ndarray"},
57+
{'name': 'frequency_values', 'type': "numpy.ndarray"}])
58+
self.set_fixed_hyperparameters({'low': {'type': 'int'}, 'high': {'type': 'int'}})
59+
60+
61+
class STFT(primitive.FrequencyTimeTransformation):
62+
"""STFT primitive class."""
63+
64+
def __init__(self):
65+
super().__init__('sigpro.transformations.frequency_time.stft.stft')
66+
self.set_primitive_outputs([{"name": "amplitude_values", "type": "numpy.ndarray"},
67+
{"name": "frequency_values", "type": "numpy.ndarray"},
68+
{"name": "time_values", "type": "numpy.ndarray"}])
69+
70+
71+
class STFTReal(primitive.FrequencyTimeTransformation):
72+
"""STFTReal primitive class."""
73+
74+
def __init__(self):
75+
super().__init__('sigpro.transformations.frequency_time.stft.stft_real')
76+
self.set_primitive_outputs([{"name": "real_amplitude_values", "type": "numpy.ndarray"},
77+
{"name": "frequency_values", "type": "numpy.ndarray"},
78+
{"name": "time_values", "type": "numpy.ndarray"}])
79+
80+
# Aggregations
81+
82+
83+
class CrestFactor(primitive.AmplitudeAggregation):
84+
"""CrestFactor primitive class."""
85+
86+
def __init__(self):
87+
super().__init__('sigpro.aggregations.amplitude.statistical.crest_factor')
88+
self.set_primitive_outputs([{'name': 'crest_factor_value', 'type': "float"}])
89+
90+
91+
class Kurtosis(primitive.AmplitudeAggregation):
92+
"""
93+
Kurtosis primitive class.
94+
95+
Computes the kurtosis value of the input array. If all values are equal, return
96+
`-3` for Fisher's definition and `0` for Pearson's definition.
97+
98+
Args:
99+
fisher (bool):
100+
If ``True``, Fisher’s definition is used (normal ==> 0.0). If ``False``,
101+
Pearson’s definition is used (normal ==> 3.0). Defaults to ``True``.
102+
bias (bool):
103+
If ``False``, then the calculations are corrected for statistical bias.
104+
Defaults to ``True``.
105+
"""
106+
107+
def __init__(self, fisher=True, bias=True):
108+
super().__init__('sigpro.aggregations.amplitude.statistical.kurtosis',
109+
init_params={'fisher': fisher, 'bias': bias})
110+
self.set_primitive_outputs([{'name': 'kurtosis_value', 'type': "float"}])
111+
self.set_fixed_hyperparameters({'fisher': {'type': 'bool', 'default': True},
112+
'bias': {'type': 'bool', 'default': True}})
113+
114+
115+
class Mean(primitive.AmplitudeAggregation):
116+
"""Mean primitive class."""
117+
118+
def __init__(self):
119+
super().__init__('sigpro.aggregations.amplitude.statistical.mean')
120+
self.set_primitive_outputs([{'name': 'mean_value', 'type': "float"}])
121+
122+
123+
class RMS(primitive.AmplitudeAggregation):
124+
"""RMS primitive class."""
125+
126+
def __init__(self):
127+
super().__init__('sigpro.aggregations.amplitude.statistical.rms')
128+
self.set_primitive_outputs([{'name': 'rms_value', 'type': "float"}])
129+
130+
131+
class Skew(primitive.AmplitudeAggregation):
132+
"""Skew primitive class."""
133+
134+
def __init__(self):
135+
super().__init__('sigpro.aggregations.amplitude.statistical.skew')
136+
self.set_primitive_outputs([{'name': 'skew_value', 'type': "float"}])
137+
138+
139+
class Std(primitive.AmplitudeAggregation):
140+
"""Std primitive class."""
141+
142+
def __init__(self):
143+
super().__init__('sigpro.aggregations.amplitude.statistical.std')
144+
self.set_primitive_outputs([{'name': 'std_value', 'type': "float"}])
145+
146+
147+
class Var(primitive.AmplitudeAggregation):
148+
"""Var primitive class."""
149+
150+
def __init__(self):
151+
super().__init__('sigpro.aggregations.amplitude.statistical.var')
152+
self.set_primitive_outputs([{'name': 'var_value', 'type': "float"}])
153+
154+
155+
class BandMean(primitive.FrequencyAggregation):
156+
"""
157+
BandMean primitive class.
158+
159+
Filters between a high and low band and compute the mean value for this specific band.
160+
161+
Args:
162+
min_frequency (int or float):
163+
Band minimum.
164+
max_frequency (int or float):
165+
Band maximum.
166+
"""
167+
168+
def __init__(self, min_frequency, max_frequency):
169+
super().__init__('sigpro.aggregations.frequency.band.band_mean', init_params={
170+
'min_frequency': min_frequency, 'max_frequency': max_frequency})
171+
self.set_fixed_hyperparameters({'min_frequency': {'type': 'float'},
172+
'max_frequency': {'type': 'float'}})

sigpro/contributing.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,11 @@ def _write_primitive(primitive_dict, primitive_name, primitives_path, primitives
277277
return primitive_path
278278

279279

280-
def make_primitive(primitive, primitive_type, primitive_subtype,
281-
context_arguments=None, fixed_hyperparameters=None,
282-
tunable_hyperparameters=None, primitive_outputs=None,
283-
primitives_path='sigpro/primitives', primitives_subfolders=True):
284-
"""Create a primitive JSON.
285-
286-
During the JSON creation the primitive function signature is validated to
287-
ensure that it matches the primitive type and subtype implicitly specified
288-
by the primitive name.
289-
290-
Any additional function arguments are also validated to ensure that the
291-
function does actually expect them.
280+
def _make_primitive_dict(primitive, primitive_type, primitive_subtype,
281+
context_arguments=None, fixed_hyperparameters=None,
282+
tunable_hyperparameters=None, primitive_inputs=None,
283+
primitive_outputs=None):
284+
"""Create a primitive dict.
292285
293286
Args:
294287
primitive (str):
@@ -308,30 +301,27 @@ def make_primitive(primitive, primitive_type, primitive_subtype,
308301
A dictionary containing as key the name of the hyperparameter and as
309302
value a dictionary containing the type and the default value and the
310303
range of values that it can take.
304+
primitive_inputs (list or None):
305+
A list with dictionaries containing the name and type of the input values. If
306+
``None`` default values for those will be used.
311307
primitive_outputs (list or None):
312308
A list with dictionaries containing the name and type of the output values. If
313309
``None`` default values for those will be used.
314-
primitives_path (str):
315-
Path to the root of the primitives folder, in which the primitives JSON will be stored.
316-
Defaults to `sigpro/primitives`.
317-
primitives_subfolders (bool):
318-
Whether to store the primitive JSON in a subfolder tree (``True``) or to use a flat
319-
primitive name (``False``). Defaults to ``True``.
320310
321311
Raises:
322312
ValueError:
323313
If the primitive specification arguments are not valid.
324314
325315
Returns:
326-
str:
327-
Path of the generated JSON file.
316+
dict:
317+
Generated JSON file as a Python dict.
328318
"""
329319
context_arguments = context_arguments or []
330320
fixed_hyperparameters = fixed_hyperparameters or {}
331321
tunable_hyperparameters = tunable_hyperparameters or {}
332322

333323
primitive_spec = _get_primitive_spec(primitive_type, primitive_subtype)
334-
primitive_inputs = primitive_spec['args']
324+
primitive_inputs = primitive_inputs or primitive_spec['args']
335325
primitive_outputs = primitive_outputs or primitive_spec['output']
336326

337327
primitive_function = _import_object(primitive)
@@ -366,6 +356,69 @@ def make_primitive(primitive, primitive_type, primitive_subtype,
366356
}
367357
}
368358

359+
return primitive_dict
360+
361+
# pylint: disable = too-many-arguments
362+
363+
364+
def make_primitive(primitive, primitive_type, primitive_subtype,
365+
context_arguments=None, fixed_hyperparameters=None,
366+
tunable_hyperparameters=None, primitive_inputs=None,
367+
primitive_outputs=None, primitives_path='sigpro/primitives',
368+
primitives_subfolders=True):
369+
"""Create a primitive JSON.
370+
371+
During the JSON creation the primitive function signature is validated to
372+
ensure that it matches the primitive type and subtype implicitly specified
373+
by the primitive name.
374+
375+
Any additional function arguments are also validated to ensure that the
376+
function does actually expect them.
377+
378+
Args:
379+
primitive (str):
380+
The name of the primitive, the python path including the name of the
381+
module and the name of the function.
382+
primitive_type (str):
383+
Type of primitive.
384+
primitive_subtype (str):
385+
Subtype of the primitive.
386+
context_arguments (list or None):
387+
A list with dictionaries containing the name and type of the context arguments.
388+
fixed_hyperparameters (dict or None):
389+
A dictionary containing as key the name of the hyperparameter and as
390+
value a dictionary containing the type and the default value that it
391+
should take.
392+
tunable_hyperparameters (dict or None):
393+
A dictionary containing as key the name of the hyperparameter and as
394+
value a dictionary containing the type and the default value and the
395+
range of values that it can take.
396+
primitive_inputs (list or None):
397+
A list with dictionaries containing the name and type of the input values. If
398+
``None`` default values for those will be used.
399+
primitive_outputs (list or None):
400+
A list with dictionaries containing the name and type of the output values. If
401+
``None`` default values for those will be used.
402+
primitives_path (str):
403+
Path to the root of the primitives folder, in which the primitives JSON will be stored.
404+
Defaults to `sigpro/primitives`.
405+
primitives_subfolders (bool):
406+
Whether to store the primitive JSON in a subfolder tree (``True``) or to use a flat
407+
primitive name (``False``). Defaults to ``True``.
408+
409+
Raises:
410+
ValueError:
411+
If the primitive specification arguments are not valid.
412+
413+
Returns:
414+
str:
415+
Path of the generated JSON file.
416+
"""
417+
primitive_dict = _make_primitive_dict(primitive, primitive_type, primitive_subtype,
418+
context_arguments, fixed_hyperparameters,
419+
tunable_hyperparameters, primitive_inputs,
420+
primitive_outputs)
421+
369422
return _write_primitive(primitive_dict, primitive, primitives_path, primitives_subfolders)
370423

371424

0 commit comments

Comments
 (0)