Skip to content

Commit 3c80c16

Browse files
committed
Edit on github button fix
1 parent 877e549 commit 3c80c16

File tree

1 file changed

+50
-92
lines changed

1 file changed

+50
-92
lines changed

conf.py

Lines changed: 50 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,25 @@
3434

3535
html_theme = "pytorch_sphinx_theme2"
3636
html_theme_path = [pytorch_sphinx_theme2.get_html_theme_path()]
37-
import torch
37+
import distutils.file_util
3838
import glob
3939
import random
40-
import shutil
41-
from pathlib import Path
42-
import shutil
43-
import distutils.file_util
4440
import re
45-
from get_sphinx_filenames import SPHINX_SHOULD_RUN
46-
import pandocfilters
47-
import pypandoc
48-
import plotly.io as pio
41+
import shutil
4942
from pathlib import Path
5043
pio.renderers.default = 'sphinx_gallery'
5144
from redirects import redirects
5245

46+
import pandocfilters
47+
import plotly.io as pio
48+
import pypandoc
49+
import torch
50+
from get_sphinx_filenames import SPHINX_SHOULD_RUN
51+
52+
pio.renderers.default = "sphinx_gallery"
53+
54+
import multiprocessing
55+
5356
import sphinx_gallery.gen_rst
5457

5558

@@ -65,34 +68,6 @@
6568
# Alt option 2: Run sphinx gallery once per file (similar to how we shard in CI
6669
# but with shard sizes of 1), but running sphinx gallery for each file has a
6770
# ~5min overhead, resulting in the entire suite taking ~2x time
68-
def call_fn(func, args, kwargs, result_queue):
69-
try:
70-
result = func(*args, **kwargs)
71-
result_queue.put((True, result))
72-
except Exception as e:
73-
result_queue.put((False, str(e)))
74-
75-
76-
def call_in_subprocess(func):
77-
def wrapper(*args, **kwargs):
78-
result_queue = multiprocessing.Queue()
79-
p = multiprocessing.Process(
80-
target=call_fn, args=(func, args, kwargs, result_queue)
81-
)
82-
p.start()
83-
p.join()
84-
success, result = result_queue.get()
85-
if success:
86-
return result
87-
else:
88-
raise RuntimeError(f"Error in subprocess: {result}")
89-
90-
return wrapper
91-
92-
93-
sphinx_gallery.gen_rst.generate_file_rst = call_in_subprocess(
94-
sphinx_gallery.gen_rst.generate_file_rst
95-
)
9671

9772
try:
9873
import torchvision
@@ -120,7 +95,7 @@ def wrapper(*args, **kwargs):
12095
}
12196

12297
html_additional_pages = {
123-
'404': '404.html',
98+
"404": "404.html",
12499
}
125100

126101
# Add any Sphinx extension module names here, as strings. They can be
@@ -245,6 +220,35 @@ def wrapper(*args, **kwargs):
245220
}
246221

247222

223+
def get_html_context():
224+
context = {
225+
"theme_variables": theme_variables,
226+
"display_github": True,
227+
"github_url": "https://github.com",
228+
"github_user": "pytorch",
229+
"github_repo": "tutorials",
230+
"feedback_url": "https://github.com/pytorch/tutorials",
231+
"github_version": "main",
232+
"doc_path": "docs/source",
233+
"library_links": theme_variables.get("library_links", []),
234+
"icon_links": theme_variables.get("icon_links", []),
235+
"community_links": theme_variables.get("community_links", []),
236+
"pytorch_project": "docs",
237+
"language_bindings_links": html_theme_options.get(
238+
"language_bindings_links", []
239+
),
240+
}
241+
242+
# Function to determine if edit button should be shown
243+
def should_show_edit_button(pagename, sourcename):
244+
return not sourcename.endswith(".py")
245+
246+
context["should_show_edit_button"] = should_show_edit_button
247+
return context
248+
249+
250+
html_context = get_html_context()
251+
248252
if os.getenv("GALLERY_PATTERN"):
249253
# GALLERY_PATTERN is to be used when you want to work on a single
250254
# tutorial. Previously this was fed into filename_pattern, but
@@ -275,58 +279,6 @@ def wrapper(*args, **kwargs):
275279
]
276280

277281

278-
def fix_gallery_edit_links(app, pagename, templatename, context, doctree):
279-
if pagename.startswith(
280-
("beginner/", "intermediate/", "advanced/", "recipes/", "prototype/")
281-
):
282-
parts = pagename.split("/")
283-
gallery_dir = parts[0]
284-
# Handle nested directories by joining all parts except the first
285-
example_path = "/".join(parts[1:])
286-
example_name = parts[-1]
287-
288-
source_dirs = {}
289-
for i in range(len(sphinx_gallery_conf["examples_dirs"])):
290-
gallery_dir = sphinx_gallery_conf["gallery_dirs"][i]
291-
source_dir = sphinx_gallery_conf["examples_dirs"][i]
292-
# Extract the base name without "_source" suffix
293-
gallery_base = gallery_dir
294-
source_dirs[gallery_base] = source_dir
295-
296-
if gallery_dir in source_dirs:
297-
source_dir = source_dirs[gallery_dir]
298-
299-
# Reconstruct the path preserving subdirectories
300-
subdir = "/".join(parts[1:-1]) if len(parts) > 2 else ""
301-
302-
# Check if .py file exists
303-
py_path = (
304-
f"{source_dir}/{subdir}/{example_name}.py"
305-
if subdir
306-
else f"{source_dir}/{example_name}.py"
307-
)
308-
rst_path = (
309-
f"{source_dir}/{subdir}/{example_name}.rst"
310-
if subdir
311-
else f"{source_dir}/{example_name}.rst"
312-
)
313-
314-
# Clean up any double slashes
315-
py_path = py_path.replace("//", "/")
316-
rst_path = rst_path.replace("//", "/")
317-
318-
# Default to .py file, fallback to .rst if needed
319-
file_path = py_path
320-
if not os.path.exists(
321-
os.path.join(os.path.dirname(__file__), py_path)
322-
) and os.path.exists(os.path.join(os.path.dirname(__file__), rst_path)):
323-
file_path = rst_path
324-
325-
context["edit_url"] = (
326-
f"{html_context['github_url']}/{html_context['github_user']}/{html_context['github_repo']}/edit/{html_context['github_version']}/{file_path}"
327-
)
328-
329-
330282
# The suffix(es) of source filenames.
331283
# You can specify multiple suffix as a list of string:
332284
#
@@ -364,8 +316,8 @@ def fix_gallery_edit_links(app, pagename, templatename, context, doctree):
364316
"Thumbs.db",
365317
".DS_Store",
366318
"src/pytorch-sphinx-theme/docs*",
367-
# "**/huggindef fix_gallery_edit_linksgface_hub/templates/**",
368319
]
320+
369321
exclude_patterns += sphinx_gallery_conf["examples_dirs"]
370322
exclude_patterns += ["*/index.rst"]
371323

@@ -478,4 +430,10 @@ def handle_jinja_templates(app, docname, source):
478430

479431
def setup(app):
480432
app.connect("source-read", handle_jinja_templates)
481-
app.connect("html-page-context", fix_gallery_edit_links)
433+
app.connect("html-page-context", update_context)
434+
435+
436+
def update_context(app, pagename, templatename, context, doctree):
437+
# Get source file name
438+
if "page_source_suffix" in context and context["page_source_suffix"] == ".py":
439+
context["display_github"] = False

0 commit comments

Comments
 (0)