|
34 | 34 |
|
35 | 35 | html_theme = "pytorch_sphinx_theme2"
|
36 | 36 | html_theme_path = [pytorch_sphinx_theme2.get_html_theme_path()]
|
37 |
| -import torch |
| 37 | +import distutils.file_util |
38 | 38 | import glob
|
39 | 39 | import random
|
40 |
| -import shutil |
41 |
| -from pathlib import Path |
42 |
| -import shutil |
43 |
| -import distutils.file_util |
44 | 40 | import re
|
45 |
| -from get_sphinx_filenames import SPHINX_SHOULD_RUN |
46 |
| -import pandocfilters |
47 |
| -import pypandoc |
48 |
| -import plotly.io as pio |
| 41 | +import shutil |
49 | 42 | from pathlib import Path
|
50 | 43 | pio.renderers.default = 'sphinx_gallery'
|
51 | 44 | from redirects import redirects
|
52 | 45 |
|
| 46 | +import pandocfilters |
| 47 | +import plotly.io as pio |
| 48 | +import pypandoc |
| 49 | +import torch |
| 50 | +from get_sphinx_filenames import SPHINX_SHOULD_RUN |
| 51 | + |
| 52 | +pio.renderers.default = "sphinx_gallery" |
| 53 | + |
| 54 | +import multiprocessing |
| 55 | + |
53 | 56 | import sphinx_gallery.gen_rst
|
54 | 57 |
|
55 | 58 |
|
|
65 | 68 | # Alt option 2: Run sphinx gallery once per file (similar to how we shard in CI
|
66 | 69 | # but with shard sizes of 1), but running sphinx gallery for each file has a
|
67 | 70 | # ~5min overhead, resulting in the entire suite taking ~2x time
|
68 |
| -def call_fn(func, args, kwargs, result_queue): |
69 |
| - try: |
70 |
| - result = func(*args, **kwargs) |
71 |
| - result_queue.put((True, result)) |
72 |
| - except Exception as e: |
73 |
| - result_queue.put((False, str(e))) |
74 |
| - |
75 |
| - |
76 |
| -def call_in_subprocess(func): |
77 |
| - def wrapper(*args, **kwargs): |
78 |
| - result_queue = multiprocessing.Queue() |
79 |
| - p = multiprocessing.Process( |
80 |
| - target=call_fn, args=(func, args, kwargs, result_queue) |
81 |
| - ) |
82 |
| - p.start() |
83 |
| - p.join() |
84 |
| - success, result = result_queue.get() |
85 |
| - if success: |
86 |
| - return result |
87 |
| - else: |
88 |
| - raise RuntimeError(f"Error in subprocess: {result}") |
89 |
| - |
90 |
| - return wrapper |
91 |
| - |
92 |
| - |
93 |
| -sphinx_gallery.gen_rst.generate_file_rst = call_in_subprocess( |
94 |
| - sphinx_gallery.gen_rst.generate_file_rst |
95 |
| -) |
96 | 71 |
|
97 | 72 | try:
|
98 | 73 | import torchvision
|
@@ -120,7 +95,7 @@ def wrapper(*args, **kwargs):
|
120 | 95 | }
|
121 | 96 |
|
122 | 97 | html_additional_pages = {
|
123 |
| - '404': '404.html', |
| 98 | + "404": "404.html", |
124 | 99 | }
|
125 | 100 |
|
126 | 101 | # Add any Sphinx extension module names here, as strings. They can be
|
@@ -245,6 +220,35 @@ def wrapper(*args, **kwargs):
|
245 | 220 | }
|
246 | 221 |
|
247 | 222 |
|
| 223 | +def get_html_context(): |
| 224 | + context = { |
| 225 | + "theme_variables": theme_variables, |
| 226 | + "display_github": True, |
| 227 | + "github_url": "https://github.com", |
| 228 | + "github_user": "pytorch", |
| 229 | + "github_repo": "tutorials", |
| 230 | + "feedback_url": "https://github.com/pytorch/tutorials", |
| 231 | + "github_version": "main", |
| 232 | + "doc_path": "docs/source", |
| 233 | + "library_links": theme_variables.get("library_links", []), |
| 234 | + "icon_links": theme_variables.get("icon_links", []), |
| 235 | + "community_links": theme_variables.get("community_links", []), |
| 236 | + "pytorch_project": "docs", |
| 237 | + "language_bindings_links": html_theme_options.get( |
| 238 | + "language_bindings_links", [] |
| 239 | + ), |
| 240 | + } |
| 241 | + |
| 242 | + # Function to determine if edit button should be shown |
| 243 | + def should_show_edit_button(pagename, sourcename): |
| 244 | + return not sourcename.endswith(".py") |
| 245 | + |
| 246 | + context["should_show_edit_button"] = should_show_edit_button |
| 247 | + return context |
| 248 | + |
| 249 | + |
| 250 | +html_context = get_html_context() |
| 251 | + |
248 | 252 | if os.getenv("GALLERY_PATTERN"):
|
249 | 253 | # GALLERY_PATTERN is to be used when you want to work on a single
|
250 | 254 | # tutorial. Previously this was fed into filename_pattern, but
|
@@ -275,58 +279,6 @@ def wrapper(*args, **kwargs):
|
275 | 279 | ]
|
276 | 280 |
|
277 | 281 |
|
278 |
| -def fix_gallery_edit_links(app, pagename, templatename, context, doctree): |
279 |
| - if pagename.startswith( |
280 |
| - ("beginner/", "intermediate/", "advanced/", "recipes/", "prototype/") |
281 |
| - ): |
282 |
| - parts = pagename.split("/") |
283 |
| - gallery_dir = parts[0] |
284 |
| - # Handle nested directories by joining all parts except the first |
285 |
| - example_path = "/".join(parts[1:]) |
286 |
| - example_name = parts[-1] |
287 |
| - |
288 |
| - source_dirs = {} |
289 |
| - for i in range(len(sphinx_gallery_conf["examples_dirs"])): |
290 |
| - gallery_dir = sphinx_gallery_conf["gallery_dirs"][i] |
291 |
| - source_dir = sphinx_gallery_conf["examples_dirs"][i] |
292 |
| - # Extract the base name without "_source" suffix |
293 |
| - gallery_base = gallery_dir |
294 |
| - source_dirs[gallery_base] = source_dir |
295 |
| - |
296 |
| - if gallery_dir in source_dirs: |
297 |
| - source_dir = source_dirs[gallery_dir] |
298 |
| - |
299 |
| - # Reconstruct the path preserving subdirectories |
300 |
| - subdir = "/".join(parts[1:-1]) if len(parts) > 2 else "" |
301 |
| - |
302 |
| - # Check if .py file exists |
303 |
| - py_path = ( |
304 |
| - f"{source_dir}/{subdir}/{example_name}.py" |
305 |
| - if subdir |
306 |
| - else f"{source_dir}/{example_name}.py" |
307 |
| - ) |
308 |
| - rst_path = ( |
309 |
| - f"{source_dir}/{subdir}/{example_name}.rst" |
310 |
| - if subdir |
311 |
| - else f"{source_dir}/{example_name}.rst" |
312 |
| - ) |
313 |
| - |
314 |
| - # Clean up any double slashes |
315 |
| - py_path = py_path.replace("//", "/") |
316 |
| - rst_path = rst_path.replace("//", "/") |
317 |
| - |
318 |
| - # Default to .py file, fallback to .rst if needed |
319 |
| - file_path = py_path |
320 |
| - if not os.path.exists( |
321 |
| - os.path.join(os.path.dirname(__file__), py_path) |
322 |
| - ) and os.path.exists(os.path.join(os.path.dirname(__file__), rst_path)): |
323 |
| - file_path = rst_path |
324 |
| - |
325 |
| - context["edit_url"] = ( |
326 |
| - f"{html_context['github_url']}/{html_context['github_user']}/{html_context['github_repo']}/edit/{html_context['github_version']}/{file_path}" |
327 |
| - ) |
328 |
| - |
329 |
| - |
330 | 282 | # The suffix(es) of source filenames.
|
331 | 283 | # You can specify multiple suffix as a list of string:
|
332 | 284 | #
|
@@ -364,8 +316,8 @@ def fix_gallery_edit_links(app, pagename, templatename, context, doctree):
|
364 | 316 | "Thumbs.db",
|
365 | 317 | ".DS_Store",
|
366 | 318 | "src/pytorch-sphinx-theme/docs*",
|
367 |
| - # "**/huggindef fix_gallery_edit_linksgface_hub/templates/**", |
368 | 319 | ]
|
| 320 | + |
369 | 321 | exclude_patterns += sphinx_gallery_conf["examples_dirs"]
|
370 | 322 | exclude_patterns += ["*/index.rst"]
|
371 | 323 |
|
@@ -478,4 +430,10 @@ def handle_jinja_templates(app, docname, source):
|
478 | 430 |
|
479 | 431 | def setup(app):
|
480 | 432 | app.connect("source-read", handle_jinja_templates)
|
481 |
| - app.connect("html-page-context", fix_gallery_edit_links) |
| 433 | + app.connect("html-page-context", update_context) |
| 434 | + |
| 435 | + |
| 436 | +def update_context(app, pagename, templatename, context, doctree): |
| 437 | + # Get source file name |
| 438 | + if "page_source_suffix" in context and context["page_source_suffix"] == ".py": |
| 439 | + context["display_github"] = False |
0 commit comments