Skip to content

Commit 2eb34f3

Browse files
committed
Move fixtures to conftest, apply to more tests
1 parent 38517da commit 2eb34f3

File tree

7 files changed

+49
-59
lines changed

7 files changed

+49
-59
lines changed

tests/data/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
import json
3+
import pathlib
4+
from unittest import mock
5+
6+
RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
7+
GUITAR_SET_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "guitarset" / "dummy_index.json"))
8+
IKALA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "ikala" / "dummy_index.json"))
9+
MAESTRO_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "dummy_index.json"))
10+
METADATA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "maestro-v2.0.0.json"))
11+
MEDLEYDB_PITCH_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "medleydb_pitch" / "dummy_index.json"))
12+
13+
14+
@pytest.fixture # type: ignore[misc]
15+
def mock_medleydb_pitch_index() -> None: # type: ignore[misc]
16+
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset.download"):
17+
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset._index", new=MEDLEYDB_PITCH_TEST_INDEX):
18+
yield
19+
20+
21+
@pytest.fixture # type: ignore[misc]
22+
def mock_maestro_index() -> None: # type: ignore[misc]
23+
index_with_metadata = MAESTRO_TEST_INDEX
24+
index_with_metadata["metadata"] = METADATA_TEST_INDEX
25+
with mock.patch("mirdata.datasets.maestro.Dataset.download"):
26+
with mock.patch("mirdata.datasets.maestro.Dataset._index", new=index_with_metadata):
27+
yield
28+
29+
30+
@pytest.fixture # type: ignore[misc]
31+
def mock_guitarset_index() -> None: # type: ignore[misc]
32+
with mock.patch("mirdata.datasets.guitarset.Dataset.download"):
33+
with mock.patch("mirdata.datasets.guitarset.Dataset._index", new=GUITAR_SET_TEST_INDEX):
34+
yield
35+
36+
37+
@pytest.fixture # type: ignore[misc]
38+
def mock_ikala_index() -> None: # type: ignore[misc]
39+
with mock.patch("mirdata.datasets.ikala.Dataset.download"):
40+
with mock.patch("mirdata.datasets.ikala.Dataset._index", new=IKALA_TEST_INDEX):
41+
yield

tests/data/test_guitarset.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
from unittest import mock
1817
import apache_beam as beam
1918
import itertools
2019
import os
2120
import pathlib
2221
import shutil
23-
import pytest
24-
import json
2522
from apache_beam.testing.test_pipeline import TestPipeline
2623
from typing import List
2724

@@ -37,15 +34,6 @@
3734
RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
3835
TRACK_ID = "00_BN1-129-Eb_comp"
3936

40-
GUITAR_SET_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "guitarset" / "dummy_index.json"))
41-
42-
43-
@pytest.fixture # type: ignore[misc]
44-
def mock_guitarset_index() -> None: # type: ignore[misc]
45-
with mock.patch("mirdata.datasets.guitarset.Dataset.download"):
46-
with mock.patch("mirdata.datasets.guitarset.Dataset._index", new=GUITAR_SET_TEST_INDEX):
47-
yield
48-
4937

5038
def test_guitarset_to_tf_example(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
5139
mock_guitarset_home = tmp_path / "guitarset"

tests/data/test_ikala.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
import pytest
1817
import pathlib
19-
from unittest import mock
2018
import apache_beam as beam
2119
import itertools
2220
import os
23-
import json
2421
from apache_beam.testing.test_pipeline import TestPipeline
2522

2623
from basic_pitch.data.datasets.ikala import (
@@ -31,14 +28,6 @@
3128
# TODO: Create test_ikala_to_tf_example
3229

3330
RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
34-
IKALA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "ikala" / "dummy_index.json"))
35-
36-
37-
@pytest.fixture # type: ignore[misc]
38-
def mock_ikala_index() -> None: # type: ignore[misc]
39-
with mock.patch("mirdata.datasets.ikala.Dataset.download"):
40-
with mock.patch("mirdata.datasets.ikala.Dataset._index", new=IKALA_TEST_INDEX):
41-
yield
4231

4332

4433
def test_ikala_invalid_tracks(tmpdir: str) -> None:

tests/data/test_maestro.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
# limitations under the License.
1717
import os
1818
import pathlib
19-
import json
20-
import pytest
21-
from unittest import mock
2219
from typing import List
2320

2421
import apache_beam as beam
@@ -41,18 +38,6 @@
4138
TEST_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_08_Track08_wav"
4239
GT_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"
4340

44-
MAESTRO_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "dummy_index.json"))
45-
METADATA_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "maestro" / "maestro-v2.0.0.json"))
46-
47-
48-
@pytest.fixture # type: ignore[misc]
49-
def mock_maestro_index() -> None: # type: ignore[misc]
50-
index_with_metadata = MAESTRO_TEST_INDEX
51-
index_with_metadata["metadata"] = METADATA_TEST_INDEX
52-
with mock.patch("mirdata.datasets.maestro.Dataset.download"):
53-
with mock.patch("mirdata.datasets.maestro.Dataset._index", new=index_with_metadata):
54-
yield
55-
5641

5742
def test_maestro_to_tf_example(tmp_path: pathlib.Path, mock_maestro_index: None) -> None:
5843
mock_maestro_home = tmp_path / "maestro"

tests/data/test_medleydb_pitch.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
import apache_beam as beam
1818
import itertools
1919
import os
20-
import json
21-
import pytest
22-
import pathlib
23-
from unittest import mock
2420

2521
from apache_beam.testing.test_pipeline import TestPipeline
2622

@@ -32,16 +28,6 @@
3228

3329
# TODO: Create test_medleydb_pitch_to_tf_example
3430

35-
RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
36-
MEDLEYDB_PITCH_TEST_INDEX = json.load(open(RESOURCES_PATH / "data" / "medleydb_pitch" / "dummy_index.json"))
37-
38-
39-
@pytest.fixture # type: ignore[misc]
40-
def mock_medleydb_pitch_index() -> None: # type: ignore[misc]
41-
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset.download"):
42-
with mock.patch("mirdata.datasets.medleydb_pitch.Dataset._index", new=MEDLEYDB_PITCH_TEST_INDEX):
43-
yield
44-
4531

4632
def test_medleydb_pitch_invalid_tracks(tmpdir: str) -> None:
4733
split_labels = ["train", "validation"]

tests/data/test_slakh.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def create_mock_input_data(data_home: pathlib.Path, input_data: List[Tuple[str,
7979
shutil.copy(SLAKH_PATH / split / track_num / "metadata.yaml", track_dir / "metadata.yaml")
8080

8181

82-
def test_slakh_to_tf_example(tmp_path: pathlib.Path) -> None:
82+
def test_slakh_to_tf_example(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
8383
mock_slakh_home = tmp_path / "slakh"
8484
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"
8585

@@ -105,7 +105,7 @@ def test_slakh_to_tf_example(tmp_path: pathlib.Path) -> None:
105105
assert len(data) != 0
106106

107107

108-
def test_slakh_invalid_tracks(tmp_path: pathlib.Path) -> None:
108+
def test_slakh_invalid_tracks(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
109109
mock_slakh_home = tmp_path / "slakh"
110110
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"
111111

@@ -132,7 +132,7 @@ def test_slakh_invalid_tracks(tmp_path: pathlib.Path) -> None:
132132
assert fp.read().strip() == track_id
133133

134134

135-
def test_slakh_invalid_tracks_omitted(tmp_path: pathlib.Path) -> None:
135+
def test_slakh_invalid_tracks_omitted(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
136136
mock_slakh_home = tmp_path / "slakh"
137137
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"
138138

@@ -161,7 +161,7 @@ def test_slakh_invalid_tracks_omitted(tmp_path: pathlib.Path) -> None:
161161
assert fp.read().strip() == ""
162162

163163

164-
def test_slakh_invalid_tracks_drums(tmp_path: pathlib.Path) -> None:
164+
def test_slakh_invalid_tracks_drums(tmp_path: pathlib.Path, mock_slakh_index: None) -> None:
165165
mock_slakh_home = tmp_path / "slakh"
166166
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"
167167

tests/data/test_tf_example_deserialization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def mock_and_process(split: str, track_id: str) -> None:
8686
return output_home
8787

8888

89-
def test_prepare_datasets(tmp_path: pathlib.Path) -> None:
89+
def test_prepare_datasets(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
9090
datasets_home = setup_test_resources(tmp_path)
9191

9292
ds_train, ds_valid = prepare_datasets(
@@ -102,7 +102,7 @@ def test_prepare_datasets(tmp_path: pathlib.Path) -> None:
102102
assert ds_valid is not None and isinstance(ds_valid, tf.data.Dataset)
103103

104104

105-
def test_prepare_visualization_dataset(tmp_path: pathlib.Path) -> None:
105+
def test_prepare_visualization_dataset(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
106106
datasets_home = setup_test_resources(tmp_path)
107107

108108
ds_train, ds_valid = prepare_visualization_datasets(
@@ -117,7 +117,7 @@ def test_prepare_visualization_dataset(tmp_path: pathlib.Path) -> None:
117117
assert ds_valid is not None and isinstance(ds_train, tf.data.Dataset)
118118

119119

120-
def test_sample_datasets(tmp_path: pathlib.Path) -> None:
120+
def test_sample_datasets(tmp_path: pathlib.Path, mock_guitarset_index: None) -> None:
121121
"""touches the following methods:
122122
- transcription_dataset
123123
- parse_transcription_tfexample
@@ -126,6 +126,7 @@ def test_sample_datasets(tmp_path: pathlib.Path) -> None:
126126
- reduce_transcription_inputs
127127
- get_sample_weights
128128
- _infer_time_size
129+
- _infer_time_size
129130
- get_transcription_chunks
130131
- extract_random_window
131132
- extract_window

0 commit comments

Comments
 (0)