RedisLock
package test.utils;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.SetArgs;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.api.async.RedisScriptingAsyncCommands;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import lombok.extern.log4j.Log4j2;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.nio.charset.StandardCharsets;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@Log4j2
public class RedisLock {
private StringRedisTemplate redisTemplate;
public static final String NX = "NX";
public static final String EX = "EX";
public static final String OK = "OK";
public static final String UNLOCK_LUA = "if redis.call(\"get\",KEYS[1]) == ARGV[1] then return redis.call(\"del\",KEYS[1]) else return 0 end";
private String lockKey;
private String lockValue;
private int expireTime = 120;
private long timeOut = 500;
private volatile boolean locked = false;
private static final String REDIS_LIB_MISMATCH = "Failed to convert nativeConnection. " +
"Is your SpringBoot main version > 2.0 ? Only lib:lettuce is supported.";
public RedisLock(StringRedisTemplate redisTemplate, String lockKey) {
this.redisTemplate = redisTemplate;
this.lockKey = lockKey;
}
public RedisLock(StringRedisTemplate redisTemplate, String lockKey, int expireTime, long timeOut) {
this(redisTemplate, lockKey);
this.expireTime = expireTime;
this.timeOut = timeOut;
}
public boolean tryLock() {
lockValue = UUID.randomUUID().toString();
long timeout = timeOut * 1000000;
long nowTime = System.nanoTime();
while ((System.nanoTime() - nowTime) < timeout) {
if (OK.equalsIgnoreCase(this.set(lockKey, lockValue, expireTime))) {
locked = true;
return true;
}
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
log.info("Sleep is interrupted", e);
}
}
return locked;
}
public boolean lock() {
lockValue = UUID.randomUUID().toString();
String result = set(lockKey, lockValue, expireTime);
locked = OK.equalsIgnoreCase(result);
return locked;
}
public boolean lockBlock() {
lockValue = UUID.randomUUID().toString();
while (true) {
String result = set(lockKey, lockValue, expireTime);
if (OK.equalsIgnoreCase(result)) {
locked = true;
return locked;
}
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
log.info("Sleep is interrupted", e);
}
}
}
public Boolean unlock() {
if (locked) {
try {
return redisTemplate.execute((RedisConnection connection) -> {
Object nativeConnection = connection.getNativeConnection();
Long result = 0L;
byte[] keyBytes = lockKey.getBytes(StandardCharsets.UTF_8);
byte[] valueBytes = lockValue.getBytes(StandardCharsets.UTF_8);
Object[] keyParam = new Object[]{keyBytes};
if (nativeConnection instanceof RedisScriptingAsyncCommands) {
RedisScriptingAsyncCommands<Object,byte[]> command = (RedisScriptingAsyncCommands<Object,byte[]>) nativeConnection;
RedisFuture future = command.eval(UNLOCK_LUA, ScriptOutputType.INTEGER, keyParam, valueBytes);
result = getEvalResult(future,connection);
}else{
log.warn(REDIS_LIB_MISMATCH);
}
if (result == 0L && !StringUtils.isEmpty(lockKey)) {
log.debug("Unlock failed! key={}, time={}", lockKey, System.currentTimeMillis());
}
locked = result == 0L;
return result == 1L;
});
} catch (Throwable e) {
if(log.isDebugEnabled()) {
log.debug(
"The redis you are using dose NOT support EVAL. Use downgrade method to unlock. {}",
e.getMessage());
}
String value = this.get(lockKey, String.class);
if (lockValue.equals(value)) {
redisTemplate.delete(lockKey);
return true;
}
return false;
}
}
return true;
}
private Long getEvalResult(RedisFuture future,RedisConnection connection){
try {
Object o = future.get();
return (Long)o;
} catch (InterruptedException | ExecutionException e) {
log.error("Future get failed, trying to close connection.", e);
closeConnection(connection);
return 0L;
}
}
public boolean isLock() {
return locked;
}
private String set(final String key, final String value, final long expireSeconds) {
Assert.isTrue(!StringUtils.isEmpty(key), "Invalid key");
return redisTemplate.execute((RedisCallback<String>) connection -> {
Object nativeConnection = connection.getNativeConnection();
String result = null;
byte[] keyByte = key.getBytes(StandardCharsets.UTF_8);
byte[] valueByte = value.getBytes(StandardCharsets.UTF_8);
if(nativeConnection instanceof RedisAsyncCommands){
RedisAsyncCommands command = (RedisAsyncCommands) nativeConnection;
result = command.getStatefulConnection().sync().set(keyByte, valueByte, SetArgs.Builder.nx().ex(expireSeconds));
}else if(nativeConnection instanceof RedisAdvancedClusterAsyncCommands){
RedisAdvancedClusterAsyncCommands clusterAsyncCommands = (RedisAdvancedClusterAsyncCommands) nativeConnection;
result = clusterAsyncCommands.getStatefulConnection().sync().set(keyByte, valueByte, SetArgs.Builder.nx().ex(expireSeconds));
}else{
log.error(REDIS_LIB_MISMATCH);
}
return result;
});
}
private void closeConnection(RedisConnection connection){
try{
connection.close();
}catch (Exception e2){
log.error("close connection fail.", e2);
}
}
private <T> T get(final String key, Class<T> clazz) {
Assert.isTrue(!StringUtils.isEmpty(key), "Invalid key");
return redisTemplate.execute((RedisConnection connection) -> {
Object nativeConnection = connection.getNativeConnection();
Object result = null;
byte[] keyByte = key.getBytes(StandardCharsets.UTF_8);
if(nativeConnection instanceof RedisAsyncCommands){
RedisAsyncCommands command = (RedisAsyncCommands) nativeConnection;
result = command.getStatefulConnection().sync().get(keyByte);
}else if(nativeConnection instanceof RedisAdvancedClusterAsyncCommands){
RedisAdvancedClusterAsyncCommands clusterAsyncCommands = (RedisAdvancedClusterAsyncCommands) nativeConnection;
result = clusterAsyncCommands.getStatefulConnection().sync().get(keyByte);
}
return clazz.cast(result);
});
}
}