Skip to content

Commit

Permalink
fix #9 initial write on connect (AUTH and SELECT)
Browse files Browse the repository at this point in the history
  • Loading branch information
Valerian Barbot committed Sep 14, 2013
1 parent 7b29998 commit 9e883d5
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 51 deletions.
48 changes: 28 additions & 20 deletions src/main/scala/redis/Redis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,6 @@ import redis.actors.{RedisSubscriberActorWithCallback, RedisClientActor}
import redis.api.pubsub._
import java.util.concurrent.atomic.AtomicLong
import akka.event.Logging
import redis.protocol.RedisReply

trait Request {
implicit val executionContext: ExecutionContext

def redisConnection: ActorRef

def send[T](redisCommand: RedisCommand[_ <: RedisReply, T]): Future[T] = {
val promise = Promise[T]()
redisConnection ! Operation(redisCommand, promise)
promise.future
}
}

trait RedisCommands
extends Keys
Expand All @@ -35,14 +22,16 @@ trait RedisCommands
with Connection
with Server

abstract class RedisClientActorLike(system: ActorSystem) {
abstract class RedisClientActorLike(system: ActorSystem) extends ActorRequest {
var host: String
var port: Int
val name: String
val password: Option[String] = None
val db: Option[Int] = None
implicit val executionContext = system.dispatcher

val redisConnection: ActorRef = system.actorOf(
Props(classOf[RedisClientActor], new InetSocketAddress(host, port))
Props(classOf[RedisClientActor], new InetSocketAddress(host, port), getConnectOperations)
.withDispatcher(Redis.dispatcher),
name + '-' + Redis.tempName()
)
Expand All @@ -55,6 +44,20 @@ abstract class RedisClientActorLike(system: ActorSystem) {
}
}

def onConnect(redis: RedisCommands): Unit = {
password.foreach(redis.auth(_)) // TODO log on auth failure
db.foreach(redis.select(_))
}

def getConnectOperations: () => Seq[Operation[_, _]] = () => {
val self = this
val redis = new BufferedRequest with RedisCommands {
implicit val executionContext: ExecutionContext = self.executionContext
}
onConnect(redis)
redis.operations.result()
}

/**
* Disconnect from the server (stop the actor)
*/
Expand All @@ -65,14 +68,18 @@ abstract class RedisClientActorLike(system: ActorSystem) {

case class RedisClient(var host: String = "localhost",
var port: Int = 6379,
name: String = "RedisClient")
name: String = "RedisClient",
override val password: Option[String] = None,
override val db: Option[Int] = None)
(implicit _system: ActorSystem) extends RedisClientActorLike(_system) with RedisCommands with Transactions {

}

case class RedisBlockingClient(var host: String = "localhost",
var port: Int = 6379,
name: String = "RedisBlockingClient")
name: String = "RedisBlockingClient",
override val password: Option[String] = None,
override val db: Option[Int] = None)
(implicit _system: ActorSystem) extends RedisClientActorLike(_system) with BLists {
}

Expand All @@ -83,12 +90,13 @@ case class RedisPubSub(
patterns: Seq[String],
onMessage: Message => Unit = _ => {},
onPMessage: PMessage => Unit = _ => {},
authPassword: Option[String] = None,
name: String = "RedisPubSub"
)(implicit system: ActorSystem) {

val redisConnection: ActorRef = system.actorOf(
Props(classOf[RedisSubscriberActorWithCallback],
new InetSocketAddress(host, port), channels, patterns, onMessage, onPMessage)
new InetSocketAddress(host, port), channels, patterns, onMessage, onPMessage, authPassword)
.withDispatcher(Redis.dispatcher),
name + '-' + Redis.tempName()
)
Expand Down Expand Up @@ -144,7 +152,7 @@ case class SentinelClient(var host: String = "localhost",

val redisPubSubConnection: ActorRef = system.actorOf(
Props(classOf[RedisSubscriberActorWithCallback],
new InetSocketAddress(host, port), channels, Seq(), onMessage, (pmessage: PMessage) => {})
new InetSocketAddress(host, port), channels, Seq(), onMessage, (pmessage: PMessage) => {}, None)
.withDispatcher(Redis.dispatcher),
name + '-' + Redis.tempName()
)
Expand Down Expand Up @@ -185,7 +193,7 @@ abstract class SentinelMonitored(system: ActorSystem) {
}
}

abstract class SentinelMonitoredRedisClientLike(system: ActorSystem) extends SentinelMonitored(system) {
abstract class SentinelMonitoredRedisClientLike(system: ActorSystem) extends SentinelMonitored(system) with ActorRequest {
val redisClient: RedisClientActorLike
val onMasterChange = (ip: String, port: Int) => {
redisClient.reconnect(ip, port)
Expand Down
37 changes: 37 additions & 0 deletions src/main/scala/redis/Request.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package redis

import redis.protocol.RedisReply
import scala.concurrent.{ExecutionContext, Promise, Future}
import scala.collection.immutable.Queue
import akka.actor.ActorRef


trait Request {
implicit val executionContext: ExecutionContext

def send[T](redisCommand: RedisCommand[_ <: RedisReply, T]): Future[T]
}

trait ActorRequest {
implicit val executionContext: ExecutionContext

def redisConnection: ActorRef

def send[T](redisCommand: RedisCommand[_ <: RedisReply, T]): Future[T] = {
val promise = Promise[T]()
redisConnection ! Operation(redisCommand, promise)
promise.future
}
}

trait BufferedRequest extends redis.Request {
implicit val executionContext: ExecutionContext

val operations = Queue.newBuilder[Operation[_, _]]

override def send[T](redisCommand: RedisCommand[_ <: RedisReply, T]): Future[T] = {
val promise = Promise[T]()
operations += Operation(redisCommand, promise)
promise.future
}
}
15 changes: 14 additions & 1 deletion src/main/scala/redis/actors/RedisClientActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import akka.actor.{OneForOneStrategy, Terminated, PoisonPill, Props}
import scala.collection.mutable
import akka.actor.SupervisorStrategy.Stop

class RedisClientActor(override val address: InetSocketAddress) extends RedisWorkerIO(address) {
class RedisClientActor(override val address: InetSocketAddress, getConnectOperations: () => Seq[Operation[_, _]]) extends RedisWorkerIO(address) {


var repliesDecoder = initRepliesDecoder()
Expand Down Expand Up @@ -60,6 +60,19 @@ class RedisClientActor(override val address: InetSocketAddress) extends RedisWor
Stop
}
}

def onConnectWrite(): ByteString = {
val ops = getConnectOperations()
val buffer = new ByteStringBuilder

val queuePromisesConnect = mutable.Queue[Operation[_,_]]()
ops.foreach(operation => {
buffer.append(operation.redisCommand.encodedRequest)
queuePromisesConnect enqueue operation
})
queuePromises = queuePromisesConnect ++ queuePromises
buffer.result()
}
}

case object NoConnectionException extends RuntimeException("No Connection established")
12 changes: 9 additions & 3 deletions src/main/scala/redis/actors/RedisSubscriberActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ import redis.protocol.{MultiBulk, RedisReply}
import redis.api.pubsub._
import java.net.InetSocketAddress
import scala.collection.mutable
import redis.api.connection.Auth

class RedisSubscriberActorWithCallback(
address: InetSocketAddress,
channels: Seq[String],
patterns: Seq[String],
messageCallback: Message => Unit,
pmessageCallback: PMessage => Unit
) extends RedisSubscriberActor(address, channels, patterns) {
pmessageCallback: PMessage => Unit,
authPassword: Option[String] = None
) extends RedisSubscriberActor(address, channels, patterns, authPassword) {
def onMessage(m: Message) = messageCallback(m)

def onPMessage(pm: PMessage) = pmessageCallback(pm)
Expand All @@ -21,8 +23,12 @@ class RedisSubscriberActorWithCallback(
abstract class RedisSubscriberActor(
address: InetSocketAddress,
channels: Seq[String],
patterns: Seq[String]
patterns: Seq[String],
authPassword: Option[String] = None
) extends RedisWorkerIO(address) with DecodeReplies {
def onConnectWrite(): ByteString = {
authPassword.map(Auth(_).encodedRequest).getOrElse(ByteString.empty)
}

def onMessage(m: Message): Unit

Expand Down
15 changes: 14 additions & 1 deletion src/main/scala/redis/actors/RedisWorkerIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ abstract class RedisWorkerIO(val address: InetSocketAddress) extends Actor with
sender ! Register(self)
tcpWorker = sender
initConnectedBuffer()
tryWrite()
tryInitialWrite() // TODO write something in head buffer
become(connected)
log.info("Connected to " + cmd.remoteAddress)
}
Expand Down Expand Up @@ -122,6 +122,19 @@ abstract class RedisWorkerIO(val address: InetSocketAddress) extends Actor with

def restartConnection() = reconnect()

def onConnectWrite() : ByteString

def tryInitialWrite() {
val data = onConnectWrite()

if(data.nonEmpty) {
writeWorker(data ++ bufferWrite.result())
bufferWrite.clear()
} else {
tryWrite()
}
}

def tryWrite() {
if (bufferWrite.length == 0) {
readyToWrite = true
Expand Down
13 changes: 3 additions & 10 deletions src/main/scala/redis/commands/Transactions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import scala.util.{Failure, Success}
import redis.api.transactions.{Watch, Exec, Multi}
import akka.util.ByteString

trait Transactions extends Request {
trait Transactions extends ActorRequest {

def multi(): TransactionBuilder = transaction()

Expand All @@ -32,9 +32,9 @@ trait Transactions extends Request {

}

case class TransactionBuilder(redisConnection: ActorRef)(implicit val executionContext: ExecutionContext) extends RedisCommands {
case class TransactionBuilder(redisConnection: ActorRef)(implicit val executionContext: ExecutionContext) extends BufferedRequest with RedisCommands {

val operations = Queue.newBuilder[Operation[_, _]]
//val operations = Queue.newBuilder[Operation[_, _]]
val watcher = Set.newBuilder[String]

def unwatch() {
Expand All @@ -60,13 +60,6 @@ case class TransactionBuilder(redisConnection: ActorRef)(implicit val executionC
t.process(p)
p.future
}


override def send[T](redisCommand: RedisCommand[_ <: RedisReply, T]): Future[T] = {
val promise = Promise[T]()
operations += Operation(redisCommand, promise)
promise.future
}
}

case class Transaction(watcher: Set[String], operations: Queue[Operation[_, _]], redisConnection: ActorRef)(implicit val executionContext: ExecutionContext) {
Expand Down
1 change: 1 addition & 0 deletions src/test/scala/redis/RedisPubSubSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import redis.actors.RedisSubscriberActor
import java.net.InetSocketAddress
import akka.actor.{Props, ActorRef}
import akka.testkit.{TestActorRef, TestProbe}
import akka.util.ByteString

class RedisPubSubSpec extends RedisSpec {

Expand Down
17 changes: 17 additions & 0 deletions src/test/scala/redis/RedisTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,21 @@ class RedisTest extends RedisSpec {
}
}

"init connection test" should {
"ok" in {
withRedisServer(port => {
val redis = RedisClient(port = port)
// TODO set password (CONFIG SET requiredpass password)
val r = for {
_ <- redis.select(2)
_ <- redis.set("keyDbSelect", "2")
} yield {
val redis = RedisClient(port = port, password = Some("password"), db = Some(2))
Await.result(redis.get[String]("keyDbSelect"), timeOut) must beSome("2")
}
Await.result(r, timeOut)
})
}
}

}
Loading

0 comments on commit 9e883d5

Please # to comment.