Skip to content

Commit b6b0bbf

Browse files
Johannes Ballécopybara-github
authored andcommitted
Fixes bug in Y4MDataset reading chroma planes incorrectly.
PiperOrigin-RevId: 373819685 Change-Id: I93d6238297056260ad087b02287f7002e5d06033
1 parent 66228f0 commit b6b0bbf

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

tensorflow_compression/cc/kernels/y4m_dataset_kernels.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class Y4MDatasetOp : public DatasetOpKernel {
132132
cbcr_width /= 2;
133133
cbcr_height /= 2;
134134
}
135+
const size_t cbcr_size = cbcr_width * cbcr_height;
135136

136137
// This is a no-op for the second and subsequent frames.
137138
buffer_.resize(frame_header.size() + frame_size);
@@ -159,8 +160,10 @@ class Y4MDatasetOp : public DatasetOpKernel {
159160
auto flat_cbcr = cbcr_tensor.flat<uint8>();
160161
std::memcpy(flat_y.data(), frame_buffer.data(), flat_y.size());
161162
frame_buffer.remove_prefix(flat_y.size());
162-
std::memcpy(flat_cbcr.data(), frame_buffer.data(),
163-
flat_cbcr.size());
163+
for (int i = 0; i < cbcr_size; i++) {
164+
flat_cbcr.data()[2*i] = frame_buffer[i];
165+
flat_cbcr.data()[2*i+1] = frame_buffer[cbcr_size+i];
166+
}
164167
out_tensors->push_back(std::move(y_tensor));
165168
out_tensors->push_back(std::move(cbcr_tensor));
166169

tensorflow_compression/python/datasets/y4m_dataset_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Y4MDatasetTest(tf.test.TestCase):
2828
def setUp(self):
2929
super().setUp()
3030
self.tempfile_1 = self.create_tempfile(
31-
content=b"YUV4MPEG2 W2 H2 F30:1 Ip A0:0 C420jpeg\nFRAME\nABCDEF")
31+
content=b"YUV4MPEG2 W4 H2 F30:1 Ip A0:0 C420jpeg\nFRAME\nABCDEFGHIJKL")
3232
self.tempfile_2 = self.create_tempfile(
3333
content=b"YUV4MPEG2 C444 W1 H1\nFRAME\nabcFRAME\ndef")
3434

@@ -40,8 +40,10 @@ def test_dataset_yields_correct_sequence(self):
4040
y, cbcr = next(it)
4141
self.assertEqual(tf.uint8, y.dtype)
4242
self.assertEqual(tf.uint8, cbcr.dtype)
43-
self.assertAllEqual(shaped_uint8(b"ABCD", (2, 2, 1)), y)
44-
self.assertAllEqual(shaped_uint8(b"EF", (1, 1, 2)), cbcr)
43+
cb, cr = tf.unstack(cbcr, axis=-1)
44+
self.assertAllEqual(shaped_uint8(b"ABCDEFGH", (2, 4, 1)), y)
45+
self.assertAllEqual(shaped_uint8(b"IJ", (1, 2)), cb)
46+
self.assertAllEqual(shaped_uint8(b"KL", (1, 2)), cr)
4547

4648
y, cbcr = next(it)
4749
self.assertEqual(tf.uint8, y.dtype)

0 commit comments

Comments
 (0)