Skip to content

Commit a2015cc

Browse files
committed
refactor upload_fileobj in fsspec and add tests
1 parent 24bafe5 commit a2015cc

File tree

4 files changed

+99
-22
lines changed

4 files changed

+99
-22
lines changed

src/webdav4/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import threading
55
from contextlib import contextmanager, suppress
66
from http import HTTPStatus
7-
from io import TextIOWrapper, UnsupportedOperation
7+
from io import TextIOWrapper
88
from typing import (
99
TYPE_CHECKING,
1010
Any,
@@ -662,8 +662,7 @@ def upload_fileobj(
662662
# if we are not successfull in that, we gracefully fallback
663663
# to the chunked encoding.
664664
if size is None:
665-
with suppress(TypeError, AttributeError, UnsupportedOperation):
666-
size = peek_filelike_length(file_obj)
665+
size = peek_filelike_length(file_obj)
667666

668667
headers = {"Content-Length": str(size)} if size is not None else None
669668
if not overwrite and self.exists(to_path):

src/webdav4/fsspec.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
cast,
2323
)
2424

25+
from fsspec import Callback
2526
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
2627

2728
from .client import (
@@ -32,6 +33,7 @@
3233
ResourceConflict,
3334
ResourceNotFound,
3435
)
36+
from .fs_utils import peek_filelike_length
3537
from .stream import read_into
3638

3739
if TYPE_CHECKING:
@@ -41,8 +43,7 @@
4143
from os import PathLike
4244
from typing import AnyStr
4345

44-
from fsspec import Callback
45-
46+
from .callback import CallbackFn
4647
from .types import AuthTypes, URLTypes
4748

4849

@@ -319,12 +320,41 @@ def sign(self, path: str, expiration: int = 100, **kwargs: Any) -> None:
319320

320321
def pipe_file(self, path: str, value: bytes, **kwargs: Any) -> None:
321322
"""Upload the contents to given file in the remote webdav server."""
322-
path = self._strip_protocol(path)
323323
buff = io.BytesIO(value)
324324
kwargs.setdefault("overwrite", True)
325-
# maybe it's not a bad idea to make a `self.open` for `mode="rb"`
326-
# on top of `io.BytesIO`?
327-
self.client.upload_fileobj(buff, path, **kwargs)
325+
return self.upload_fileobj(buff, path, **kwargs)
326+
327+
def upload_fileobj(
328+
self,
329+
fobj: BinaryIO,
330+
rpath: str,
331+
callback: "Callback" = None,
332+
overwrite: bool = True,
333+
size: int = None,
334+
**kwargs: Any,
335+
) -> None:
336+
"""Upload contents from the fileobj to the remote path."""
337+
rpath = self._strip_protocol(rpath)
338+
self.mkdirs(os.path.dirname(rpath), exist_ok=True)
339+
340+
if size is None:
341+
size = peek_filelike_length(fobj)
342+
343+
callback = cast("Callback", Callback.as_callback(callback))
344+
if size is not None: # pragma: no cover
345+
callback.set_size(size)
346+
progress_callback = cast("CallbackFn", callback.relative_update)
347+
348+
return self.client.upload_fileobj(
349+
fobj,
350+
rpath,
351+
overwrite=overwrite,
352+
callback=progress_callback,
353+
size=size,
354+
**kwargs,
355+
)
356+
357+
put_fileobj = upload_fileobj
328358

329359
def put_file(
330360
self,
@@ -334,16 +364,19 @@ def put_file(
334364
**kwargs: Any,
335365
) -> None:
336366
"""Copy file to remote webdav server."""
337-
rpath = self._strip_protocol(rpath)
338367
if os.path.isdir(lpath):
339-
self.makedirs(rpath, exist_ok=True)
340-
else:
341-
if callback is not None:
342-
callback.set_size(os.path.getsize(lpath))
343-
kwargs.setdefault("callback", callback.relative_update)
344-
self.mkdirs(os.path.dirname(rpath), exist_ok=True)
368+
rpath = self._strip_protocol(rpath)
369+
return self.makedirs(rpath, exist_ok=True)
370+
371+
with open(lpath, mode="rb") as fobj:
345372
kwargs.setdefault("overwrite", True)
346-
self.client.upload_file(lpath, rpath, **kwargs)
373+
kwargs.setdefault("size", None)
374+
return self.upload_fileobj(
375+
fobj,
376+
rpath,
377+
callback=callback,
378+
**kwargs,
379+
)
347380

348381

349382
class WebdavFile(AbstractBufferedFile):

tests/test_fs_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Test fs utils."""
2+
from io import BytesIO
3+
4+
from webdav4.fs_utils import peek_filelike_length
5+
6+
from .test_callback import ReadWrapper
7+
8+
9+
def test_peek_filelike_length():
10+
"""Test peek_filelike length for the fileobj."""
11+
fobj = BytesIO(b"Hello, World!")
12+
13+
assert peek_filelike_length(fobj) == 13
14+
assert peek_filelike_length(ReadWrapper(fobj)) is None # type: ignore
15+
assert peek_filelike_length(object()) is None

tests/test_fsspec.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from webdav4.fsspec import WebdavFileSystem
1010
from webdav4.urls import URL, join_url
1111

12+
from .test_callback import ReadWrapper
1213
from .utils import TmpDir, approx_datetime
1314

1415

@@ -828,15 +829,44 @@ def test_touch_not_truncate(storage_dir: TmpDir, fs: WebdavFileSystem):
828829
fs.touch("foo", truncate=False)
829830

830831

832+
@pytest.mark.parametrize("with_size", [True, False])
833+
@pytest.mark.parametrize("wrap_fobj", [True, False])
834+
def test_upload_fileobj(
835+
storage_dir: TmpDir,
836+
fs: WebdavFileSystem,
837+
with_size: bool,
838+
wrap_fobj: bool,
839+
):
840+
"""Test upload_fileobj.
841+
842+
If with_size, size hints are provided to the API.
843+
If wrap_fobj, the file object is wrapped with the ReadWrapper
844+
that makes the peek_fileobj_length util to not be able to
845+
figure out the size of the given fileobj. For the most part,
846+
it should still work in this case.
847+
"""
848+
if wrap_fobj and not with_size:
849+
pytest.skip("the test server does not work without content-length :(")
850+
851+
foo = storage_dir / "foo"
852+
length = foo.write_bytes(b"foo")
853+
854+
with foo.open(mode="rb") as fobj:
855+
if wrap_fobj:
856+
fobj = ReadWrapper(fobj) # type: ignore
857+
size = length if with_size else None
858+
fs.upload_fileobj(fobj, "data/foobar", size=size)
859+
860+
assert (storage_dir / "data").cat() == {"foobar": "foo"}
861+
862+
831863
def test_callbacks(storage_dir: TmpDir, fs: WebdavFileSystem):
832864
"""Test fsspec callbacks."""
833865
src_file = storage_dir / "source"
834866
dest_file = "data/get_put_file/dest"
835867

836868
data = b"test" * 4
837-
838-
with open(src_file, "wb") as stream:
839-
stream.write(data)
869+
size = src_file.write_bytes(data)
840870

841871
class EventLogger(fsspec.Callback):
842872
"""Log callback values."""
@@ -855,9 +885,9 @@ def relative_update(self, value: int) -> None:
855885

856886
event_logger = EventLogger()
857887
fs.put_file(src_file, dest_file, chunk_size=4, callback=event_logger)
858-
assert fs.exists(dest_file)
888+
assert (storage_dir / dest_file).cat() == data.decode()
859889

860-
assert event_logger.events[0] == ("set_size", len(data))
890+
assert event_logger.events[0] == ("set_size", size)
861891
assert event_logger.events[1:] == [
862892
("relative_update", 4),
863893
("relative_update", 4),

0 commit comments

Comments
 (0)