|
26 | 26 | import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
27 | 27 | import org.springframework.context.annotation.Bean;
|
28 | 28 | import org.springframework.context.annotation.Configuration;
|
| 29 | +import org.springframework.dao.DataAccessException; |
| 30 | +import org.springframework.dao.support.PersistenceExceptionTranslator; |
29 | 31 | import org.springframework.data.mongodb.MongoDbFactory;
|
30 | 32 | import org.springframework.data.mongodb.core.MongoTemplate;
|
31 | 33 | import org.springframework.data.mongodb.core.SimpleMongoDbFactory;
|
32 | 34 | import org.springframework.data.mongodb.gridfs.GridFsTemplate;
|
| 35 | +import org.springframework.util.Assert; |
33 | 36 | import org.springframework.util.StringUtils;
|
34 | 37 |
|
| 38 | +import com.mongodb.DB; |
35 | 39 | import com.mongodb.Mongo;
|
36 | 40 |
|
37 | 41 | /**
|
@@ -73,11 +77,49 @@ public MongoTemplate mongoTemplate(MongoDbFactory mongoDbFactory)
|
73 | 77 |
|
74 | 78 | @Bean
|
75 | 79 | @ConditionalOnMissingBean
|
76 |
| - public GridFsTemplate gridFsTemplate(Mongo mongo, MongoTemplate mongoTemplate) { |
77 |
| - String db = StringUtils.hasText(this.properties.getGridFsDatabase()) ? this.properties |
78 |
| - .getGridFsDatabase() : this.properties.getMongoClientDatabase(); |
79 |
| - return new GridFsTemplate(new SimpleMongoDbFactory(mongo, db), |
80 |
| - mongoTemplate.getConverter()); |
| 80 | + public GridFsTemplate gridFsTemplate(MongoDbFactory mongoDbFactory, |
| 81 | + MongoTemplate mongoTemplate) { |
| 82 | + return new GridFsTemplate(new GridFsMongoDbFactory(mongoDbFactory, |
| 83 | + this.properties), mongoTemplate.getConverter()); |
| 84 | + } |
| 85 | + |
| 86 | + /** |
| 87 | + * {@link MongoDbFactory} decorator to respect |
| 88 | + * {@link MongoProperties#getGridFsDatabase()} if set. |
| 89 | + */ |
| 90 | + private static class GridFsMongoDbFactory implements MongoDbFactory { |
| 91 | + |
| 92 | + private final MongoDbFactory mongoDbFactory; |
| 93 | + |
| 94 | + private final MongoProperties properties; |
| 95 | + |
| 96 | + public GridFsMongoDbFactory(MongoDbFactory mongoDbFactory, |
| 97 | + MongoProperties properties) { |
| 98 | + Assert.notNull(mongoDbFactory, "MongoDbFactory must not be null"); |
| 99 | + Assert.notNull(properties, "Properties must not be null"); |
| 100 | + this.mongoDbFactory = mongoDbFactory; |
| 101 | + this.properties = properties; |
| 102 | + } |
| 103 | + |
| 104 | + @Override |
| 105 | + public DB getDb() throws DataAccessException { |
| 106 | + String gridFsDatabase = this.properties.getGridFsDatabase(); |
| 107 | + if (StringUtils.hasText(gridFsDatabase)) { |
| 108 | + return this.mongoDbFactory.getDb(gridFsDatabase); |
| 109 | + } |
| 110 | + return this.mongoDbFactory.getDb(); |
| 111 | + } |
| 112 | + |
| 113 | + @Override |
| 114 | + public DB getDb(String dbName) throws DataAccessException { |
| 115 | + return this.mongoDbFactory.getDb(dbName); |
| 116 | + } |
| 117 | + |
| 118 | + @Override |
| 119 | + public PersistenceExceptionTranslator getExceptionTranslator() { |
| 120 | + return this.mongoDbFactory.getExceptionTranslator(); |
| 121 | + } |
| 122 | + |
81 | 123 | }
|
82 | 124 |
|
83 | 125 | }
|
0 commit comments