|
1 | 1 | import copy |
| 2 | +import csv |
| 3 | +import io |
| 4 | +import itertools |
2 | 5 | import logging |
3 | | -from typing import List, Optional, Tuple |
| 6 | +from pathlib import Path |
| 7 | +import re |
| 8 | +from typing import List, Iterable, Optional, Tuple, Union |
4 | 9 |
|
5 | 10 | from tableauserverclient.server.query import QuerySet |
6 | 11 |
|
7 | 12 | from .endpoint import QuerysetEndpoint, api |
8 | 13 | from .exceptions import MissingRequiredFieldError, ServerResponseError |
9 | 14 | from tableauserverclient.server import RequestFactory, RequestOptions |
10 | | -from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem |
11 | | -from ..pager import Pager |
| 15 | +from tableauserverclient.models import UserItem, WorkbookItem, PaginationItem, GroupItem, JobItem |
| 16 | +from tableauserverclient.server.pager import Pager |
12 | 17 |
|
13 | 18 | from tableauserverclient.helpers.logging import logger |
14 | 19 |
|
@@ -97,8 +102,25 @@ def add_all(self, users: List[UserItem]): |
97 | 102 |
|
98 | 103 | # helping the user by parsing a file they could have used to add users through the UI |
99 | 104 | # line format: Username [required], password, display name, license, admin, publish |
| 105 | + @api(version="3.15") |
| 106 | + def bulk_add(self, users: Iterable[UserItem]) -> JobItem: |
| 107 | + """ |
| 108 | + line format: Username [required], password, display name, license, admin, publish |
| 109 | + """ |
| 110 | + url = f"{self.baseurl}/import" |
| 111 | + # Allow for iterators to be passed into the function |
| 112 | + csv_users, xml_users = itertools.tee(users, 2) |
| 113 | + csv_content = create_users_csv(csv_users) |
| 114 | + |
| 115 | + xml_request, content_type = RequestFactory.User.import_from_csv_req(csv_content, xml_users) |
| 116 | + server_response = self.post_request(url, xml_request, content_type) |
| 117 | + return JobItem.from_response(server_response.content, self.parent_srv.namespace).pop() |
| 118 | + |
100 | 119 | @api(version="2.0") |
101 | 120 | def create_from_file(self, filepath: str) -> Tuple[List[UserItem], List[Tuple[UserItem, ServerResponseError]]]: |
| 121 | + import warnings |
| 122 | + |
| 123 | + warnings.warn("This method is deprecated, use bulk_add instead", DeprecationWarning) |
102 | 124 | created = [] |
103 | 125 | failed = [] |
104 | 126 | if not filepath.find("csv"): |
@@ -159,16 +181,6 @@ def groups_for_user_pager(): |
159 | 181 |
|
160 | 182 | user_item._set_groups(groups_for_user_pager) |
161 | 183 |
|
162 | | - def _get_groups_for_user( |
163 | | - self, user_item: UserItem, req_options: Optional[RequestOptions] = None |
164 | | - ) -> Tuple[List[GroupItem], PaginationItem]: |
165 | | - url = "{0}/{1}/groups".format(self.baseurl, user_item.id) |
166 | | - server_response = self.get_request(url, req_options) |
167 | | - logger.info("Populated groups for user (ID: {0})".format(user_item.id)) |
168 | | - group_item = GroupItem.from_response(server_response.content, self.parent_srv.namespace) |
169 | | - pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) |
170 | | - return group_item, pagination_item |
171 | | - |
172 | 184 | def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySet[UserItem]: |
173 | 185 | """ |
174 | 186 | Queries the Tableau Server for items using the specified filters. Page |
@@ -205,3 +217,53 @@ def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySe |
205 | 217 | """ |
206 | 218 |
|
207 | 219 | return super().filter(*invalid, page_size=page_size, **kwargs) |
| 220 | + |
| 221 | + def _get_groups_for_user( |
| 222 | + self, user_item: UserItem, req_options: Optional[RequestOptions] = None |
| 223 | + ) -> Tuple[List[GroupItem], PaginationItem]: |
| 224 | + url = "{0}/{1}/groups".format(self.baseurl, user_item.id) |
| 225 | + server_response = self.get_request(url, req_options) |
| 226 | + logger.info("Populated groups for user (ID: {0})".format(user_item.id)) |
| 227 | + group_item = GroupItem.from_response(server_response.content, self.parent_srv.namespace) |
| 228 | + pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) |
| 229 | + return group_item, pagination_item |
| 230 | + |
| 231 | +def create_users_csv(users: Iterable[UserItem], identity_pool=None) -> bytes: |
| 232 | + """ |
| 233 | + Create a CSV byte string from an Iterable of UserItem objects |
| 234 | + """ |
| 235 | + if identity_pool is not None: |
| 236 | + raise NotImplementedError("Identity pool is not supported in this version") |
| 237 | + with io.StringIO() as output: |
| 238 | + writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL) |
| 239 | + for user in users: |
| 240 | + site_role = user.site_role or "Unlicensed" |
| 241 | + if site_role == "ServerAdministrator": |
| 242 | + license = "Creator" |
| 243 | + admin_level = "System" |
| 244 | + elif site_role.startswith("SiteAdministrator"): |
| 245 | + admin_level = "Site" |
| 246 | + license = site_role.replace("SiteAdministrator", "") |
| 247 | + else: |
| 248 | + license = site_role |
| 249 | + admin_level = "" |
| 250 | + |
| 251 | + if any(x in site_role for x in ("Creator", "Admin", "Publish")): |
| 252 | + publish = 1 |
| 253 | + else: |
| 254 | + publish = 0 |
| 255 | + |
| 256 | + writer.writerow( |
| 257 | + ( |
| 258 | + user.name, |
| 259 | + getattr(user, "password", ""), |
| 260 | + user.fullname, |
| 261 | + license, |
| 262 | + admin_level, |
| 263 | + publish, |
| 264 | + user.email, |
| 265 | + ) |
| 266 | + ) |
| 267 | + output.seek(0) |
| 268 | + result = output.read().encode("utf-8") |
| 269 | + return result |
0 commit comments