|
| 1 | +#! /usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# Copyright 2025 Google LLC |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | +import argparse |
| 18 | +import os |
| 19 | +import libcst as cst |
| 20 | +import pathlib |
| 21 | +import sys |
| 22 | +from typing import (Any, Callable, Dict, List, Sequence, Tuple) |
| 23 | + |
| 24 | + |
| 25 | +def partition( |
| 26 | + predicate: Callable[[Any], bool], |
| 27 | + iterator: Sequence[Any] |
| 28 | +) -> Tuple[List[Any], List[Any]]: |
| 29 | + """A stable, out-of-place partition.""" |
| 30 | + results = ([], []) |
| 31 | + |
| 32 | + for i in iterator: |
| 33 | + results[int(predicate(i))].append(i) |
| 34 | + |
| 35 | + # Returns trueList, falseList |
| 36 | + return results[1], results[0] |
| 37 | + |
| 38 | + |
| 39 | +class bigquery_storageCallTransformer(cst.CSTTransformer): |
| 40 | + CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') |
| 41 | + METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { |
| 42 | + 'append_rows': ('write_stream', 'offset', 'proto_rows', 'trace_id', ), |
| 43 | + 'batch_commit_write_streams': ('parent', 'write_streams', ), |
| 44 | + 'create_read_session': ('parent', 'read_session', 'max_stream_count', ), |
| 45 | + 'create_write_stream': ('parent', 'write_stream', ), |
| 46 | + 'finalize_write_stream': ('name', ), |
| 47 | + 'flush_rows': ('write_stream', 'offset', ), |
| 48 | + 'get_write_stream': ('name', ), |
| 49 | + 'read_rows': ('read_stream', 'offset', ), |
| 50 | + 'split_read_stream': ('name', 'fraction', ), |
| 51 | + } |
| 52 | + |
| 53 | + def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: |
| 54 | + try: |
| 55 | + key = original.func.attr.value |
| 56 | + kword_params = self.METHOD_TO_PARAMS[key] |
| 57 | + except (AttributeError, KeyError): |
| 58 | + # Either not a method from the API or too convoluted to be sure. |
| 59 | + return updated |
| 60 | + |
| 61 | + # If the existing code is valid, keyword args come after positional args. |
| 62 | + # Therefore, all positional args must map to the first parameters. |
| 63 | + args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) |
| 64 | + if any(k.keyword.value == "request" for k in kwargs): |
| 65 | + # We've already fixed this file, don't fix it again. |
| 66 | + return updated |
| 67 | + |
| 68 | + kwargs, ctrl_kwargs = partition( |
| 69 | + lambda a: a.keyword.value not in self.CTRL_PARAMS, |
| 70 | + kwargs |
| 71 | + ) |
| 72 | + |
| 73 | + args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] |
| 74 | + ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) |
| 75 | + for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) |
| 76 | + |
| 77 | + request_arg = cst.Arg( |
| 78 | + value=cst.Dict([ |
| 79 | + cst.DictElement( |
| 80 | + cst.SimpleString("'{}'".format(name)), |
| 81 | +cst.Element(value=arg.value) |
| 82 | + ) |
| 83 | + # Note: the args + kwargs looks silly, but keep in mind that |
| 84 | + # the control parameters had to be stripped out, and that |
| 85 | + # those could have been passed positionally or by keyword. |
| 86 | + for name, arg in zip(kword_params, args + kwargs)]), |
| 87 | + keyword=cst.Name("request") |
| 88 | + ) |
| 89 | + |
| 90 | + return updated.with_changes( |
| 91 | + args=[request_arg] + ctrl_kwargs |
| 92 | + ) |
| 93 | + |
| 94 | + |
| 95 | +def fix_files( |
| 96 | + in_dir: pathlib.Path, |
| 97 | + out_dir: pathlib.Path, |
| 98 | + *, |
| 99 | + transformer=bigquery_storageCallTransformer(), |
| 100 | +): |
| 101 | + """Duplicate the input dir to the output dir, fixing file method calls. |
| 102 | +
|
| 103 | + Preconditions: |
| 104 | + * in_dir is a real directory |
| 105 | + * out_dir is a real, empty directory |
| 106 | + """ |
| 107 | + pyfile_gen = ( |
| 108 | + pathlib.Path(os.path.join(root, f)) |
| 109 | + for root, _, files in os.walk(in_dir) |
| 110 | + for f in files if os.path.splitext(f)[1] == ".py" |
| 111 | + ) |
| 112 | + |
| 113 | + for fpath in pyfile_gen: |
| 114 | + with open(fpath, 'r') as f: |
| 115 | + src = f.read() |
| 116 | + |
| 117 | + # Parse the code and insert method call fixes. |
| 118 | + tree = cst.parse_module(src) |
| 119 | + updated = tree.visit(transformer) |
| 120 | + |
| 121 | + # Create the path and directory structure for the new file. |
| 122 | + updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) |
| 123 | + updated_path.parent.mkdir(parents=True, exist_ok=True) |
| 124 | + |
| 125 | + # Generate the updated source file at the corresponding path. |
| 126 | + with open(updated_path, 'w') as f: |
| 127 | + f.write(updated.code) |
| 128 | + |
| 129 | + |
| 130 | +if __name__ == '__main__': |
| 131 | + parser = argparse.ArgumentParser( |
| 132 | + description="""Fix up source that uses the bigquery_storage client library. |
| 133 | +
|
| 134 | +The existing sources are NOT overwritten but are copied to output_dir with changes made. |
| 135 | +
|
| 136 | +Note: This tool operates at a best-effort level at converting positional |
| 137 | + parameters in client method calls to keyword based parameters. |
| 138 | + Cases where it WILL FAIL include |
| 139 | + A) * or ** expansion in a method call. |
| 140 | + B) Calls via function or method alias (includes free function calls) |
| 141 | + C) Indirect or dispatched calls (e.g. the method is looked up dynamically) |
| 142 | +
|
| 143 | + These all constitute false negatives. The tool will also detect false |
| 144 | + positives when an API method shares a name with another method. |
| 145 | +""") |
| 146 | + parser.add_argument( |
| 147 | + '-d', |
| 148 | + '--input-directory', |
| 149 | + required=True, |
| 150 | + dest='input_dir', |
| 151 | + help='the input directory to walk for python files to fix up', |
| 152 | + ) |
| 153 | + parser.add_argument( |
| 154 | + '-o', |
| 155 | + '--output-directory', |
| 156 | + required=True, |
| 157 | + dest='output_dir', |
| 158 | + help='the directory to output files fixed via un-flattening', |
| 159 | + ) |
| 160 | + args = parser.parse_args() |
| 161 | + input_dir = pathlib.Path(args.input_dir) |
| 162 | + output_dir = pathlib.Path(args.output_dir) |
| 163 | + if not input_dir.is_dir(): |
| 164 | + print( |
| 165 | + f"input directory '{input_dir}' does not exist or is not a directory", |
| 166 | + file=sys.stderr, |
| 167 | + ) |
| 168 | + sys.exit(-1) |
| 169 | + |
| 170 | + if not output_dir.is_dir(): |
| 171 | + print( |
| 172 | + f"output directory '{output_dir}' does not exist or is not a directory", |
| 173 | + file=sys.stderr, |
| 174 | + ) |
| 175 | + sys.exit(-1) |
| 176 | + |
| 177 | + if os.listdir(output_dir): |
| 178 | + print( |
| 179 | + f"output directory '{output_dir}' is not empty", |
| 180 | + file=sys.stderr, |
| 181 | + ) |
| 182 | + sys.exit(-1) |
| 183 | + |
| 184 | + fix_files(input_dir, output_dir) |
0 commit comments