-
Notifications
You must be signed in to change notification settings - Fork 36
Expand file tree
/
Copy path_doctor_check.py
More file actions
84 lines (70 loc) · 2.84 KB
/
_doctor_check.py
File metadata and controls
84 lines (70 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
"""
Smoke check for `rapids doctor` (RAPIDS CLI).
See: https://github.com/rapidsai/rapids-cli#check-plugins
"""
def cugraph_pyg_smoke_check(**kwargs):
"""
A quick check to ensure cugraph-pyg can be imported and its core
submodules are loadable.
"""
try:
import cugraph_pyg
# Ensure core submodules load (touches pylibwholegraph, torch-geometric, etc.)
import cugraph_pyg.data
import cugraph_pyg.tensor
except ImportError as e:
raise ImportError(
"cugraph-pyg or its dependencies could not be imported. "
"Tip: install with `pip install cugraph-pyg` or use a RAPIDS conda environment."
) from e
if not hasattr(cugraph_pyg, "__version__") or not cugraph_pyg.__version__:
raise AssertionError(
"cugraph-pyg smoke check failed: __version__ not found or empty"
)
from cugraph_pyg.utils.imports import import_optional, MissingModule
torch = import_optional("torch")
if isinstance(torch, MissingModule) or not torch.cuda.is_available():
import warnings
warnings.warn(
"PyTorch with CUDA support is required to use cuGraph-PyG. "
"Please install PyTorch from PyPI or Conda-Forge."
)
else:
import os
from cugraph_pyg.data import GraphStore
addr = os.environ.get("MASTER_ADDR", "")
port = os.environ.get("MASTER_PORT", "")
local_rank = os.environ.get("LOCAL_RANK", "")
world_size = os.environ.get("WORLD_SIZE", "")
local_world_size = os.environ.get("LOCAL_WORLD_SIZE", "")
rank = os.environ.get("RANK", "")
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29505"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
torch.distributed.init_process_group("nccl")
graph_store = GraphStore()
graph_store.put_edge_index(
torch.tensor([[0, 1], [1, 2]]),
("person", "knows", "person"),
"coo",
False,
(3, 3),
)
edge_index = graph_store.get_edge_index(
("person", "knows", "person"), "coo"
)
assert edge_index.shape == torch.Size([2, 2])
finally:
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = port
os.environ["LOCAL_RANK"] = local_rank
os.environ["WORLD_SIZE"] = world_size
os.environ["LOCAL_WORLD_SIZE"] = local_world_size
os.environ["RANK"] = rank
torch.distributed.destroy_process_group()