diff --git a/beginner_source/basics/data_tutorial.py b/beginner_source/basics/data_tutorial.py index 561e9723fde..fe649e9e009 100644 --- a/beginner_source/basics/data_tutorial.py +++ b/beginner_source/basics/data_tutorial.py @@ -120,7 +120,7 @@ import os import pandas as pd -from torchvision.io import read_image +from torchvision.io import decode_image class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): @@ -134,7 +134,7 @@ def __len__(self): def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) - image = read_image(img_path) + image = decode_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) @@ -184,7 +184,7 @@ def __len__(self): # ^^^^^^^^^^^^^^^^^^^^ # # The __getitem__ function loads and returns a sample from the dataset at the given index ``idx``. -# Based on the index, it identifies the image's location on disk, converts that to a tensor using ``read_image``, retrieves the +# Based on the index, it identifies the image's location on disk, converts that to a tensor using ``decode_image``, retrieves the # corresponding label from the csv data in ``self.img_labels``, calls the transform functions on them (if applicable), and returns the # tensor image and corresponding label in a tuple.