Skip to content
20 changes: 15 additions & 5 deletions pytest_doctestplus/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def pytest_addoption(parser):
"Options accepted are 'txt', 'tex', and 'rst'. "
"This is no longer recommended, use --doctest-glob instead."
))

parser.addoption("--text-file-encoding", action="store",
help="Specify encoding for files.",
default="utf-8")

# Defaults to `atol` parameter from `numpy.allclose`.
parser.addoption("--doctest-plus-atol", action="store",
Expand Down Expand Up @@ -142,6 +146,9 @@ def pytest_addoption(parser):
parser.addini("text_file_format",
"Default format for docs. "
"This is no longer recommended, use --doctest-glob instead.")

parser.addini("text_file_encoding",
"Default encoding for text files.", default=None)

parser.addini("doctest_optionflags", "option flags for doctests",
type="args", default=["ELLIPSIS", "NORMALIZE_WHITESPACE"],)
Expand Down Expand Up @@ -912,13 +919,13 @@ def test_filter(test):
return tests


def write_modified_file(fname, new_fname, changes):
def write_modified_file(fname, new_fname, changes, encoding=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default here be consistent with addoption default (utf-8)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be consistent, though I'm not sure if we should change the default to utf-8 or leave as is now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to set encoding=None to avoid changing the behavior of the write_modified_file function since I am not familiar with your codebase. I wanted to ensure that the function behaves as it did before.

From a developer's perspective, I would prefer using "utf-8" as the default encoding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any changes required?

i guess the default value of write_modified_file is uncritical.

default value for cli option defaults to "utf-8" which will be passed on.

# Sort in reversed order to edit the lines:
bad_tests = []
changes.sort(key=lambda x: (x["test_lineno"], x["example_lineno"]),
reverse=True)

with open(fname) as f:
with open(fname, encoding=encoding) as f:
text = f.readlines()

for change in changes:
Expand All @@ -939,7 +946,7 @@ def write_modified_file(fname, new_fname, changes):

text[lineno:lineno+want.count("\n")] = [got]

with open(new_fname, "w") as f:
with open(new_fname, "w", encoding=encoding) as f:
f.write("".join(text))

return bad_tests
Expand All @@ -953,6 +960,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
all_bad_tests = []
if not diff_mode:
return # we do not report or apply diffs

# get encoding to open file default ini=None or option="utf-8"
encoding = config.getini("text_file_encoding") or config.getoption("text_file_encoding")

if diff_mode != "overwrite":
# In this mode, we write a corrected file to a temporary folder in
Expand All @@ -974,7 +984,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
new_fname = fname.replace(common_path, tmpdirname)
os.makedirs(os.path.split(new_fname)[0], exist_ok=True)

bad_tests = write_modified_file(fname, new_fname, changes)
bad_tests = write_modified_file(fname, new_fname, changes, encoding)
all_bad_tests.extend(bad_tests)

# git diff returns 1 to signal changes, so just ignore the
Expand Down Expand Up @@ -1013,7 +1023,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
return
terminalreporter.write_line("Applied fix to the following files:")
for fname, changes in changesets.items():
bad_tests = write_modified_file(fname, fname, changes)
bad_tests = write_modified_file(fname, fname, changes, encoding)
all_bad_tests.extend(bad_tests)
terminalreporter.write_line(f" {fname}")

Expand Down