-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Dear vadimkantorov,
thank you for your publishing this nice repo, very well written.
I'm running
"python train.py --dataset cub2011 --model margin --base resnet50"
with pytorch 1.0.1 and pythorn 3.6 but it crushes with the error
Traceback (most recent call last):
File "train.py", line 71, in
loader_train = torch.utils.data.DataLoader(dataset_train, sampler = adapt_sampler(opts.batch, dataset_train, opts.sampler), num_workers = opts.threads, batch_size = opts.batch, drop_last = True, pin_memory = True)
File "/export/home/bbrattol/anaconda2/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 805, in init
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
File "/export/home/bbrattol/anaconda2/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 146, in init
.format(sampler))
ValueError: sampler should be an instance of torch.utils.data.Sampler, but got sampler=<main. object at 0x7f199b08a9b0>
I guess it has something to do with the new pytorch version. Could you help me to make it run correctly?
Thanks