Skip to content

Commit 7caa7cb

Browse files
atheo89jiridanek
authored andcommitted
RHAIENG-332: (feat): add basic llm compressor test (opendatahub-io#1972)
1 parent cb6c6a9 commit 7caa7cb

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "fc39f0e9",
7+
"metadata": {
8+
"vscode": {
9+
"languageId": "plaintext"
10+
}
11+
},
12+
"outputs": [],
13+
"source": [
14+
"import os\n",
15+
"import unittest\n",
16+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
17+
"from llmcompressor.modifiers.quantization import QuantizationModifier\n",
18+
"from llmcompressor.transformers import oneshot\n",
19+
"\n",
20+
"class TestQuantizationProcess(unittest.TestCase):\n",
21+
"\n",
22+
" def test_quantization_and_save(self):\n",
23+
" # Load model\n",
24+
" model_stub = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
25+
" model_name = model_stub.split(\"/\")[-1]\n",
26+
"\n",
27+
" model = AutoModelForCausalLM.from_pretrained(\n",
28+
" model_stub,\n",
29+
" torch_dtype=\"auto\",\n",
30+
" )\n",
31+
"\n",
32+
" tokenizer = AutoTokenizer.from_pretrained(model_stub)\n",
33+
"\n",
34+
" # Configure the quantization algorithm and scheme\n",
35+
" recipe = QuantizationModifier(\n",
36+
" targets=\"Linear\",\n",
37+
" scheme=\"FP8_DYNAMIC\",\n",
38+
" ignore=[\"lm_head\"],\n",
39+
" )\n",
40+
"\n",
41+
" # Apply quantization\n",
42+
" oneshot(\n",
43+
" model=model,\n",
44+
" recipe=recipe,\n",
45+
" )\n",
46+
"\n",
47+
" # Save to disk in compressed-tensors format\n",
48+
" save_path = model_name + \"-FP8-dynamic\"\n",
49+
" model.save_pretrained(save_path)\n",
50+
" tokenizer.save_pretrained(save_path)\n",
51+
"\n",
52+
" # Assertions to verify save\n",
53+
" self.assertTrue(os.path.exists(save_path), f\"Save path does not exist: {save_path}\")\n",
54+
" self.assertTrue(os.path.exists(os.path.join(save_path, \"config.json\")), \"Model config not found\")\n",
55+
" self.assertTrue(os.path.exists(os.path.join(save_path, \"tokenizer_config.json\")), \"Tokenizer config not found\")\n",
56+
"\n",
57+
"unittest.main(argv=[''], verbosity=2, exit=False)\n"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"id": "f7bc63bb",
64+
"metadata": {
65+
"vscode": {
66+
"languageId": "plaintext"
67+
}
68+
},
69+
"outputs": [],
70+
"source": []
71+
}
72+
],
73+
"metadata": {
74+
"language_info": {
75+
"name": "python"
76+
}
77+
},
78+
"nbformat": 4,
79+
"nbformat_minor": 5
80+
}

0 commit comments

Comments
 (0)