|
4 | 4 | import click |
5 | 5 | import configparser |
6 | 6 | from csv import DictWriter |
| 7 | +import datetime |
7 | 8 | import fnmatch |
| 9 | +from http.server import HTTPServer, BaseHTTPRequestHandler |
8 | 10 | import io |
9 | 11 | import itertools |
10 | 12 | import json |
|
14 | 16 | import re |
15 | 17 | import sys |
16 | 18 | import textwrap |
| 19 | +import threading |
| 20 | +import time |
17 | 21 | from . import policies |
18 | 22 |
|
19 | 23 | PUBLIC_ACCESS_BLOCK_CONFIGURATION = { |
@@ -143,6 +147,27 @@ def convert(self, value, param, ctx): |
143 | 147 | return integer |
144 | 148 |
|
145 | 149 |
|
| 150 | +class RefreshIntervalParam(click.ParamType): |
| 151 | + "Parses refresh interval values like 30s, 5m, 1h" |
| 152 | + name = "refresh_interval" |
| 153 | + pattern = re.compile(r"^(\d+)(m|h|s)?$") |
| 154 | + |
| 155 | + def convert(self, value, param, ctx): |
| 156 | + match = self.pattern.match(value) |
| 157 | + if match is None: |
| 158 | + self.fail("Refresh interval must be of form 30s or 5m or 1h") |
| 159 | + integer_string, suffix = match.groups() |
| 160 | + integer = int(integer_string) |
| 161 | + if suffix == "m": |
| 162 | + integer *= 60 |
| 163 | + elif suffix == "h": |
| 164 | + integer *= 3600 |
| 165 | + # Must be at least 1 second |
| 166 | + if integer < 1: |
| 167 | + self.fail("Refresh interval must be at least 1 second") |
| 168 | + return integer |
| 169 | + |
| 170 | + |
146 | 171 | class StatementParam(click.ParamType): |
147 | 172 | "Ensures statement is valid JSON with required fields" |
148 | 173 | name = "statement" |
@@ -1638,6 +1663,263 @@ def set_public_access_block( |
1638 | 1663 | ) |
1639 | 1664 |
|
1640 | 1665 |
|
| 1666 | +class CredentialCache: |
| 1667 | + """Thread-safe credential cache that regenerates credentials on expiry.""" |
| 1668 | + |
| 1669 | + def __init__( |
| 1670 | + self, iam, sts, bucket, permission, prefix, refresh_interval, extra_statements |
| 1671 | + ): |
| 1672 | + self.iam = iam |
| 1673 | + self.sts = sts |
| 1674 | + self.bucket = bucket |
| 1675 | + self.permission = permission |
| 1676 | + self.prefix = prefix |
| 1677 | + self.refresh_interval = refresh_interval |
| 1678 | + self.extra_statements = extra_statements |
| 1679 | + self._credentials = None |
| 1680 | + self._expiry_time = None |
| 1681 | + self._lock = threading.Lock() |
| 1682 | + self._generating = False |
| 1683 | + |
| 1684 | + def _generate_policy(self): |
| 1685 | + """Generate the IAM policy for bucket access.""" |
| 1686 | + statements = [] |
| 1687 | + if self.permission == "read-write": |
| 1688 | + statements.extend(policies.read_write_statements(self.bucket, self.prefix)) |
| 1689 | + elif self.permission == "read-only": |
| 1690 | + statements.extend(policies.read_only_statements(self.bucket, self.prefix)) |
| 1691 | + elif self.permission == "write-only": |
| 1692 | + statements.extend(policies.write_only_statements(self.bucket, self.prefix)) |
| 1693 | + if self.extra_statements: |
| 1694 | + statements.extend(self.extra_statements) |
| 1695 | + return policies.wrap_policy(statements) |
| 1696 | + |
| 1697 | + def _generate_credentials(self): |
| 1698 | + """Generate new temporary credentials using STS assume_role.""" |
| 1699 | + s3_role_arn = ensure_s3_role_exists(self.iam, self.sts) |
| 1700 | + # Duration should be refresh_interval + buffer, but must be between 15min and 12h |
| 1701 | + # Add 60 seconds buffer to ensure credentials don't expire mid-request |
| 1702 | + duration = max(15 * 60, min(self.refresh_interval + 60, 12 * 60 * 60)) |
| 1703 | + |
| 1704 | + policy_document = self._generate_policy() |
| 1705 | + credentials_response = self.sts.assume_role( |
| 1706 | + RoleArn=s3_role_arn, |
| 1707 | + RoleSessionName="s3.{permission}.{bucket}".format( |
| 1708 | + permission=self.permission, |
| 1709 | + bucket=self.bucket, |
| 1710 | + ), |
| 1711 | + Policy=json.dumps(policy_document), |
| 1712 | + DurationSeconds=duration, |
| 1713 | + ) |
| 1714 | + return credentials_response["Credentials"] |
| 1715 | + |
| 1716 | + def get_credentials(self): |
| 1717 | + """Get cached credentials, regenerating if expired or about to expire.""" |
| 1718 | + current_time = time.time() |
| 1719 | + |
| 1720 | + # Check if we need new credentials |
| 1721 | + with self._lock: |
| 1722 | + if self._credentials is not None and self._expiry_time is not None: |
| 1723 | + # Return cached credentials if still valid |
| 1724 | + if current_time < self._expiry_time: |
| 1725 | + return self._credentials |
| 1726 | + |
| 1727 | + # Need to generate new credentials |
| 1728 | + # Check if another thread is already generating |
| 1729 | + if self._generating: |
| 1730 | + # Wait for the other thread to finish |
| 1731 | + while self._generating: |
| 1732 | + self._lock.release() |
| 1733 | + time.sleep(0.1) |
| 1734 | + self._lock.acquire() |
| 1735 | + return self._credentials |
| 1736 | + |
| 1737 | + # Mark that we're generating |
| 1738 | + self._generating = True |
| 1739 | + |
| 1740 | + try: |
| 1741 | + # Generate new credentials outside the lock |
| 1742 | + credentials = self._generate_credentials() |
| 1743 | + with self._lock: |
| 1744 | + self._credentials = credentials |
| 1745 | + # Set expiry time to refresh_interval from now |
| 1746 | + self._expiry_time = current_time + self.refresh_interval |
| 1747 | + self._generating = False |
| 1748 | + return credentials |
| 1749 | + except Exception: |
| 1750 | + with self._lock: |
| 1751 | + self._generating = False |
| 1752 | + raise |
| 1753 | + |
| 1754 | + |
| 1755 | +def make_credential_handler(credential_cache): |
| 1756 | + """Create an HTTP request handler class with access to the credential cache.""" |
| 1757 | + |
| 1758 | + class CredentialHandler(BaseHTTPRequestHandler): |
| 1759 | + def log_message(self, format, *args): |
| 1760 | + # Log to stderr with timestamp |
| 1761 | + click.echo( |
| 1762 | + "{} - {}".format( |
| 1763 | + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| 1764 | + format % args, |
| 1765 | + ), |
| 1766 | + err=True, |
| 1767 | + ) |
| 1768 | + |
| 1769 | + def do_GET(self): |
| 1770 | + if self.path != "/": |
| 1771 | + self.send_response(404) |
| 1772 | + self.send_header("Content-Type", "application/json") |
| 1773 | + self.end_headers() |
| 1774 | + self.wfile.write(json.dumps({"error": "Not found"}).encode()) |
| 1775 | + return |
| 1776 | + |
| 1777 | + try: |
| 1778 | + credentials = credential_cache.get_credentials() |
| 1779 | + response_data = { |
| 1780 | + "AccessKeyId": credentials["AccessKeyId"], |
| 1781 | + "SecretAccessKey": credentials["SecretAccessKey"], |
| 1782 | + "SessionToken": credentials["SessionToken"], |
| 1783 | + "Expiration": ( |
| 1784 | + credentials["Expiration"].isoformat() |
| 1785 | + if hasattr(credentials["Expiration"], "isoformat") |
| 1786 | + else str(credentials["Expiration"]) |
| 1787 | + ), |
| 1788 | + } |
| 1789 | + self.send_response(200) |
| 1790 | + self.send_header("Content-Type", "application/json") |
| 1791 | + self.end_headers() |
| 1792 | + self.wfile.write(json.dumps(response_data, indent=2).encode()) |
| 1793 | + except Exception as e: |
| 1794 | + self.send_response(500) |
| 1795 | + self.send_header("Content-Type", "application/json") |
| 1796 | + self.end_headers() |
| 1797 | + self.wfile.write(json.dumps({"error": str(e)}).encode()) |
| 1798 | + |
| 1799 | + return CredentialHandler |
| 1800 | + |
| 1801 | + |
| 1802 | +@cli.command() |
| 1803 | +@click.argument("bucket") |
| 1804 | +@click.option( |
| 1805 | + "-p", |
| 1806 | + "--port", |
| 1807 | + type=int, |
| 1808 | + default=8094, |
| 1809 | + help="Port to run the server on (default: 8094)", |
| 1810 | +) |
| 1811 | +@click.option( |
| 1812 | + "--host", |
| 1813 | + default="localhost", |
| 1814 | + help="Host to bind the server to (default: localhost)", |
| 1815 | +) |
| 1816 | +@click.option("--read-only", help="Only allow reading from the bucket", is_flag=True) |
| 1817 | +@click.option("--write-only", help="Only allow writing to the bucket", is_flag=True) |
| 1818 | +@click.option( |
| 1819 | + "--prefix", help="Restrict to keys starting with this prefix", default="*" |
| 1820 | +) |
| 1821 | +@click.option( |
| 1822 | + "extra_statements", |
| 1823 | + "--statement", |
| 1824 | + multiple=True, |
| 1825 | + type=StatementParam(), |
| 1826 | + help="JSON statement to add to the policy", |
| 1827 | +) |
| 1828 | +@click.option( |
| 1829 | + "--refresh-interval", |
| 1830 | + type=RefreshIntervalParam(), |
| 1831 | + required=True, |
| 1832 | + help="How often to refresh credentials, e.g. 30s, 5m, 1h", |
| 1833 | +) |
| 1834 | +@common_boto3_options |
| 1835 | +def localserver( |
| 1836 | + bucket, |
| 1837 | + port, |
| 1838 | + host, |
| 1839 | + read_only, |
| 1840 | + write_only, |
| 1841 | + prefix, |
| 1842 | + extra_statements, |
| 1843 | + refresh_interval, |
| 1844 | + **boto_options, |
| 1845 | +): |
| 1846 | + """ |
| 1847 | + Start a localhost server that serves S3 credentials. |
| 1848 | +
|
| 1849 | + The server responds to GET requests on / with JSON containing temporary |
| 1850 | + AWS credentials that allow access to the specified bucket. |
| 1851 | +
|
| 1852 | + Credentials are cached and refreshed automatically based on the |
| 1853 | + --refresh-interval setting. |
| 1854 | +
|
| 1855 | + To start a server that serves read-only credentials for a bucket, |
| 1856 | + refreshing every 5 minutes: |
| 1857 | +
|
| 1858 | + s3-credentials localserver my-bucket --read-only --refresh-interval 5m |
| 1859 | +
|
| 1860 | + To run on a different port: |
| 1861 | +
|
| 1862 | + s3-credentials localserver my-bucket --refresh-interval 5m --port 9000 |
| 1863 | + """ |
| 1864 | + if read_only and write_only: |
| 1865 | + raise click.ClickException( |
| 1866 | + "Cannot use --read-only and --write-only at the same time" |
| 1867 | + ) |
| 1868 | + extra_statements = list(extra_statements) |
| 1869 | + |
| 1870 | + permission = "read-write" |
| 1871 | + if read_only: |
| 1872 | + permission = "read-only" |
| 1873 | + if write_only: |
| 1874 | + permission = "write-only" |
| 1875 | + |
| 1876 | + # Create AWS clients |
| 1877 | + iam = make_client("iam", **boto_options) |
| 1878 | + sts = make_client("sts", **boto_options) |
| 1879 | + s3 = make_client("s3", **boto_options) |
| 1880 | + |
| 1881 | + # Verify bucket exists |
| 1882 | + if not bucket_exists(s3, bucket): |
| 1883 | + raise click.ClickException("Bucket does not exist: {}".format(bucket)) |
| 1884 | + |
| 1885 | + # Create credential cache |
| 1886 | + credential_cache = CredentialCache( |
| 1887 | + iam=iam, |
| 1888 | + sts=sts, |
| 1889 | + bucket=bucket, |
| 1890 | + permission=permission, |
| 1891 | + prefix=prefix, |
| 1892 | + refresh_interval=refresh_interval, |
| 1893 | + extra_statements=extra_statements, |
| 1894 | + ) |
| 1895 | + |
| 1896 | + # Pre-generate credentials to catch any errors early |
| 1897 | + click.echo("Generating initial credentials...", err=True) |
| 1898 | + try: |
| 1899 | + credential_cache.get_credentials() |
| 1900 | + except Exception as e: |
| 1901 | + raise click.ClickException("Failed to generate credentials: {}".format(e)) |
| 1902 | + |
| 1903 | + # Create and start server |
| 1904 | + handler = make_credential_handler(credential_cache) |
| 1905 | + server = HTTPServer((host, port), handler) |
| 1906 | + |
| 1907 | + click.echo( |
| 1908 | + "Serving {} credentials for bucket '{}' at http://{}:{}/".format( |
| 1909 | + permission, bucket, host, port |
| 1910 | + ), |
| 1911 | + err=True, |
| 1912 | + ) |
| 1913 | + click.echo("Refresh interval: {} seconds".format(refresh_interval), err=True) |
| 1914 | + click.echo("Press Ctrl+C to stop", err=True) |
| 1915 | + |
| 1916 | + try: |
| 1917 | + server.serve_forever() |
| 1918 | + except KeyboardInterrupt: |
| 1919 | + click.echo("\nShutting down server...", err=True) |
| 1920 | + server.shutdown() |
| 1921 | + |
| 1922 | + |
1641 | 1923 | def output(iterator, headers, nl, csv, tsv): |
1642 | 1924 | if nl: |
1643 | 1925 | for item in iterator: |
|
0 commit comments