|  | 
| 188 | 188 |     "text" | 
| 189 | 189 |    ] | 
| 190 | 190 |   }, | 
|  | 191 | +  { | 
|  | 192 | +   "metadata": {}, | 
|  | 193 | +   "cell_type": "markdown", | 
|  | 194 | +   "source": "## Initializing TMaRCo", | 
|  | 195 | +   "id": "1eb7719e30054304" | 
|  | 196 | +  }, | 
| 191 | 197 |   { | 
| 192 | 198 |    "cell_type": "code", | 
| 193 | 199 |    "execution_count": 5, | 
|  | 
| 198 | 204 |     "tmarco = TMaRCo()" | 
| 199 | 205 |    ] | 
| 200 | 206 |   }, | 
|  | 207 | +  { | 
|  | 208 | +   "metadata": {}, | 
|  | 209 | +   "cell_type": "markdown", | 
|  | 210 | +   "source": [ | 
|  | 211 | +    "This will initialize `TMaRCo` using the default models, taken from HuggingFace.\n", | 
|  | 212 | +    "<div class=\"alert alert-info\">\n", | 
|  | 213 | +    "To use local models with TMaRCo, we need to have them in a local storage, accessible to TMaRCo, initialize separately, and pass them to TMaRCo's constructor.\n", | 
|  | 214 | +    "</div>\n", | 
|  | 215 | +    "For instance, to use the default `facebook/bart-large` model, but locally. First, we would need to retrieve the model:" | 
|  | 216 | +   ], | 
|  | 217 | +   "id": "3e16ee305f4983d9" | 
|  | 218 | +  }, | 
|  | 219 | +  { | 
|  | 220 | +   "metadata": {}, | 
|  | 221 | +   "cell_type": "code", | 
|  | 222 | +   "outputs": [], | 
|  | 223 | +   "execution_count": null, | 
|  | 224 | +   "source": [ | 
|  | 225 | +    "from huggingface_hub import snapshot_download\n", | 
|  | 226 | +    "\n", | 
|  | 227 | +    "snapshot_download(repo_id=\"facebook/bart-large\", local_dir=\"models/bart\")" | 
|  | 228 | +   ], | 
|  | 229 | +   "id": "614c9ff6f46a0ea9" | 
|  | 230 | +  }, | 
|  | 231 | +  { | 
|  | 232 | +   "metadata": {}, | 
|  | 233 | +   "cell_type": "markdown", | 
|  | 234 | +   "source": "We now initialize the base model and tokenizer from local files and pass them to `TMaRCo`:", | 
|  | 235 | +   "id": "95bd792e757205d6" | 
|  | 236 | +  }, | 
|  | 237 | +  { | 
|  | 238 | +   "metadata": {}, | 
|  | 239 | +   "cell_type": "code", | 
|  | 240 | +   "source": [ | 
|  | 241 | +    "from transformers import BartForConditionalGeneration, BartTokenizer\n", | 
|  | 242 | +    "\n", | 
|  | 243 | +    "tokenizer = BartTokenizer.from_pretrained(\n", | 
|  | 244 | +    "    \"models/bart\", # Or directory where the local model is stored \n", | 
|  | 245 | +    "    is_split_into_words=True, add_prefix_space=True\n", | 
|  | 246 | +    ")\n", | 
|  | 247 | +    "\n", | 
|  | 248 | +    "tokenizer.pad_token_id = tokenizer.eos_token_id\n", | 
|  | 249 | +    "\n", | 
|  | 250 | +    "base = BartForConditionalGeneration.from_pretrained(\n", | 
|  | 251 | +    "    \"models/bart\", # Or directory where the local model is stored\n", | 
|  | 252 | +    "    max_length=150,\n", | 
|  | 253 | +    "    forced_bos_token_id=tokenizer.bos_token_id,\n", | 
|  | 254 | +    ")\n", | 
|  | 255 | +    "\n", | 
|  | 256 | +    "# Initialize TMaRCo with local models\n", | 
|  | 257 | +    "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)" | 
|  | 258 | +   ], | 
|  | 259 | +   "id": "f0f24485822a7c3f", | 
|  | 260 | +   "outputs": [], | 
|  | 261 | +   "execution_count": null | 
|  | 262 | +  }, | 
| 201 | 263 |   { | 
| 202 | 264 |    "cell_type": "code", | 
| 203 | 265 |    "execution_count": 7, | 
|  | 
| 223 | 285 |     "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])" | 
| 224 | 286 |    ] | 
| 225 | 287 |   }, | 
|  | 288 | +  { | 
|  | 289 | +   "metadata": {}, | 
|  | 290 | +   "cell_type": "markdown", | 
|  | 291 | +   "source": [ | 
|  | 292 | +    "<div class=\"alert alert-info\">\n", | 
|  | 293 | +    "To use local expert/anti-expert models with TMaRCo, we need to have them in a local storage, accessible to TMaRCo, as previously.\n", | 
|  | 294 | +    "However, we don't need to initialize them separately, and can pass the directory directly.\n", | 
|  | 295 | +    "</div>\n", | 
|  | 296 | +    "If we want to use local models with `TMaRCo` (in this case the same default `gminus`/`gplus`):\n" | 
|  | 297 | +   ], | 
|  | 298 | +   "id": "c113208c527c342e" | 
|  | 299 | +  }, | 
|  | 300 | +  { | 
|  | 301 | +   "metadata": {}, | 
|  | 302 | +   "cell_type": "code", | 
|  | 303 | +   "outputs": [], | 
|  | 304 | +   "execution_count": null, | 
|  | 305 | +   "source": [ | 
|  | 306 | +    "snapshot_download(repo_id=\"trustyai/gminus\", local_dir=\"models/gminus\")\n", | 
|  | 307 | +    "snapshot_download(repo_id=\"trustyai/gplus\", local_dir=\"models/gplus\")\n", | 
|  | 308 | +    "\n", | 
|  | 309 | +    "tmarco.load_models([\"models/gminus\", \"models/gplus\"])" | 
|  | 310 | +   ], | 
|  | 311 | +   "id": "dfa288dcb60102c" | 
|  | 312 | +  }, | 
| 226 | 313 |   { | 
| 227 | 314 |    "cell_type": "code", | 
| 228 | 315 |    "execution_count": 13, | 
|  | 
| 362 | 449 |     "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])" | 
| 363 | 450 |    ] | 
| 364 | 451 |   }, | 
|  | 452 | +  { | 
|  | 453 | +   "metadata": {}, | 
|  | 454 | +   "cell_type": "markdown", | 
|  | 455 | +   "source": "As noted previously, to use local models, simply pass the initialized tokenizer and base model to the constructor, and the local path as the expert/anti-expert:", | 
|  | 456 | +   "id": "b0738c324227f57" | 
|  | 457 | +  }, | 
|  | 458 | +  { | 
|  | 459 | +   "metadata": {}, | 
|  | 460 | +   "cell_type": "code", | 
|  | 461 | +   "outputs": [], | 
|  | 462 | +   "execution_count": null, | 
|  | 463 | +   "source": [ | 
|  | 464 | +    "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)\n", | 
|  | 465 | +    "tmarco.load_models([\"models/gminus\", \"models/gplus\"])" | 
|  | 466 | +   ], | 
|  | 467 | +   "id": "b929e21a97ea914e" | 
|  | 468 | +  }, | 
| 365 | 469 |   { | 
| 366 | 470 |    "cell_type": "markdown", | 
| 367 | 471 |    "id": "5303f56b-85ff-40da-99bf-6962cf2f3395", | 
|  | 
0 commit comments