From 23aa3c93f64fde365649b422c46b93f37b1b57ad Mon Sep 17 00:00:00 2001 From: Mikhail2048 Date: Thu, 22 Jun 2023 18:02:40 +0300 Subject: [PATCH] Implemented GH-921 keyspace customization feature --- .../cassandra/core/CassandraTemplate.java | 48 ++++-- .../data/cassandra/core/EntityOperations.java | 14 ++ .../ExecutableInsertOperationSupport.java | 2 +- .../data/cassandra/core/StatementFactory.java | 151 ++++++++++++++---- .../core/legacy/AsyncCassandraTemplate.java | 1 - .../BasicCassandraPersistentEntity.java | 26 ++- .../core/mapping/CassandraMappingContext.java | 3 + .../mapping/CassandraPersistentEntity.java | 7 + ...andraPersistentEntityMetadataVerifier.java | 2 +- .../mapping/EmbeddedEntityOperations.java | 5 +- .../data/cassandra/core/mapping/Table.java | 5 + ...sicCassandraPersistentEntityUnitTests.java | 5 +- .../cassandra/domain/EntityWithKeyspace.java | 11 ++ .../PartTreeCassandraQueryUnitTests.java | 28 +++- ...leCassandraRepositoryIntegrationTests.java | 113 ++++++++----- 15 files changed, 309 insertions(+), 112 deletions(-) create mode 100644 spring-data-cassandra/src/test/java/org/springframework/data/cassandra/domain/EntityWithKeyspace.java diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java index 6299d2189..0669b0993 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/CassandraTemplate.java @@ -414,19 +414,22 @@ public List select(Query query, Class entityClass) throws DataAccessEx Assert.notNull(query, "Query must not be null"); Assert.notNull(entityClass, "Entity type must not be null"); - return doSelect(query, entityClass, getTableName(entityClass), entityClass); + return doSelect(query, entityClass, getTableName(entityClass), entityClass, getEntityOperations().getCustomKeyspaceName(entityClass)); } List doSelect(Query query, Class entityClass, CqlIdentifier tableName, Class returnType) { + return this.doSelect(query, entityClass, tableName, returnType, null); + } + List doSelect(Query query, Class entityClass, CqlIdentifier tableName, Class returnType, @Nullable CqlIdentifier keyspace) { CassandraPersistentEntity entity = getRequiredPersistentEntity(entityClass); EntityProjection projection = entityOperations.introspectProjection(returnType, entityClass); Columns columns = getStatementFactory().computeColumnsForProjection(projection, query.getColumns(), entity, - returnType); + returnType); Query queryToUse = query.columns(columns); - StatementBuilder select = getStatementFactory().select(queryToUse, entity, tableName, keyspace); Function mapper = getMapper(projection, tableName); return doQuery(select.build(), (row, rowNum) -> mapper.apply(row)); @@ -463,7 +466,7 @@ public Stream stream(Query query, Class entityClass) throws DataAccess Stream doStream(Query query, Class entityClass, CqlIdentifier tableName, Class returnType) { StatementBuilder countStatement = getStatementFactory().count(query, - getRequiredPersistentEntity(entityClass), tableName); + getRequiredPersistentEntity(entityClass), tableName, keyspace); return doQueryForObject(countStatement.build(), Long.class); } @@ -557,7 +569,7 @@ public boolean exists(Object id, Class entityClass) { Assert.notNull(entityClass, "Entity type must not be null"); CassandraPersistentEntity entity = getRequiredPersistentEntity(entityClass); - StatementBuilder select = getStatementFactory().selectOneById(id, entity, entity.getTableName(), entity.getCustomKeyspace()); return doQueryForResultSet(select.build()).one() != null; } @@ -587,7 +599,7 @@ public T selectOneById(Object id, Class entityClass) { CassandraPersistentEntity entity = getRequiredPersistentEntity(entityClass); CqlIdentifier tableName = entity.getTableName(); - StatementBuilder select = getStatementFactory().selectOneById(id, entity, tableName, entity.getCustomKeyspace()); Function mapper = getMapper(EntityProjection.nonProjecting(entityClass), tableName); List result = doQuery(select.build(), (row, rowNum) -> mapper.apply(row)); @@ -605,10 +617,13 @@ public EntityWriteResult insert(T entity, InsertOptions options) { Assert.notNull(entity, "Entity must not be null"); Assert.notNull(options, "InsertOptions must not be null"); - return doInsert(entity, options, getTableName(entity.getClass())); + return doInsert(entity, options, getTableName(entity.getClass()), getEntityOperations().getCustomKeyspaceName(entity.getClass())); } EntityWriteResult doInsert(T entity, WriteOptions options, CqlIdentifier tableName) { + return this.doInsert(entity, options, tableName, null); + } + EntityWriteResult doInsert(T entity, WriteOptions options, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) { AdaptibleEntity source = getEntityOperations().forEntity(maybeCallBeforeConvert(entity, tableName), getConverter().getConversionService()); @@ -616,7 +631,7 @@ EntityWriteResult doInsert(T entity, WriteOptions options, CqlIdentifier T entityToUse = source.isVersionedEntity() ? source.initializeVersionProperty() : source.getBean(); StatementBuilder builder = getStatementFactory().insert(entityToUse, options, - source.getPersistentEntity(), tableName); + source.getPersistentEntity(), tableName, keyspace); if (source.isVersionedEntity()) { @@ -709,7 +724,7 @@ public WriteResult delete(Object entity, QueryOptions options) { CassandraPersistentEntity persistentEntity = getRequiredPersistentEntity(entity.getClass()); CqlIdentifier tableName = persistentEntity.getTableName(); - StatementBuilder builder = getStatementFactory().delete(entity, options, getConverter(), tableName); + StatementBuilder builder = getStatementFactory().delete(entity, options, getConverter(), tableName, persistentEntity.getCustomKeyspace()); return source.isVersionedEntity() ? doDeleteVersioned(source.appendVersionCondition(builder).build(), entity, source, tableName) @@ -743,7 +758,7 @@ public boolean deleteById(Object id, Class entityClass) { CassandraPersistentEntity entity = getRequiredPersistentEntity(entityClass); CqlIdentifier tableName = entity.getTableName(); - StatementBuilder delete = getStatementFactory().deleteById(id, entity, tableName); + StatementBuilder delete = getStatementFactory().deleteById(id, entity, tableName, getEntityOperations().getCustomKeyspaceName(entityClass)); SimpleStatement statement = delete.build(); maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName)); @@ -761,7 +776,7 @@ public void truncate(Class entityClass) { Assert.notNull(entityClass, "Entity type must not be null"); CqlIdentifier tableName = getTableName(entityClass); - Truncate truncate = QueryBuilder.truncate(tableName); + Truncate truncate = QueryBuilder.truncate(getEntityOperations().getCustomKeyspaceName(entityClass), tableName); SimpleStatement statement = truncate.build(); maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName)); @@ -924,7 +939,6 @@ public String getCql() { return getCqlOperations().execute(new GetConfiguredPageSize()); } - @SuppressWarnings("unchecked") private Function getMapper(EntityProjection projection, CqlIdentifier tableName) { Class targetType = projection.getMappedType().getType(); diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityOperations.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityOperations.java index d2dc44897..d51f2f1ef 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityOperations.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/EntityOperations.java @@ -113,6 +113,20 @@ CqlIdentifier getTableName(Class entityClass) { return getRequiredPersistentEntity(entityClass).getTableName(); } + /** + * Returns custom keyspace defined (if any) where the table for entity {@code entityClass} should be persisted. + * If the keyspace is not overridden in {@link org.springframework.data.cassandra.core.mapping.Table} annotation, + * then {@code null} is returned, signaling that default keyspace of {@link com.datastax.oss.driver.api.core.CqlSession} + * should be used + * + * @param entityClass entity class, must not be {@literal null}. + * @return custom keyspace defined (if any) + */ + @Nullable + CqlIdentifier getCustomKeyspaceName(Class entityClass) { + return getRequiredPersistentEntity(entityClass).getCustomKeyspace(); + } + /** * Introspect the given {@link Class result type} in the context of the {@link Class entity type} whether the returned * type is a projection and what property paths are participating in the projection. diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableInsertOperationSupport.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableInsertOperationSupport.java index 798712bc2..b2c39eae7 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableInsertOperationSupport.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/ExecutableInsertOperationSupport.java @@ -55,7 +55,7 @@ static class ExecutableInsertSupport implements ExecutableInsert { @Nullable private final CqlIdentifier tableName; public ExecutableInsertSupport(CassandraTemplate template, Class domainType, InsertOptions insertOptions, - CqlIdentifier tableName) { + @Nullable CqlIdentifier tableName) { this.template = template; this.domainType = domainType; this.insertOptions = insertOptions; diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java index d009a1c96..287825324 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/StatementFactory.java @@ -194,12 +194,26 @@ public StatementBuilder count(Query query, CassandraPersistentEntity entity, CqlIdentifier tableName) { + return this.count(query, entity, tableName, null); + } + + /** + * Create a {@literal COUNT} statement by mapping {@link Query} to {@link Select}. + * + * @param query user-defined count {@link Query} to execute; must not be {@literal null}. + * @param entity {@link CassandraPersistentEntity entity} to count; must not be {@literal null}. + * @param tableName must not be {@literal null}. + * @param keyspace - keyspace in which table is located. Might be null, in which case the default keyspace is assumed + * @return the select builder. + * @since 2.1 + */ + public StatementBuilder count(Query query, CassandraPersistentEntity */ public StatementBuilder selectOneById(Object id, CassandraPersistentEntity persistentEntity, + CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) { Where where = new Where(); cassandraConverter.write(id, where, persistentEntity); - return StatementBuilder.of(QueryBuilder.selectFrom(tableName).all().limit(1)) - .bind((statement, factory) -> statement.where(toRelations(where, factory))); + return StatementBuilder.of(QueryBuilder.selectFrom(keyspace, tableName).all().limit(1)) + .bind((statement, factory) -> statement.where(toRelations(where, factory))); } /** @@ -234,7 +262,7 @@ public StatementBuilder select(Query query, CassandraPersistentEntity */ public StatementBuilder select(Query query, CassandraPersistentEntity persistentEntity, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) { Assert.notNull(query, "Query must not be null"); Assert.notNull(persistentEntity, "CassandraPersistentEntity must not be null"); @@ -257,7 +299,7 @@ public StatementBuilder createSelect(Query query, CassandraPersistentEntity entity, Filter filter, - List selectors, CqlIdentifier tableName) { + List selectors, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) { - Sort sort = Optional.of(query.getSort()).map(querySort -> getQueryMapper().getMappedSort(querySort, entity)) + Sort sort = Optional.of(query.getSort()) + .map(querySort -> getQueryMapper().getMappedSort(querySort, entity)) .orElse(Sort.unsorted()); - StatementBuilder select = this.createSelectAndOrder(selectors, tableName, filter, sort, keyspace); if (query.getLimit() > 0) { select.apply(it -> it.limit(Math.toIntExact(query.getLimit()))); @@ -631,24 +713,28 @@ private StatementBuilder createSelectAndOrder(List selectors, CqlIdentifier from, - Filter filter, Sort sort) { + private StatementBuilder createSelectAndOrder(List selectors, CqlIdentifier from, Filter filter, Sort sort, @Nullable CqlIdentifier keyspace) { Select select; if (selectors.isEmpty()) { - select = QueryBuilder.selectFrom(from).all(); + select = QueryBuilder.selectFrom(keyspace, from).all(); } else { - - List mappedSelectors = new ArrayList<>( - selectors.size()); + List mappedSelectors = new ArrayList<>(selectors.size()); for (Selector selector : selectors) { com.datastax.oss.driver.api.querybuilder.select.Selector orElseGet = selector.getAlias() - .map(it -> getSelection(selector).as(it)).orElseGet(() -> getSelection(selector)); + .map(it -> getSelection(selector).as(it)) + .orElseGet(() -> getSelection(selector)); mappedSelectors.add(orElseGet); } - select = QueryBuilder.selectFrom(from).selectors(mappedSelectors); + select = QueryBuilder.selectFrom(from) + .selectors(mappedSelectors); } StatementBuilder createSelectAndOrder(List sele Select statementToUse = statement; for (Sort.Order order : sort) { - statementToUse = statementToUse.orderBy(order.getProperty(), - order.isAscending() ? ClusteringOrder.ASC : ClusteringOrder.DESC); + statementToUse = statementToUse.orderBy(order.getProperty(), order.isAscending() ? ClusteringOrder.ASC : ClusteringOrder.DESC); } return statementToUse; diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/legacy/AsyncCassandraTemplate.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/legacy/AsyncCassandraTemplate.java index 8ac7a29b3..c9e9baff8 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/legacy/AsyncCassandraTemplate.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/legacy/AsyncCassandraTemplate.java @@ -904,7 +904,6 @@ public String getCql() { return getAsyncCqlOperations().execute(new GetConfiguredPageSize()).completable().join(); } - @SuppressWarnings("unchecked") private Function getMapper(Class entityType, Class targetType, CqlIdentifier tableName) { EntityProjection projection = entityOperations.introspectProjection(targetType, entityType); diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntity.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntity.java index 9a91cc0c2..d39133f3a 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntity.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntity.java @@ -34,6 +34,7 @@ import org.springframework.expression.spel.support.StandardEvaluationContext; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import com.datastax.oss.driver.api.core.CqlIdentifier; @@ -50,14 +51,16 @@ public class BasicCassandraPersistentEntity extends BasicPersistentEntity (String) AnnotationUtils.getValue(it, "keyspace")) + .orElse(null); + return (keyspaceName = (StringUtils.hasText(keyspace) ? CqlIdentifierGenerator.createIdentifier(keyspace, this.forceQuote) : null)); + } + @Override public void addAssociation(Association association) { throw new UnsupportedCassandraOperationException("Cassandra does not support associations"); @@ -181,7 +192,7 @@ public void setTableName(CqlIdentifier tableName) { * @since 3.0 */ public void setNamingStrategy(NamingStrategy namingStrategy) { - this.namingAccessor.setNamingStrategy(namingStrategy); + this.cqlIdentifierGenerator.setNamingStrategy(namingStrategy); } @Override @@ -189,6 +200,11 @@ public CqlIdentifier getTableName() { return Optional.ofNullable(this.tableName).orElseGet(this::determineTableName); } + @Override + public CqlIdentifier getCustomKeyspace() { + return Optional.ofNullable(this.keyspaceName).orElseGet(this::determineKeyspaceName); + } + /** * @param verifier The verifier to set. */ diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraMappingContext.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraMappingContext.java index 441e66d28..d6af5b296 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraMappingContext.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraMappingContext.java @@ -355,6 +355,8 @@ protected boolean shouldCreatePersistentEntityFor(TypeInformation typeInfo) { @Override protected BasicCassandraPersistentEntity createPersistentEntity(TypeInformation typeInformation) { + + BasicCassandraPersistentEntity entity = isUserDefinedType(typeInformation) ? new CassandraUserTypePersistentEntity<>(typeInformation, getVerifier()) : isTuple(typeInformation) ? new BasicCassandraPersistentTupleEntity<>(typeInformation) @@ -363,6 +365,7 @@ protected BasicCassandraPersistentEntity createPersistentEntity(TypeInfor if (this.namingStrategy != null) { entity.setNamingStrategy(this.namingStrategy); } + Optional.ofNullable(this.applicationContext).ifPresent(entity::setApplicationContext); return entity; diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraPersistentEntity.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraPersistentEntity.java index 7bd6b280d..613aecdb7 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraPersistentEntity.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/CassandraPersistentEntity.java @@ -16,6 +16,7 @@ package org.springframework.data.cassandra.core.mapping; import org.springframework.data.mapping.PersistentEntity; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import com.datastax.oss.driver.api.core.CqlIdentifier; @@ -50,6 +51,12 @@ public interface CassandraPersistentEntity extends PersistentEntity verifiers; + private final Collection verifiers; /** * Create a new {@link CompositeCassandraPersistentEntityMetadataVerifier} using default entity and primary key diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/EmbeddedEntityOperations.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/EmbeddedEntityOperations.java index d21c6a4b6..ac0d6fc5b 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/EmbeddedEntityOperations.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/EmbeddedEntityOperations.java @@ -114,9 +114,8 @@ public CqlIdentifier getTableName() { } @Override - @Deprecated - public void setTableName(org.springframework.data.cassandra.core.cql.CqlIdentifier tableName) { - delegate.setTableName(tableName); + public CqlIdentifier getCustomKeyspace() { + return this.delegate.getCustomKeyspace(); } @Override diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/Table.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/Table.java index d1283e821..9421ce629 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/Table.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/core/mapping/Table.java @@ -51,4 +51,9 @@ */ @Deprecated boolean forceQuote() default false; + + /** + * Keyspace where this table is located + */ + String keyspace() default ""; } diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntityUnitTests.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntityUnitTests.java index 12ac991c8..26114da27 100755 --- a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntityUnitTests.java +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/core/mapping/BasicCassandraPersistentEntityUnitTests.java @@ -89,14 +89,13 @@ void tableAllowsReferencingSpringBean() { @Test void setForceQuoteCallsSetTableName() { - BasicCassandraPersistentEntity entitySpy = spy( - new BasicCassandraPersistentEntity<>(TypeInformation.of(Message.class))); + BasicCassandraPersistentEntity entitySpy = spy(new BasicCassandraPersistentEntity<>(TypeInformation.of(Message.class))); DirectFieldAccessor directFieldAccessor = new DirectFieldAccessor(entitySpy); entitySpy.setTableName(CqlIdentifier.fromCql("Messages")); - assertThat(directFieldAccessor.getPropertyValue("forceQuote")).isNull(); + assertThat((Boolean) directFieldAccessor.getPropertyValue("forceQuote")).isFalse(); entitySpy.setForceQuote(true); diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/domain/EntityWithKeyspace.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/domain/EntityWithKeyspace.java new file mode 100644 index 000000000..b569ff7f3 --- /dev/null +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/domain/EntityWithKeyspace.java @@ -0,0 +1,11 @@ +package org.springframework.data.cassandra.domain; + +import org.springframework.data.annotation.Id; +import org.springframework.data.cassandra.core.mapping.Table; + +@Table(value = "entity_with_keyspace", keyspace = "custom") +public record EntityWithKeyspace( + @Id String id, + String name, + String type +){} \ No newline at end of file diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/PartTreeCassandraQueryUnitTests.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/PartTreeCassandraQueryUnitTests.java index ad4c03be4..773cde93b 100644 --- a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/PartTreeCassandraQueryUnitTests.java +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/query/PartTreeCassandraQueryUnitTests.java @@ -23,6 +23,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import org.junit.jupiter.api.BeforeEach; @@ -39,9 +40,11 @@ import org.springframework.data.cassandra.core.mapping.CassandraMappingContext; import org.springframework.data.cassandra.core.mapping.UserTypeResolver; import org.springframework.data.cassandra.domain.AddressType; +import org.springframework.data.cassandra.domain.EntityWithKeyspace; import org.springframework.data.cassandra.domain.Group; import org.springframework.data.cassandra.domain.Person; import org.springframework.data.cassandra.repository.AllowFiltering; +import org.springframework.data.cassandra.repository.CassandraRepository; import org.springframework.data.cassandra.repository.Consistency; import org.springframework.data.cassandra.repository.MapIdCassandraRepository; import org.springframework.data.cassandra.repository.Query; @@ -225,6 +228,13 @@ void shouldApplyConsistencyLevel() { assertThat(statement.getConsistencyLevel()).isEqualTo(DefaultConsistencyLevel.LOCAL_ONE); } + @Test + void shouldCreateSelectQueryWithKeyspaceApplied() { + String query = deriveQueryFromMethod("findByName", EntityWithKeyspaceRepository.class, "Mark"); + + assertThat(query).isEqualTo("SELECT * FROM custom.entity_with_keyspace WHERE name='Mark'"); + } + @Test // DATACASS-512 void shouldCreateCountQuery() { @@ -251,6 +261,10 @@ void shouldCreateExistsQuery() { } private String deriveQueryFromMethod(String method, Object... args) { + return this.deriveQueryFromMethod(method, Repo.class, args); + } + + private String deriveQueryFromMethod(String method, Class> cassandraRepositoryClass, Object... args) { Class[] types = new Class[args.length]; @@ -258,21 +272,22 @@ private String deriveQueryFromMethod(String method, Object... args) { types[i] = ClassUtils.getUserClass(args[i].getClass()); } - SimpleStatement statement = deriveQueryFromMethod(Repo.class, method, types, args); + SimpleStatement statement = deriveQueryFromMethod(cassandraRepositoryClass, method, types, args); String query = statement.getQuery(); List positionalValues = statement.getPositionalValues(); for (Object positionalValue : positionalValues) { query = query.replaceFirst("\\?", - positionalValue != null - ? CodecRegistry.DEFAULT.codecFor((Class) positionalValue.getClass()).format(positionalValue) - : "null"); + positionalValue != null + ? CodecRegistry.DEFAULT.codecFor((Class) positionalValue.getClass()).format(positionalValue) + : "null"); } return query; } + private SimpleStatement deriveQueryFromMethod(Class repositoryInterface, String method, Class[] types, Object... args) { @@ -308,6 +323,11 @@ interface GroupRepository extends MapIdCassandraRepository { Group findByIdHashPrefix(String hashPrefix); } + interface EntityWithKeyspaceRepository extends CassandraRepository { + + Optional findByName(String name); + } + @SuppressWarnings("unused") interface Repo extends MapIdCassandraRepository { diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/support/SimpleCassandraRepositoryIntegrationTests.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/support/SimpleCassandraRepositoryIntegrationTests.java index 9126f29de..3a1f79c42 100644 --- a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/support/SimpleCassandraRepositoryIntegrationTests.java +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/repository/support/SimpleCassandraRepositoryIntegrationTests.java @@ -28,6 +28,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.BeansException; @@ -37,12 +38,14 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.data.cassandra.config.CqlSessionFactoryBean; import org.springframework.data.cassandra.core.CassandraOperations; import org.springframework.data.cassandra.core.mapping.event.AbstractCassandraEventListener; import org.springframework.data.cassandra.core.mapping.event.AfterSaveEvent; import org.springframework.data.cassandra.core.mapping.event.BeforeSaveEvent; import org.springframework.data.cassandra.core.mapping.event.CassandraMappingEvent; import org.springframework.data.cassandra.core.query.CassandraPageRequest; +import org.springframework.data.cassandra.domain.EntityWithKeyspace; import org.springframework.data.cassandra.domain.User; import org.springframework.data.cassandra.domain.UserToken; import org.springframework.data.cassandra.repository.CassandraRepository; @@ -56,7 +59,14 @@ import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import org.springframework.util.ClassUtils; +import com.datastax.driver.core.querybuilder.QueryBuilder; +import com.datastax.driver.core.querybuilder.Select; +import com.datastax.driver.core.schemabuilder.Create; import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.cql.Statement; +import com.datastax.oss.driver.api.core.cql.StatementBuilder; +import com.datastax.oss.protocol.internal.request.Query; /** * Integration tests for {@link SimpleCassandraRepository}. @@ -92,8 +102,9 @@ CaptureEventListener eventListener() { private BeanFactory beanFactory; private CassandraRepositoryFactory factory; private ClassLoader classLoader; - private UserRepostitory repository; + private UserRepository userRepository; + private EntityWithKeyspaceRepostitory entityWithKeyspaceRepostitory; private User dave, oliver, carter, boyd; @Override @@ -115,26 +126,38 @@ void setUp() { factory.setBeanFactory(beanFactory); factory.setEvaluationContextProvider(ExtensionAwareQueryMethodEvaluationContextProvider.DEFAULT); - repository = factory.getRepository(UserRepostitory.class); + userRepository = factory.getRepository(UserRepository.class); + entityWithKeyspaceRepostitory = factory.getRepository(EntityWithKeyspaceRepostitory.class); cassandraVersion = CassandraVersion.get(session); - repository.deleteAll(); + userRepository.deleteAll(); dave = new User("42", "Dave", "Matthews"); oliver = new User("4", "Oliver August", "Matthews"); carter = new User("49", "Carter", "Beauford"); boyd = new User("45", "Boyd", "Tinsley"); - repository.saveAll(Arrays.asList(oliver, dave, carter, boyd)); + userRepository.saveAll(Arrays.asList(oliver, dave, carter, boyd)); eventListener.clear(); } + @Test + void whenInsertingEntityWithKeyspaceSpecified_thenAppliedQueryWithKeyspace() { + session.execute("CREATE KEYSPACE custom WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }"); + session.execute("CREATE TABLE custom.entity_with_keyspace (id TEXT, name TEXT, type TEXT, PRIMARY KEY ((id)) )"); + EntityWithKeyspace entityWithKeyspace = new EntityWithKeyspace("12", "Artur", "COMMON"); + entityWithKeyspaceRepostitory.save(entityWithKeyspace); + Optional foundEntity = entityWithKeyspaceRepostitory.findById("12"); + Assertions.assertThat(foundEntity).isPresent(); + Assertions.assertThat(foundEntity.get().type()).isEqualTo("COMMON"); + } + @Test // DATACASS-396 void existsByIdShouldReturnTrueForExistingObject() { - Boolean exists = repository.existsById(dave.getId()); + Boolean exists = userRepository.existsById(dave.getId()); assertThat(exists).isTrue(); } @@ -142,7 +165,7 @@ void existsByIdShouldReturnTrueForExistingObject() { @Test // DATACASS-396 void existsByIdShouldReturnFalseForAbsentObject() { - boolean exists = repository.existsById("unknown"); + boolean exists = userRepository.existsById("unknown"); assertThat(exists).isFalse(); } @@ -150,7 +173,7 @@ void existsByIdShouldReturnFalseForAbsentObject() { @Test // DATACASS-396 void existsByMonoOfIdShouldReturnTrueForExistingObject() { - boolean exists = repository.existsById(dave.getId()); + boolean exists = userRepository.existsById(dave.getId()); assertThat(exists).isTrue(); } @@ -158,7 +181,7 @@ void existsByMonoOfIdShouldReturnTrueForExistingObject() { @Test // DATACASS-396 void findByIdShouldReturnObject() { - Optional User = repository.findById(dave.getId()); + Optional User = userRepository.findById(dave.getId()); assertThat(User).contains(dave); } @@ -166,7 +189,7 @@ void findByIdShouldReturnObject() { @Test // DATACASS-396 void findByIdShouldCompleteWithoutValueForAbsentObject() { - Optional User = repository.findById("unknown"); + Optional User = userRepository.findById("unknown"); assertThat(User).isEmpty(); } @@ -174,7 +197,7 @@ void findByIdShouldCompleteWithoutValueForAbsentObject() { @Test // DATACASS-396, DATACASS-416 void findAllShouldReturnAllResults() { - List Users = repository.findAll(); + List Users = userRepository.findAll(); assertThat(Users).hasSize(4); } @@ -182,7 +205,7 @@ void findAllShouldReturnAllResults() { @Test // DATACASS-396, DATACASS-416 void findAllByIterableOfIdShouldReturnResults() { - List Users = repository.findAllById(Arrays.asList(dave.getId(), boyd.getId())); + List Users = userRepository.findAllById(Arrays.asList(dave.getId(), boyd.getId())); assertThat(Users).hasSize(2); } @@ -190,11 +213,11 @@ void findAllByIterableOfIdShouldReturnResults() { @Test // DATACASS-56 void findAllWithPaging() { - Slice slice = repository.findAll(CassandraPageRequest.first(2)); + Slice slice = userRepository.findAll(CassandraPageRequest.first(2)); assertThat(slice).hasSize(2); - assertThat(repository.findAll(slice.nextPageable())).hasSize(2); + assertThat(userRepository.findAll(slice.nextPageable())).hasSize(2); } @Test // DATACASS-700 @@ -202,7 +225,7 @@ void findAllWithPagingAndSorting() { assumeTrue(cassandraVersion.isGreaterThan(CASSANDRA_3)); - UserTokenRepostitory repository = factory.getRepository(UserTokenRepostitory.class); + UserTokenRepository repository = factory.getRepository(UserTokenRepository.class); repository.deleteAll(); UUID id = UUID.randomUUID(); @@ -232,7 +255,7 @@ void findAllWithPagingAndSorting() { @Test // DATACASS-396 void countShouldReturnNumberOfRecords() { - long count = repository.count(); + long count = userRepository.count(); assertThat(count).isEqualTo(4); } @@ -240,23 +263,23 @@ void countShouldReturnNumberOfRecords() { @Test // DATACASS-415 void insertEntityShouldInsertEntity() { - repository.deleteAll(); + userRepository.deleteAll(); User User = new User("36", "Homer", "Simpson"); - repository.insert(User); + userRepository.insert(User); - assertThat(repository.count()).isEqualTo(1); + assertThat(userRepository.count()).isEqualTo(1); } @Test // DATACASS-415 void insertIterableOfEntitiesShouldInsertEntity() { - repository.deleteAll(); + userRepository.deleteAll(); - repository.insert(Arrays.asList(dave, oliver, boyd)); + userRepository.insert(Arrays.asList(dave, oliver, boyd)); - assertThat(repository.count()).isEqualTo(3); + assertThat(userRepository.count()).isEqualTo(3); } @Test // DATACASS-396, DATACASS-573 @@ -265,11 +288,11 @@ void saveEntityShouldUpdateExistingEntity() { dave.setFirstname("Hello, Dave"); dave.setLastname("Bowman"); - User saved = repository.save(dave); + User saved = userRepository.save(dave); assertThat(saved).isSameAs(saved); - Optional loaded = repository.findById(dave.getId()); + Optional loaded = userRepository.findById(dave.getId()); assertThat(loaded).isPresent(); @@ -285,7 +308,7 @@ void saveShouldEmitEvents() { dave.setFirstname("Hello, Dave"); dave.setLastname("Bowman"); - repository.save(dave); + userRepository.save(dave); assertThat(eventListener.getBeforeSave()).hasSize(1); assertThat(eventListener.getAfterSave()).hasSize(1); @@ -296,11 +319,11 @@ void saveEntityShouldInsertNewEntity() { User User = new User("36", "Homer", "Simpson"); - User saved = repository.save(User); + User saved = userRepository.save(User); assertThat(saved).isEqualTo(User); - Optional loaded = repository.findById(User.getId()); + Optional loaded = userRepository.findById(User.getId()); assertThat(loaded).contains(User); } @@ -308,13 +331,13 @@ void saveEntityShouldInsertNewEntity() { @Test // DATACASS-396, DATACASS-416, DATACASS-573 void saveIterableOfNewEntitiesShouldInsertEntity() { - repository.deleteAll(); + userRepository.deleteAll(); - List saved = repository.saveAll(Arrays.asList(dave, oliver, boyd)); + List saved = userRepository.saveAll(Arrays.asList(dave, oliver, boyd)); assertThat(saved).hasSize(3).contains(dave, oliver, boyd); - assertThat(repository.count()).isEqualTo(3); + assertThat(userRepository.count()).isEqualTo(3); } @Test // DATACASS-396, DATACASS-416 @@ -325,23 +348,23 @@ void saveIterableOfMixedEntitiesShouldInsertEntity() { dave.setFirstname("Hello, Dave"); dave.setLastname("Bowman"); - List saved = repository.saveAll(Arrays.asList(User, dave)); + List saved = userRepository.saveAll(Arrays.asList(User, dave)); assertThat(saved).hasSize(2); - Optional persistentDave = repository.findById(dave.getId()); + Optional persistentDave = userRepository.findById(dave.getId()); assertThat(persistentDave).contains(dave); - Optional persistentHomer = repository.findById(User.getId()); + Optional persistentHomer = userRepository.findById(User.getId()); assertThat(persistentHomer).contains(User); } @Test // DATACASS-396, DATACASS-416 void deleteAllShouldRemoveEntities() { - repository.deleteAll(); + userRepository.deleteAll(); - List result = repository.findAll(); + List result = userRepository.findAll(); assertThat(result).isEmpty(); } @@ -349,9 +372,9 @@ void deleteAllShouldRemoveEntities() { @Test // DATACASS-396 void deleteByIdShouldRemoveEntity() { - repository.deleteById(dave.getId()); + userRepository.deleteById(dave.getId()); - Optional loaded = repository.findById(dave.getId()); + Optional loaded = userRepository.findById(dave.getId()); assertThat(loaded).isEmpty(); } @@ -359,9 +382,9 @@ void deleteByIdShouldRemoveEntity() { @Test // DATACASS-825 void deleteAllByIdShouldRemoveEntity() { - repository.deleteAllById(Collections.singletonList(dave.getId())); + userRepository.deleteAllById(Collections.singletonList(dave.getId())); - Optional loaded = repository.findById(dave.getId()); + Optional loaded = userRepository.findById(dave.getId()); assertThat(loaded).isEmpty(); } @@ -369,9 +392,9 @@ void deleteAllByIdShouldRemoveEntity() { @Test // DATACASS-396 void deleteShouldRemoveEntity() { - repository.delete(dave); + userRepository.delete(dave); - Optional loaded = repository.findById(dave.getId()); + Optional loaded = userRepository.findById(dave.getId()); assertThat(loaded).isEmpty(); } @@ -379,16 +402,18 @@ void deleteShouldRemoveEntity() { @Test // DATACASS-396 void deleteIterableOfEntitiesShouldRemoveEntities() { - repository.deleteAll(Arrays.asList(dave, boyd)); + userRepository.deleteAll(Arrays.asList(dave, boyd)); - Optional loaded = repository.findById(boyd.getId()); + Optional loaded = userRepository.findById(boyd.getId()); assertThat(loaded).isEmpty(); } - interface UserRepostitory extends CassandraRepository {} + interface UserRepository extends CassandraRepository {} + + interface EntityWithKeyspaceRepostitory extends CassandraRepository {} - interface UserTokenRepostitory extends CassandraRepository { + interface UserTokenRepository extends CassandraRepository { Slice findAllByUserId(UUID id, Pageable pageRequest); }