Skip to content

Commit e5f79c2

Browse files
committed
added detection for Transactional methods
1 parent 8cf01aa commit e5f79c2

File tree

11 files changed

+154
-51
lines changed

11 files changed

+154
-51
lines changed

core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>org.sterl.spring</groupId>
88
<artifactId>spring-persistent-tasks-root</artifactId>
9-
<version>1.3.2-SNAPSHOT</version>
9+
<version>1.4.0-SNAPSHOT</version>
1010
<relativePath>../pom.xml</relativePath>
1111
</parent>
1212

core/src/main/java/org/sterl/spring/persistent_tasks/task/TaskService.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import org.springframework.lang.NonNull;
1010
import org.springframework.stereotype.Service;
1111
import org.springframework.transaction.annotation.Transactional;
12+
import org.springframework.transaction.support.TransactionTemplate;
1213
import org.sterl.spring.persistent_tasks.api.PersistentTask;
1314
import org.sterl.spring.persistent_tasks.api.TaskId;
15+
import org.sterl.spring.persistent_tasks.task.component.TaskTransactionComponent;
1416
import org.sterl.spring.persistent_tasks.task.repository.TaskRepository;
1517

1618
import lombok.RequiredArgsConstructor;
@@ -20,6 +22,7 @@
2022
@RequiredArgsConstructor
2123
public class TaskService {
2224

25+
private final TaskTransactionComponent taskTransactionComponent;
2326
private final TaskRepository taskRepository;
2427

2528
@Transactional(readOnly = true)
@@ -30,6 +33,11 @@ public Set<TaskId<? extends Serializable>> findAllTaskIds() {
3033
public <T extends Serializable> Optional<PersistentTask<T>> get(TaskId<T> id) {
3134
return taskRepository.get(id);
3235
}
36+
37+
public <T extends Serializable> Optional<TransactionTemplate> getTransactionTemplate(
38+
PersistentTask<T> task) {
39+
return taskTransactionComponent.getTransactionTemplate(task);
40+
}
3341

3442
/**
3543
* Check if the {@link PersistentTask} is known or not.
@@ -66,6 +74,14 @@ public void accept(Serializable state) {
6674
@SuppressWarnings("unchecked")
6775
public <T extends Serializable> TaskId<T> register(String name, PersistentTask<T> task) {
6876
var id = (TaskId<T>)TaskId.of(name);
77+
return register(id, task);
78+
}
79+
/**
80+
* A way to manually register a persistentTask, usually not needed as spring beans will be added automatically.
81+
*/
82+
public <T extends Serializable> TaskId<T> register(TaskId<T> id, PersistentTask<T> task) {
83+
// init any transaction as needed
84+
taskTransactionComponent.getTransactionTemplate(task);
6985
return taskRepository.addTask(id, task);
7086
}
7187
/**
@@ -75,6 +91,6 @@ public <T extends Serializable> TaskId<T> register(String name, PersistentTask<T
7591
public <T extends Serializable> TaskId<T> replace(String name, PersistentTask<T> task) {
7692
var id = (TaskId<T>)TaskId.of(name);
7793
taskRepository.remove(id);
78-
return taskRepository.addTask(id, task);
94+
return register(id, task);
7995
}
8096
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package org.sterl.spring.persistent_tasks.task.component;
2+
3+
import java.io.Serializable;
4+
import java.util.EnumSet;
5+
import java.util.Map;
6+
import java.util.Optional;
7+
import java.util.Set;
8+
import java.util.concurrent.ConcurrentHashMap;
9+
10+
import org.springframework.stereotype.Component;
11+
import org.springframework.transaction.PlatformTransactionManager;
12+
import org.springframework.transaction.annotation.Propagation;
13+
import org.springframework.transaction.annotation.Transactional;
14+
import org.springframework.transaction.support.DefaultTransactionDefinition;
15+
import org.springframework.transaction.support.TransactionTemplate;
16+
import org.sterl.spring.persistent_tasks.api.PersistentTask;
17+
import org.sterl.spring.persistent_tasks.task.util.ReflectionUtil;
18+
19+
import lombok.RequiredArgsConstructor;
20+
import lombok.extern.slf4j.Slf4j;
21+
22+
@Component
23+
@Slf4j
24+
@RequiredArgsConstructor
25+
public class TaskTransactionComponent {
26+
27+
private final PlatformTransactionManager transactionManager;
28+
private final TransactionTemplate template;
29+
private final Set<Propagation> joinTransaction = EnumSet.of(
30+
Propagation.MANDATORY, Propagation.REQUIRED, Propagation.SUPPORTS);
31+
private final Map<PersistentTask<? extends Serializable>, Optional<TransactionTemplate>> cache = new ConcurrentHashMap<>();
32+
33+
public Optional<TransactionTemplate> getTransactionTemplate(PersistentTask<? extends Serializable> task) {
34+
if (cache.containsKey(task)) return cache.get(task);
35+
36+
Optional<TransactionTemplate> result;
37+
// first we apply a default
38+
if (task.isTransactional()) result = Optional.of(template);
39+
else result = Optional.empty();
40+
41+
var annotation = ReflectionUtil.getAnnotation(task, Transactional.class);
42+
if (annotation != null) {
43+
log.debug("found {} on task={}, creating custom ", annotation, task.getClass().getName());
44+
result = Optional.ofNullable(builTransactionTemplate(task, annotation));
45+
}
46+
cache.put(task, result);
47+
return result;
48+
}
49+
50+
private TransactionTemplate builTransactionTemplate(PersistentTask<? extends Serializable> task, Transactional annotation) {
51+
TransactionTemplate result;
52+
if (joinTransaction.contains(annotation.propagation())) {
53+
// No direct mapping for 'rollbackFor' or 'noRollbackFor'
54+
if (annotation.noRollbackFor().length > 0 || annotation.rollbackFor().length > 0) {
55+
throw new IllegalArgumentException("noRollbackFor or rollbackFor not supported. Please remove the settings on "
56+
+ task.getClass());
57+
} else {
58+
var dev = convertTransactionalToDefinition(annotation);
59+
dev.setName(task.getClass().getSimpleName());
60+
result = new TransactionTemplate(transactionManager, dev);
61+
}
62+
} else {
63+
log.info("Propagation={} disables join of transaction for {}",
64+
annotation.propagation(), task.getClass().getName());
65+
result = null;
66+
}
67+
return result;
68+
}
69+
70+
static DefaultTransactionDefinition convertTransactionalToDefinition(Transactional transactional) {
71+
DefaultTransactionDefinition def = new DefaultTransactionDefinition();
72+
73+
// Map Transactional attributes to DefaultTransactionDefinition
74+
def.setIsolationLevel(transactional.isolation().value());
75+
def.setPropagationBehavior(Propagation.REQUIRED.value());
76+
def.setTimeout(transactional.timeout());
77+
def.setReadOnly(false);
78+
79+
return def;
80+
}
81+
}

core/src/main/java/org/sterl/spring/persistent_tasks/task/config/TaskConfig.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import org.springframework.context.support.GenericApplicationContext;
99
import org.sterl.spring.persistent_tasks.api.PersistentTask;
1010
import org.sterl.spring.persistent_tasks.api.TaskId;
11-
import org.sterl.spring.persistent_tasks.task.repository.TaskRepository;
11+
import org.sterl.spring.persistent_tasks.task.TaskService;
1212

1313
import lombok.extern.slf4j.Slf4j;
1414

@@ -18,11 +18,10 @@ public class TaskConfig {
1818
@SuppressWarnings({ "rawtypes", "unchecked" })
1919
@Autowired
2020
void configureSimpleTasks(GenericApplicationContext context,
21-
TaskRepository taskRepository) {
21+
TaskService taskService) {
2222
final var simpleTasks = context.getBeansOfType(PersistentTask.class);
2323
for(Entry<String, PersistentTask> t : simpleTasks.entrySet()) {
24-
var id = taskRepository.addTask(
25-
(TaskId<Serializable>)TaskId.of(t.getKey()), t.getValue());
24+
var id = taskService.register(t.getKey(), t.getValue());
2625

2726
addTaskIdIfMissing(context, id, t.getValue());
2827
}

core/src/main/java/org/sterl/spring/persistent_tasks/trigger/component/RunTriggerComponent.java

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import java.io.Serializable;
44
import java.time.OffsetDateTime;
55
import java.util.Optional;
6-
import java.util.concurrent.Callable;
76

87
import org.springframework.context.ApplicationEventPublisher;
98
import org.springframework.lang.Nullable;
@@ -25,7 +24,6 @@ public class RunTriggerComponent {
2524
private final TaskService taskService;
2625
private final EditTriggerComponent editTrigger;
2726
private final ApplicationEventPublisher eventPublisher;
28-
private final TransactionTemplate trx;
2927
private final StateSerializer serializer = new StateSerializer();
3028

3129
/**
@@ -40,40 +38,41 @@ public Optional<TriggerEntity> execute(TriggerEntity trigger) {
4038
if (taskAndState == null) return Optional.of(trigger);
4139

4240
try {
43-
Optional<TriggerEntity> result;
44-
if (taskAndState.isTransactional()) {
45-
result = trx.execute(t -> taskAndState.call());
46-
} else {
47-
result = taskAndState.call();
48-
}
49-
50-
return result;
41+
return taskAndState.call();
5142
} catch (Exception e) {
5243
return handleTaskException(taskAndState, e);
5344
}
5445
}
46+
5547
@Nullable
5648
private TaskAndState getTastAndState(TriggerEntity trigger) {
5749
try {
5850
var task = taskService.assertIsKnown(trigger.newTaskId());
51+
var trx = taskService.getTransactionTemplate(task);
5952
var state = serializer.deserialize(trigger.getData().getState());
60-
return new TaskAndState(task, state, trigger);
53+
return new TaskAndState(task, trx, state, trigger);
6154
} catch (Exception e) {
6255
// this trigger is somehow crap, no retry and done.
63-
handleTaskException(new TaskAndState(null, null, trigger), e);
56+
handleTaskException(new TaskAndState(null, Optional.empty(), null, trigger), e);
6457
return null;
6558
}
6659
}
6760
@RequiredArgsConstructor
68-
private class TaskAndState implements Callable<Optional<TriggerEntity>> {
61+
private class TaskAndState {
6962
final PersistentTask<Serializable> persistentTask;
63+
final Optional<TransactionTemplate> trx;
7064
final Serializable state;
7165
final TriggerEntity trigger;
7266

73-
boolean isTransactional() {
74-
return persistentTask.isTransactional();
67+
Optional<TriggerEntity> call() {
68+
if (trx.isPresent()) {
69+
return trx.get().execute(t -> runTask());
70+
} else {
71+
return runTask();
72+
}
7573
}
76-
public Optional<TriggerEntity> call() {
74+
75+
private Optional<TriggerEntity> runTask() {
7776
eventPublisher.publishEvent(new TriggerRunningEvent(trigger));
7877

7978
persistentTask.accept(state);
@@ -82,7 +81,6 @@ public Optional<TriggerEntity> call() {
8281
editTrigger.deleteTrigger(trigger);
8382

8483
return result;
85-
8684
}
8785
}
8886

@@ -93,17 +91,18 @@ private Optional<TriggerEntity> handleTaskException(TaskAndState taskAndState,
9391
var task = taskAndState.persistentTask;
9492
var result = editTrigger.completeTaskWithStatus(trigger.getKey(), e);
9593

96-
if (task != null &&
97-
task.retryStrategy().shouldRetry(trigger.getData().getExecutionCount(), e)) {
94+
if (task != null
95+
&& task.retryStrategy().shouldRetry(trigger.getData().getExecutionCount(), e)) {
9896

9997
final OffsetDateTime retryAt = task.retryStrategy().retryAt(trigger.getData().getExecutionCount(), e);
10098

10199
result = editTrigger.retryTrigger(trigger.getKey(), retryAt);
102100
if (result.isPresent()) {
101+
var data = result.get().getData();
103102
log.warn("{} failed, retry will be done at={} status={}!",
104103
trigger.getKey(),
105-
result.get().getData().getRunAt(),
106-
result.get().getData().getStatus(),
104+
data.getRunAt(),
105+
data.getStatus(),
107106
e);
108107
} else {
109108
log.error("Trigger with key={} not found and may be at a wrong state!",
@@ -117,5 +116,4 @@ private Optional<TriggerEntity> handleTaskException(TaskAndState taskAndState,
117116
}
118117
return result;
119118
}
120-
121119
}

core/src/test/java/org/sterl/spring/persistent_tasks/task/TaskServiceTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
import org.junit.jupiter.api.Test;
66
import org.sterl.spring.persistent_tasks.api.TaskId;
7+
import org.sterl.spring.persistent_tasks.task.component.TaskTransactionComponent;
78
import org.sterl.spring.persistent_tasks.task.repository.TaskRepository;
89

910
class TaskServiceTest {
1011

11-
private final TaskService subject = new TaskService(new TaskRepository());
12+
private final TaskService subject = new TaskService(
13+
new TaskTransactionComponent(null, null),
14+
new TaskRepository());
1215

1316
@Test
1417
void testAssertIsKnown() {

core/src/test/java/org/sterl/spring/persistent_tasks/task/TaskTransactionTest.java

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import org.springframework.context.annotation.Bean;
1111
import org.springframework.context.annotation.Configuration;
1212
import org.springframework.stereotype.Component;
13+
import org.springframework.transaction.annotation.Isolation;
1314
import org.springframework.transaction.annotation.Propagation;
1415
import org.springframework.transaction.annotation.Transactional;
15-
import org.springframework.transaction.support.DefaultTransactionDefinition;
1616
import org.sterl.spring.persistent_tasks.AbstractSpringTest;
1717
import org.sterl.spring.persistent_tasks.api.PersistentTask;
1818
import org.sterl.spring.persistent_tasks.api.TaskId.TaskTriggerBuilder;
@@ -33,16 +33,18 @@ static class TransactionalClass implements PersistentTask<String> {
3333
@Override
3434
public void accept(String name) {
3535
personRepository.save(new PersonBE(name));
36+
personRepository.save(new PersonBE(name));
3637
}
3738
}
3839
@Component("transactionalMethod")
3940
@RequiredArgsConstructor
4041
static class TransactionalMethod implements PersistentTask<String> {
4142
private final PersonRepository personRepository;
42-
@Transactional(timeout = 6, propagation = Propagation.MANDATORY)
43+
@Transactional(timeout = 6, propagation = Propagation.MANDATORY, isolation = Isolation.REPEATABLE_READ)
4344
@Override
4445
public void accept(String name) {
4546
personRepository.save(new PersonBE(name));
47+
personRepository.save(new PersonBE(name));
4648
}
4749
}
4850

@@ -70,6 +72,7 @@ TransactionalTask<String> transactionalClosure(PersonRepository personRepository
7072
}
7173
}
7274

75+
@Autowired TaskService subject;
7376
@Autowired PersonRepository personRepository;
7477

7578
@Autowired @Qualifier("transactionalClass")
@@ -93,6 +96,25 @@ void testFindTransactionAnnotation() {
9396
assertThat(a).isNotNull();
9497
assertThat(a.timeout()).isEqualTo(7);
9598
}
99+
100+
@Test
101+
void testGetTransactionTemplate() {
102+
var a = subject.getTransactionTemplate(transactionalClass);
103+
assertThat(a).isPresent();
104+
assertThat(a.get().getTimeout()).isEqualTo(5);
105+
assertThat(a.get().getPropagationBehavior()).isEqualTo(Propagation.REQUIRED.value());
106+
107+
a = subject.getTransactionTemplate(transactionalMethod);
108+
assertThat(a).isPresent();
109+
assertThat(a.get().getTimeout()).isEqualTo(6);
110+
assertThat(a.get().getPropagationBehavior()).isEqualTo(Propagation.REQUIRED.value());
111+
assertThat(a.get().getIsolationLevel()).isEqualTo(Isolation.REPEATABLE_READ.value());
112+
113+
a = subject.getTransactionTemplate(transactionalAnonymous);
114+
assertThat(a).isPresent();
115+
assertThat(a.get().getTimeout()).isEqualTo(7);
116+
assertThat(a.get().getPropagationBehavior()).isEqualTo(Propagation.REQUIRED.value());
117+
}
96118

97119
@ParameterizedTest
98120
@ValueSource(strings = {"transactionalClass", "transactionalMethod", "transactionalClosure"})
@@ -110,20 +132,4 @@ void testTransactionalTask(String task) {
110132
hibernateAsserts.assertTrxCount(1);
111133
assertThat(personRepository.count()).isEqualTo(2);
112134
}
113-
114-
public static DefaultTransactionDefinition convertTransactionalToDefinition(Transactional transactional) {
115-
DefaultTransactionDefinition def = new DefaultTransactionDefinition();
116-
117-
// Map Transactional attributes to DefaultTransactionDefinition
118-
def.setIsolationLevel(transactional.isolation().value());
119-
def.setPropagationBehavior(transactional.propagation().value());
120-
def.setTimeout(transactional.timeout());
121-
def.setReadOnly(transactional.readOnly());
122-
// No direct mapping for 'rollbackFor' or 'noRollbackFor'
123-
// Set a name if desired (e.g., based on transactional class/method)
124-
def.setName("TransactionalDefinition");
125-
126-
return def;
127-
}
128-
129135
}

db/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>org.sterl.spring</groupId>
88
<artifactId>spring-persistent-tasks-root</artifactId>
9-
<version>1.3.2-SNAPSHOT</version>
9+
<version>1.4.0-SNAPSHOT</version>
1010
<relativePath>../pom.xml</relativePath>
1111
</parent>
1212

example/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>org.sterl.spring</groupId>
88
<artifactId>spring-persistent-tasks-root</artifactId>
9-
<version>1.3.2-SNAPSHOT</version>
9+
<version>1.4.0-SNAPSHOT</version>
1010
<relativePath>../pom.xml</relativePath>
1111
</parent>
1212

0 commit comments

Comments
 (0)