Skip to content

Commit 2dd203f

Browse files
Fix KFTO-SDK MNIST test for disconnected
1 parent 2fc5004 commit 2dd203f

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

tests/kfto/resources/kfto_sdk_mnist.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def train_func():
99
from minio import Minio
1010
import shutil
1111
import gzip
12-
12+
from urllib.parse import urlparse
1313

1414
# [1] Setup PyTorch DDP. Distributed environment will be set automatically by Training Operator.
1515
dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo")
@@ -57,13 +57,20 @@ def forward(self, x):
5757
secret_key = "{{.StorageBucketSecretKey}}"
5858
bucket_name = "{{.StorageBucketName}}"
5959
prefix = "{{.StorageBucketMnistDir}}"
60-
if with_aws != "true":
60+
61+
# Sanitize endpoint to remove any scheme or path.
62+
parsed = urlparse(endpoint)
63+
# If the endpoint URL contains a scheme, netloc contains the host and optional port.
64+
endpoint = parsed.netloc if parsed.netloc else parsed.path
65+
secure = parsed.scheme == "https"
66+
67+
if with_aws == "true":
6168
client = Minio(
6269
endpoint,
6370
access_key=access_key,
6471
secret_key=secret_key,
6572
cert_check=False,
66-
secure=False, #TODO
73+
secure=secure,
6774
)
6875

6976
if not os.path.exists(dataset_dir):

tests/kfto/resources/mnist_kfto.ipynb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@
6161
" resources_per_worker={\"gpu\": num_gpus},\n",
6262
" base_image=training_image,\n",
6363
" packages_to_install=[\"torchvision==0.19.0\",\"minio==7.2.13\"],\n",
64+
" pip_index_url= os.environ.get(\"PIP_INDEX_URL\"),\n",
6465
" env_vars={\n",
6566
" \"NCCL_DEBUG\": \"INFO\", \n",
6667
" \"TORCH_DISTRIBUTED_DEBUG\": \"DETAIL\", \n",
67-
" \"PIP_INDEX_URL\": os.environ.get(\"PIP_INDEX_URL\"),\n",
68+
" \"DEFAULT_PIP_INDEX_URL\": os.environ.get(\"PIP_INDEX_URL\"),\n",
6869
" \"PIP_TRUSTED_HOST\": os.environ.get(\"PIP_TRUSTED_HOST\")\n",
6970
" }\n",
7071
")"
@@ -78,7 +79,8 @@
7879
"outputs": [],
7980
"source": [
8081
"while not tc.is_job_succeeded(name=\"pytorch-ddp\", namespace=namespace): \n",
81-
" time.sleep(1)"
82+
" time.sleep(1)\n",
83+
"print(\"PytorchJob Succeeded!\")"
8284
]
8385
},
8486
{

0 commit comments

Comments
 (0)