|
332 | 332 | "import os\n",
|
333 | 333 | "from tempfile import TemporaryDirectory\n",
|
334 | 334 | "\n",
|
335 |
| - "import pandas as pd\n", |
336 | 335 | "import requests\n",
|
337 | 336 | "\n",
|
338 | 337 | "url = \"https://s3.embl.de/spatialdata/spatialdata-sandbox/generated_data/xenium_visium_integration/xenium_rep1_celltype_major.csv\"\n",
|
|
363 | 362 | ],
|
364 | 363 | "source": [
|
365 | 364 | "xenium_sdata[\"table\"].obs = pd.merge(xenium_sdata[\"table\"].obs, df, on=\"cell_id\")\n",
|
366 |
| - "xenium_sdata[\"table\"].obs[\"celltype_major\"] = (\n", |
367 |
| - " xenium_sdata[\"table\"].obs[\"celltype_major\"].astype(\"category\")\n", |
368 |
| - ")" |
| 365 | + "xenium_sdata[\"table\"].obs[\"celltype_major\"] = xenium_sdata[\"table\"].obs[\"celltype_major\"].astype(\"category\")" |
369 | 366 | ]
|
370 | 367 | },
|
371 | 368 | {
|
|
936 | 933 | "outputs": [],
|
937 | 934 | "source": [
|
938 | 935 | "class TilesDataModule(LightningDataModule):\n",
|
939 |
| - " def __init__(\n", |
940 |
| - " self, batch_size: int, num_workers: int, dataset: torch.utils.data.Dataset\n", |
941 |
| - " ):\n", |
| 936 | + " def __init__(self, batch_size: int, num_workers: int, dataset: torch.utils.data.Dataset):\n", |
942 | 937 | " super().__init__()\n",
|
943 | 938 | "\n",
|
944 | 939 | " self.batch_size = batch_size\n",
|
|
1015 | 1010 | " self.loss_function = CrossEntropyLoss()\n",
|
1016 | 1011 | "\n",
|
1017 | 1012 | " # make the model\n",
|
1018 |
| - " self.model = DenseNet121(\n", |
1019 |
| - " spatial_dims=2, in_channels=in_channels, out_channels=num_classes\n", |
1020 |
| - " )\n", |
| 1013 | + " self.model = DenseNet121(spatial_dims=2, in_channels=in_channels, out_channels=num_classes)\n", |
1021 | 1014 | "\n",
|
1022 | 1015 | " def forward(self, x) -> torch.Tensor:\n",
|
1023 | 1016 | " return self.model(x)\n",
|
1024 | 1017 | "\n",
|
1025 |
| - " def _compute_loss_from_batch(\n", |
1026 |
| - " self, batch: Dict[str, torch.Tensor], batch_idx: int\n", |
1027 |
| - " ) -> float:\n", |
| 1018 | + " def _compute_loss_from_batch(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> float:\n", |
1028 | 1019 | " inputs = batch[0]\n",
|
1029 | 1020 | " labels = batch[1]\n",
|
1030 | 1021 | "\n",
|
1031 | 1022 | " outputs = self.model(inputs)\n",
|
1032 | 1023 | " return self.loss_function(outputs, labels)\n",
|
1033 | 1024 | "\n",
|
1034 |
| - " def training_step(\n", |
1035 |
| - " self, batch: Dict[str, torch.Tensor], batch_idx: int\n", |
1036 |
| - " ) -> Dict[str, float]:\n", |
| 1025 | + " def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, float]:\n", |
1037 | 1026 | " # compute the loss\n",
|
1038 | 1027 | " loss = self._compute_loss_from_batch(batch=batch, batch_idx=batch_idx)\n",
|
1039 | 1028 | "\n",
|
|
1124 | 1113 | "print(f\"Using {BATCH_SIZE} batch size.\")\n",
|
1125 | 1114 | "print(f\"Using {NUM_WORKERS} workers.\")\n",
|
1126 | 1115 | "\n",
|
1127 |
| - "tiles_data_module = TilesDataModule(\n", |
1128 |
| - " batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, dataset=dataset\n", |
1129 |
| - ")\n", |
| 1116 | + "tiles_data_module = TilesDataModule(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, dataset=dataset)\n", |
1130 | 1117 | "\n",
|
1131 | 1118 | "tiles_data_module.setup()\n",
|
1132 | 1119 | "train_dl = tiles_data_module.train_dataloader()\n",
|
|
1870 | 1857 | "source": [
|
1871 | 1858 | "small_dataset = ImageTilesDataset(\n",
|
1872 | 1859 | " sdata=small_sdata,\n",
|
1873 |
| - " regions_to_images={\n", |
1874 |
| - " \"cell_boundaries\": \"CytAssist_FFPE_Human_Breast_Cancer_full_image\"\n", |
1875 |
| - " },\n", |
| 1860 | + " regions_to_images={\"cell_boundaries\": \"CytAssist_FFPE_Human_Breast_Cancer_full_image\"},\n", |
1876 | 1861 | " regions_to_coordinate_systems={\"cell_boundaries\": \"aligned\"},\n",
|
1877 | 1862 | " tile_dim_in_units=100,\n",
|
1878 | 1863 | " rasterize=True,\n",
|
|
1925 | 1910 | " region, instance_id = small_dataset.dataset_index.iloc[i][[\"region\", \"instance_id\"]]\n",
|
1926 | 1911 | " shapes = small_sdata[region]\n",
|
1927 | 1912 | " transformations = get_transformation(shapes, get_all=True)\n",
|
1928 |
| - " tile = ShapesModel.parse(\n", |
1929 |
| - " GeoDataFrame(geometry=shapes.loc[instance_id]), transformations=transformations\n", |
1930 |
| - " )\n", |
| 1913 | + " tile = ShapesModel.parse(GeoDataFrame(geometry=shapes.loc[instance_id]), transformations=transformations)\n", |
1931 | 1914 | " # BUG: we need to explicitly remove the coordinate system global if we want to combine\n",
|
1932 | 1915 | " # images and shapes plots into a single subplot\n",
|
1933 | 1916 | " # https://github.com/scverse/spatialdata-plot/issues/176\n",
|
1934 | 1917 | " sdata_tile[\"cell_boundaries\"] = tile\n",
|
1935 | 1918 | " if \"global\" in get_transformation(sdata_tile[\"cell_boundaries\"], get_all=True):\n",
|
1936 |
| - " sd.transformations.remove_transformation(\n", |
1937 |
| - " sdata_tile[\"cell_boundaries\"], \"global\"\n", |
1938 |
| - " )\n", |
| 1919 | + " sd.transformations.remove_transformation(sdata_tile[\"cell_boundaries\"], \"global\")\n", |
1939 | 1920 | " sdata_tile.pl.render_images().pl.render_shapes(\n",
|
1940 | 1921 | " # outline_color='predicted_celltype_major', # not yet supported: https://github.com/scverse/spatialdata-plot/issues/137\n",
|
1941 | 1922 | " outline_width=3.0,\n",
|
|
0 commit comments