diff --git a/util/size_estimation.py b/util/size_estimation.py new file mode 100644 index 0000000..2b9bab1 --- /dev/null +++ b/util/size_estimation.py @@ -0,0 +1,61 @@ +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from exllamav3 import Config, Model +import argparse +from exllamav3.loader.safetensors import SafetensorsCollection, VariantSafetensorsCollection +import yaml + + +def tsize(t): + return t.nelement() * t.element_size() + + +def dsize(d): + size = 0 + for _, v in d.items(): size += tsize(v) + return size + + +def main(args): + + # Config/model + config = Config.from_directory(args.in_dir) + model = Model.from_config(config) + + # Tensor collection + stc = SafetensorsCollection(args.in_dir) + + # Override tensors + if args.override: + with open(args.override, "r") as f: + comp = yaml.safe_load(f) + sources = {s["id"]: s["model_dir"] for s in comp["sources"]} + overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]} + collections = {} + for o_key, o_dir in overrides.items(): + if o_dir not in collections: + collections[o_dir] = [] + collections[o_dir].append(o_key) + if len(collections): + vstc = VariantSafetensorsCollection(config.stc) + for o_dir, o_keys in collections.items(): + print(f" -- Overriding from: {o_dir}:") + for o_key in o_keys: + print(f" {o_key}") + vstc.add_stc(o_keys, SafetensorsCollection(o_dir)) + config.stc = vstc + + # New bpw etc. + bpw_layer, bpw_head, vram_bits = model.get_storage_info() + bpw_layer = round(bpw_layer, 2) + bpw_head = round(bpw_head) + print(f" -- New estimated model bitrate: {bpw_layer:.2f} bpw / {bpw_head:.2f} bpw (head)") + print(f" -- VRAM: {vram_bits / 8 / 1024**3:.0f} GiB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--in_dir", type = str, default = None, help = "Input model directory") + parser.add_argument("-or", "--override", type = str, help = "Tensor override spec (YAML)", default = None) + _args = parser.parse_args() + main(_args)