Skip to content

Commit fab97d4

Browse files
committed
Add option to disable cluster trimming in ClusterRefiner
1 parent e956dcb commit fab97d4

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

gecco/refine.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Algorithm to smooth contiguous gene cluster predictions into single regions.
22
"""
33

4+
import copy
45
import collections.abc
56
import itertools
67
import functools
@@ -70,12 +71,14 @@ class ClusterRefiner:
7071

7172
def __init__(
7273
self,
74+
*,
7375
threshold: float = 0.8,
7476
criterion: str = "gecco",
7577
n_cds: int = 5,
7678
n_biopfams: int = 5,
7779
average_threshold: float = 0.6,
7880
edge_distance: int = 0,
81+
trim: bool = True,
7982
) -> None:
8083
"""Create a new `ClusterRefiner` instance.
8184
@@ -97,6 +100,11 @@ def __init__(
97100
gene cluster must be located (it may start at an edge, but must
98101
span for longer than ``edge_distance``), in number of annotated
99102
genes (*only when the criterion is* ``gecco``).
103+
trim (`bool`): If set to `True` (the default), raw segments
104+
predicted by a `~gecco.crf.ClusterCRF` will be post-processed
105+
to exclude genes on cluster edges which have no domain
106+
annotation. Set to `False` to retain all genes as predicted
107+
by the CRF.
100108
101109
"""
102110
self.threshold = threshold
@@ -105,6 +113,7 @@ def __init__(
105113
self.n_biopfams = n_biopfams
106114
self.average_threshold = average_threshold
107115
self.edge_distance = edge_distance
116+
self.trim = trim
108117

109118
def iter_clusters(self, genes: List[Gene]) -> Iterator[Cluster]:
110119
"""Find all clusters in a table of CRF predictions.
@@ -119,7 +128,8 @@ def iter_clusters(self, genes: List[Gene]) -> Iterator[Cluster]:
119128
120129
"""
121130
for seq, cluster in self._iter_clusters(genes):
122-
trimmed = self._trim_cluster(cluster)
131+
if self.trim:
132+
cluster = self._trim_cluster(cluster)
123133
if self._validate_cluster(seq, cluster):
124134
yield cluster
125135
#
@@ -157,11 +167,17 @@ def _validate_cluster(self, seq: List[Gene], cluster: Cluster) -> bool:
157167
def _trim_cluster(self, cluster: Cluster) -> Cluster:
158168
"""Remove unannotated proteins from the cluster edges.
159169
"""
160-
while cluster.genes and not cluster.genes[0].protein.domains:
161-
cluster.genes.pop(0)
162-
while cluster.genes and not cluster.genes[-1].protein.domains:
163-
cluster.genes.pop()
164-
return cluster
170+
genes = cluster.genes.copy()
171+
while genes and not genes[0].protein.domains:
172+
genes.pop(0)
173+
while genes and not genes[-1].protein.domains:
174+
genes.pop()
175+
return Cluster(
176+
cluster.id,
177+
genes,
178+
cluster.type,
179+
cluster.type_probabilities
180+
)
165181

166182
def _iter_clusters(
167183
self,

0 commit comments

Comments
 (0)