Skip to content

Commit 942b705

Browse files
committed
feat: enable bulk adding users
1 parent a3028d7 commit 942b705

File tree

4 files changed

+178
-14
lines changed

4 files changed

+178
-14
lines changed

tableauserverclient/server/endpoint/users_endpoint.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import copy
2+
import csv
3+
import io
4+
import itertools
25
import logging
3-
from typing import List, Optional, Tuple
6+
from pathlib import Path
7+
import re
8+
from typing import List, Iterable, Optional, Tuple, Union
49

510
from tableauserverclient.server.query import QuerySet
611

712
from .endpoint import QuerysetEndpoint, api
813
from .exceptions import MissingRequiredFieldError, ServerResponseError
914
from tableauserverclient.server import RequestFactory, RequestOptions
10-
from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem
11-
from ..pager import Pager
15+
from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem, JobItem
16+
from tableauserverclient.server.pager import Pager
1217

1318
from tableauserverclient.helpers.logging import logger
1419

@@ -97,8 +102,25 @@ def add_all(self, users: List[UserItem]):
97102

98103
# helping the user by parsing a file they could have used to add users through the UI
99104
# line format: Username [required], password, display name, license, admin, publish
105+
@api(version="3.15")
106+
def bulk_add(self, users: Iterable[UserItem]) -> JobItem:
107+
"""
108+
line format: Username [required], password, display name, license, admin, publish
109+
"""
110+
url = f"{self.baseurl}/import"
111+
# Allow for iterators to be passed into the function
112+
csv_users, xml_users = itertools.tee(users, 2)
113+
csv_content = create_users_csv(csv_users)
114+
115+
xml_request, content_type = RequestFactory.User.import_from_csv_req(csv_content, xml_users)
116+
server_response = self.post_request(url, xml_request, content_type)
117+
return JobItem.from_response(server_response.content, self.parent_srv.namespace).pop()
118+
100119
@api(version="2.0")
101120
def create_from_file(self, filepath: str) -> Tuple[List[UserItem], List[Tuple[UserItem, ServerResponseError]]]:
121+
import warnings
122+
123+
warnings.warn("This method is deprecated, use bulk_add instead", DeprecationWarning)
102124
created = []
103125
failed = []
104126
if not filepath.find("csv"):
@@ -159,16 +181,6 @@ def groups_for_user_pager():
159181

160182
user_item._set_groups(groups_for_user_pager)
161183

162-
def _get_groups_for_user(
163-
self, user_item: UserItem, req_options: Optional[RequestOptions] = None
164-
) -> Tuple[List[GroupItem], PaginationItem]:
165-
url = "{0}/{1}/groups".format(self.baseurl, user_item.id)
166-
server_response = self.get_request(url, req_options)
167-
logger.info("Populated groups for user (ID: {0})".format(user_item.id))
168-
group_item = GroupItem.from_response(server_response.content, self.parent_srv.namespace)
169-
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
170-
return group_item, pagination_item
171-
172184
def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySet[UserItem]:
173185
"""
174186
Queries the Tableau Server for items using the specified filters. Page
@@ -205,3 +217,53 @@ def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySe
205217
"""
206218

207219
return super().filter(*invalid, page_size=page_size, **kwargs)
220+
221+
def _get_groups_for_user(
222+
self, user_item: UserItem, req_options: Optional[RequestOptions] = None
223+
) -> Tuple[List[GroupItem], PaginationItem]:
224+
url = "{0}/{1}/groups".format(self.baseurl, user_item.id)
225+
server_response = self.get_request(url, req_options)
226+
logger.info("Populated groups for user (ID: {0})".format(user_item.id))
227+
group_item = GroupItem.from_response(server_response.content, self.parent_srv.namespace)
228+
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
229+
return group_item, pagination_item
230+
231+
def create_users_csv(users: Iterable[UserItem], identity_pool=None) -> bytes:
232+
"""
233+
Create a CSV byte string from an Iterable of UserItem objects
234+
"""
235+
if identity_pool is not None:
236+
raise NotImplementedError("Identity pool is not supported in this version")
237+
with io.StringIO() as output:
238+
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
239+
for user in users:
240+
site_role = user.site_role or "Unlicensed"
241+
if site_role == "ServerAdministrator":
242+
license = "Creator"
243+
admin_level = "System"
244+
elif site_role.startswith("SiteAdministrator"):
245+
admin_level = "Site"
246+
license = site_role.replace("SiteAdministrator", "")
247+
else:
248+
license = site_role
249+
admin_level = ""
250+
251+
if any(x in site_role for x in ("Creator", "Admin", "Publish")):
252+
publish = 1
253+
else:
254+
publish = 0
255+
256+
writer.writerow(
257+
(
258+
user.name,
259+
getattr(user, "password", ""),
260+
user.fullname,
261+
license,
262+
admin_level,
263+
publish,
264+
user.email,
265+
)
266+
)
267+
output.seek(0)
268+
result = output.read().encode("utf-8")
269+
return result

tableauserverclient/server/request_factory.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,21 @@ def add_req(self, user_item: UserItem) -> bytes:
905905
user_element.attrib["authSetting"] = user_item.auth_setting
906906
return ET.tostring(xml_request)
907907

908+
def import_from_csv_req(self, csv_content: bytes, users: Iterable[UserItem]):
909+
xml_request = ET.Element("tsRequest")
910+
for user in users:
911+
if user.name is None:
912+
raise ValueError("User name must be populated.")
913+
user_element = ET.SubElement(xml_request, "user")
914+
user_element.attrib["name"] = user.name
915+
user_element.attrib["authSetting"] = user.auth_setting or "ServerDefault"
916+
917+
parts = {
918+
"tableau_user_import": ("tsc_users_file.csv", csv_content, "file"),
919+
"request_payload": ("", ET.tostring(xml_request), "text/xml"),
920+
}
921+
return _add_multipart(parts)
922+
908923

909924
class WorkbookRequest(object):
910925
def _generate_xml(

test/assets/users_bulk_add_job.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<tsResponse xmlns="http://tableau.com/api" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://tableau.com/api https://help.tableau.com/samples/en-us/rest_api/ts-api_3_20.xsd">
3+
<job id="16a3479e-0ff9-4685-a0e4-1533b3c2eb96" mode="Asynchronous" type="UserImport" progress="0" createdAt="2024-06-27T03:21:02Z" finishCode="1"/>
4+
</tsResponse>

test/test_user.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
import csv
12
import io
23
import os
4+
from pathlib import Path
35
import unittest
46
from typing import List
57
from unittest.mock import MagicMock
68

9+
from defusedxml.ElementTree import fromstring
710
import requests_mock
811

912
import tableauserverclient as TSC
1013
from tableauserverclient.datetime_helpers import format_datetime
1114

12-
TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets")
15+
TEST_ASSET_DIR = Path(__file__).resolve().parent / "assets"
1316

17+
BULK_ADD_XML = TEST_ASSET_DIR / "users_bulk_add_job.xml"
1418
GET_XML = os.path.join(TEST_ASSET_DIR, "user_get.xml")
1519
GET_EMPTY_XML = os.path.join(TEST_ASSET_DIR, "user_get_empty.xml")
1620
GET_BY_ID_XML = os.path.join(TEST_ASSET_DIR, "user_get_by_id.xml")
@@ -236,3 +240,82 @@ def test_get_users_from_file(self):
236240
users, failures = self.server.users.create_from_file(USERS)
237241
assert users[0].name == "Cassie", users
238242
assert failures == []
243+
244+
def test_bulk_add(self):
245+
self.server.version = "3.15"
246+
users = [
247+
TSC.UserItem(
248+
"test",
249+
"Viewer",
250+
)
251+
]
252+
with requests_mock.mock() as m:
253+
m.post(f"{self.server.users.baseurl}/import", text=BULK_ADD_XML.read_text())
254+
255+
job = self.server.users.bulk_add(users)
256+
257+
assert m.last_request.method == "POST"
258+
assert m.last_request.url == f"{self.server.users.baseurl}/import"
259+
260+
body = m.last_request.body.replace(b"\r\n", b"\n")
261+
assert body.startswith(b"--") # Check if it's a multipart request
262+
boundary = body.split(b"\n")[0].strip()
263+
264+
# Body starts and ends with a boundary string. Split the body into
265+
# segments and ignore the empty sections at the start and end.
266+
segments = [seg for s in body.split(boundary) if (seg := s.strip()) not in [b"", b"--"]]
267+
assert len(segments) == 2 # Check if there are two segments
268+
269+
# Check if the first segment is the csv file and the second segment is the xml
270+
assert b'Content-Disposition: form-data; name="tableau_user_import"' in segments[0]
271+
assert b'Content-Disposition: form-data; name="request_payload"' in segments[1]
272+
assert b"Content-Type: file" in segments[0]
273+
assert b"Content-Type: text/xml" in segments[1]
274+
275+
xml_string = segments[1].split(b"\n\n")[1].strip()
276+
xml = fromstring(xml_string)
277+
xml_users = xml.findall(".//user", namespaces={})
278+
assert len(xml_users) == len(users)
279+
280+
for user, xml_user in zip(users, xml_users):
281+
assert user.name == xml_user.get("name")
282+
assert xml_user.get("authSetting") == (user.auth_setting or "ServerDefault")
283+
284+
license_map = {
285+
"Viewer": "Viewer",
286+
"Explorer": "Explorer",
287+
"ExplorerCanPublish": "Explorer",
288+
"Creator": "Creator",
289+
"SiteAdministratorExplorer": "Explorer",
290+
"SiteAdministratorCreator": "Creator",
291+
"ServerAdministrator": "Creator",
292+
"Unlicensed": "Unlicensed",
293+
}
294+
publish_map = {
295+
"Unlicensed": 0,
296+
"Viewer": 0,
297+
"Explorer": 0,
298+
"Creator": 1,
299+
"ExplorerCanPublish": 1,
300+
"SiteAdministratorExplorer": 1,
301+
"SiteAdministratorCreator": 1,
302+
"ServerAdministrator": 1,
303+
}
304+
admin_map = {
305+
"SiteAdministratorExplorer": "Site",
306+
"SiteAdministratorCreator": "Site",
307+
"ServerAdministrator": "System",
308+
}
309+
310+
csv_columns = ["name", "password", "fullname", "license", "admin", "publish", "email"]
311+
csv_file = io.StringIO(segments[0].split(b"\n\n")[1].decode("utf-8"))
312+
csv_reader = csv.reader(csv_file)
313+
for user, row in zip(users, csv_reader):
314+
site_role = user.site_role or "Unlicensed"
315+
csv_user = dict(zip(csv_columns, row))
316+
assert user.name == csv_user["name"]
317+
assert (user.fullname or "") == csv_user["fullname"]
318+
assert (user.email or "") == csv_user["email"]
319+
assert license_map[site_role] == csv_user["license"]
320+
assert admin_map.get(site_role, "") == csv_user["admin"]
321+
assert publish_map[site_role] == int(csv_user["publish"])

0 commit comments

Comments
 (0)