Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions util/size_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav3 import Config, Model
import argparse
from exllamav3.loader.safetensors import SafetensorsCollection, VariantSafetensorsCollection
import yaml


def tsize(t):
return t.nelement() * t.element_size()


def dsize(d):
size = 0
for _, v in d.items(): size += tsize(v)
return size


def main(args):

# Config/model
config = Config.from_directory(args.in_dir)
model = Model.from_config(config)

# Tensor collection
stc = SafetensorsCollection(args.in_dir)

# Override tensors
if args.override:
with open(args.override, "r") as f:
comp = yaml.safe_load(f)
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
collections = {}
for o_key, o_dir in overrides.items():
if o_dir not in collections:
collections[o_dir] = []
collections[o_dir].append(o_key)
if len(collections):
vstc = VariantSafetensorsCollection(config.stc)
for o_dir, o_keys in collections.items():
print(f" -- Overriding from: {o_dir}:")
for o_key in o_keys:
print(f" {o_key}")
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
config.stc = vstc

# New bpw etc.
bpw_layer, bpw_head, vram_bits = model.get_storage_info()
bpw_layer = round(bpw_layer, 2)
bpw_head = round(bpw_head)
print(f" -- New estimated model bitrate: {bpw_layer:.2f} bpw / {bpw_head:.2f} bpw (head)")
print(f" -- VRAM: {vram_bits / 8 / 1024**3:.0f} GiB")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--in_dir", type = str, default = None, help = "Input model directory")
parser.add_argument("-or", "--override", type = str, help = "Tensor override spec (YAML)", default = None)
_args = parser.parse_args()
main(_args)