Skip to content

Commit aef4b2a

Browse files
committed
drape
1 parent 2e401fc commit aef4b2a

File tree

3 files changed

+59
-11
lines changed

3 files changed

+59
-11
lines changed

drape_CHM.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,69 @@
22
from src import CHM
33
from src.data import read_config
44
from src.utils import create_glob_lists
5+
from src.neon_paths import find_sensor_path
6+
from src.start_cluster import start
7+
from src.model_list import species_model_paths
8+
59
import glob
610
import geopandas as gpd
11+
import rasterstats
12+
import os
13+
from distributed import wait, as_completed
14+
import traceback
715

816
config = read_config("config.yml")
917
rgb_pool, h5_pool, hsi_pool, CHM_pool = create_glob_lists(config)
1018

11-
files = glob.glob("/blue/ewhite/b.weinstein/DeepTreeAttention/results/predictions/**/*.shp", recursive=True)
19+
all_files = []
20+
for site in species_model_paths:
21+
model_path = species_model_paths[site]
22+
model_name = os.path.splitext(os.path.basename(model_path))[0]
23+
prediction_dir = os.path.join("/blue/ewhite/b.weinstein/DeepTreeAttention/results/predictions/{}/{}/*.shp".format(site, model_name))
24+
files = glob.glob(prediction_dir, recursive=True)
25+
all_files.append(files)
26+
27+
all_files = [item for sublist in all_files for item in sublist]
28+
print("Found {} files".format(len(all_files)))
29+
30+
client = start(cpus=120,mem_size="4GB")
1231

1332
def drape(shp, config, CHM_pool):
1433
"""Take in a predictions shapefile and extract CHM height"""
15-
shp = gpd.read_file(shp)
16-
# Create a dummy plotID
17-
shp["plotID"] = "same_plot"
18-
draped_shp = CHM.postprocess_CHM(shp, lookup_pool=CHM_pool)
19-
#draped_shp.to_file(shp)
34+
df = gpd.read_file(shp)
35+
36+
#buffer slightly, CHM model can be patchy
37+
geom = df.geometry.buffer(1)
38+
try:
39+
CHM_path = find_sensor_path(lookup_pool=CHM_pool, bounds=df.total_bounds)
40+
except ValueError:
41+
return None
42+
43+
draped_boxes = rasterstats.zonal_stats(geom,
44+
CHM_path,
45+
add_stats={'q99': CHM.non_zero_99_quantile})
46+
df["height"] = [x["q99"] for x in draped_boxes]
47+
filtered_trees = df[df["height"]>3]
2048

49+
dirname = os.path.dirname(shp)
50+
basename = os.path.basename(shp)
51+
dst = os.path.join(dirname,"draped")
52+
os.makedirs(dst, exist_ok=True)
53+
filtered_trees.to_file(os.path.join(dst, basename))
54+
55+
return dst
56+
57+
futures = []
2158
for f in files:
22-
drape(f, config=config, CHM_pool=CHM_pool)
23-
print(f)
59+
future = client.submit(drape, f, CHM_pool=CHM_pool, config=config)
60+
futures.append(future)
61+
62+
for x in as_completed(futures):
63+
try:
64+
print(x.result())
65+
except:
66+
traceback.print_exc()
67+
68+
#for f in files:
69+
# dst = drape(f, config=config, CHM_pool=CHM_pool)
70+
# print(dst)

predict.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def create_landscape_map(site, model_path, config, client, rgb_pool, hsi_pool, h
128128
config=config,
129129
dead_model_path=dead_model_path,
130130
savedir="/blue/ewhite/b.weinstein/DeepTreeAttention/results/crowns",
131+
CHM_pool=CHM_pool,
131132
overwrite=False)
132133
except:
133134
traceback.print_exc()

src/predict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def generate_prediction_crops(crown_path, config, rgb_pool, h5_pool, img_pool, c
106106

107107
#Write file alongside
108108
crown_annotations = gpd.GeoDataFrame(crown_annotations, geometry="geometry")
109-
crown_annotations = crown_annotations.merge(crowns[["individual","dead_label","dead_score", "score"]])
109+
crown_annotations = crown_annotations.merge(crowns[["individual","dead_label","dead_score", "score","CHM_height"]])
110110

111111
crown_annotations.to_file(output_name)
112112

@@ -166,8 +166,8 @@ def predict_tile(crown_annotations, m, config, savedir, site, trainer, filter_de
166166
trees["crown_score"] = trees["score"]
167167
trees = trees.drop(columns=["pred_taxa_top1","label","score","taxonID"])
168168
trees = trees.groupby("individual").apply(lambda x: x.head(1)).reset_index(drop=True)
169-
trees = gpd.GeoDataFrame(trees, geometry="geometry")
170-
169+
trees = gpd.GeoDataFrame(trees, geometry="geometry")
170+
171171
#Save .shp
172172
output_name = os.path.splitext(os.path.basename(crown_annotations))[0]
173173
trees.to_file(os.path.join(savedir, "{}.shp".format(output_name)))

0 commit comments

Comments
 (0)