您的位置:首页 > 数据库

java代码,使用sql语句操作mongo数据库

2017-08-25 18:29 1076 查看
        如果使用mongo的查询方式查询内容,对于mongo不熟悉的同学来说,是一件相对比较繁琐的事情,所以就想到用sql语句的方式来查询mongo的结果集,druid可以很好的解析SQL语句,所以使用它来解析sql是再好不过了

以下是完成的部分功能代码,其他功能后续会慢慢补充

package com.quark.util;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.mongodb.BasicDBObject;
import com.mongodb.MongoClient;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import org.bson.Document;
import org.bson.types.Binary;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.*;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static com.quark.util.MyLog.logger;
import static java.util.Arrays.asList;

public class MongoUtil extends Connection {
  private static final String EQUAL = "=";
  private static final String GT = ">";
  private static final String GTE = ">=";
  private static final String LT = "<";
  private static final String LTE = "<=";
  private static final String IN = "IN";

  private static String env;
  private static String system;
  private static Properties jdbcProp;

  public static MongoClient mongoClient;
  public static MongoDatabase mongoDB;

  public MongoUtil() {
  }

  private String sql;

  public MongoUtil(String sql) {
    this.sql = sql;
  }

  /**
   * 连接mongo数据库
   *
   * @return
   */
  public boolean connMongo(String env, String system) {
    try {
      this.env = env.trim();
      this.system = system.trim();
      initProperties();
      connect();
      return true;
    } catch (Exception e) {
      return false;
    }
  }

  private void connect() {
    ip = jdbcProp.getProperty("ip");
    port = jdbcProp.getProperty("port");
    userName = jdbcProp.getProperty("user");
    passwd = jdbcProp.getProperty("password");
    dataBase = jdbcProp.getProperty("database");
    ServerAddress serverAddress = new ServerAddress(ip, Integer.valueOf(port));
    List<ServerAddress> addrs = new ArrayList<>();
    addrs.add(serverAddress);
    MongoCredential credential = MongoCredential.createScramSha1Credential(userName, dataBase, passwd.toCharArray());
    List<MongoCredential> credentials = new ArrayList<>();
    credentials.add(credential);
    //通过连接认证获取MongoDB连接
    mongoClient = new MongoClient(addrs, credentials);
    mongoDB = mongoClient.getDatabase(dataBase);
  }

  /**
   * 读取mongo的配置文件xxx.properties
   */
  private void initProperties() {
    jdbcProp = new Properties();
    FileInputStream fr = null;
    try {
      fr = new FileInputStream(new File("xxx.properties"));
      jdbcProp.load(fr);
    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  private static String convert(String syntax) {
    switch (syntax) {
      case GT:
        return "$gt";
      case GTE:
        return "$gte";
      case LT:
        return "$lt";
      case LTE:
        return "$lte";
      case IN:
        return "$in";
    }
    return null;
  }

  private int convertSort(String sort) {
    if ("DESC".equalsIgnoreCase(sort)) {
      return -1;
    }
    return 1;
  }

  private class FunctionItem {
    private Function function;
    private String name;
    private List<SQLExpr> value;
    private String operation;

    public FunctionItem(String name, String token, List<SQLExpr> value, Function function) {
      this.name = name;
      this.operation = token;
      this.value = value;
      this.function = function;
    }

    public FunctionItem(List<SQLExpr> value, Function function) {
      this.value = value;
      this.function = function;
    }

    public Function getFunction() {
      return function;
    }

    public void setFunction(Function function) {
      this.function = function;
    }

    public String getName() {
      return name;
    }

    public void setName(String name) {
      this.name = name;
    }

    public List<SQLExpr> getValue() {
      return value;
    }

    public void setValue(List<SQLExpr> value) {
      this.value = value;
    }

    public String getOperation() {
      return operation;
    }

    public void setOperation(String operation) {
      this.operation = operation;
    }
  }

  private enum Function {
    NONE,
    TO_DATE,
    COMPRESS;

    private static final Map<String, Function> stringToEnum = Maps.newHashMap();

    public static Function fromString(String symbol) {
      return stringToEnum.get(symbol);
    }

    static {
      for (Function function : values()) {
        stringToEnum.put(function.toString(), function);
      }
    }
  }

  public List<FunctionItem> getFunctionItem(String collectionName, SchemaStatVisitor visitor) {
    String type = visitor.getTableStat(collectionName).toString();
    if (visitor == null || visitor.getFunctions() == null || visitor.getFunctions().size() == 0)
      return null;
    List<FunctionItem> items = Lists.newArrayList();
    List<SQLMethodInvokeExpr> functions = visitor.getFunctions();
    for (SQLMethodInvokeExpr function : functions) {
      String method = function.getMethodName();
      Function fun = Function.fromString(method.toUpperCase());
      String name = null;
      String operation = null;
      if ("Insert".equals(type)) {
        name = function.toString();
      } else {
        if (function.getParent() != null) {
          SQLBinaryOpExpr expr = ((SQLBinaryOpExpr) function.getParent());
          name = expr.getLeft().toString();
          operation = expr.getOperator().getName();
        }
      }
      items.add(new FunctionItem(name, operation, function.getParameters(), fun));
    }
    return items;
  }

  private Object getFunctionItemValue(FunctionItem item) {
    Object value = null;
    switch (item.getFunction()) {
      case COMPRESS:
        try {
          value = CommonUtil.compress(((SQLCharExpr) item.getValue().get(0)).getText());
        } catch (IOException e) {
          e.printStackTrace();
        }
        break;
      case TO_DATE:
        String date = ((SQLCharExpr) item.getValue().get(0)).getText();
        String format = ((SQLCharExpr) item.getValue().get(1)).getText().replaceAll("HH24", "HH").replaceAll("mi", "mm");
        value = DateUtil.parse(date, format);
        break;
    }
    return value;
  }

  public long count(String sql) {
    List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, "mysql");
    SQLStatement stmt = stmtList.get(0);
    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    stmt.accept(visitor);
    SQLSelectQuery query = ((SQLSelectStatement) stmt).getSelect().getQuery();
    MySqlSelectQueryBlock queryBlock = ((MySqlSelectQueryBlock) query);
    String collectionName = queryBlock.getFrom().toString();
    MongoCollection<Document> collection = mongoDB.getCollection(collectionName);
    BasicDBObject where = getWhereDBO(visitor, collectionName);
    logger.info("mongo select sql >> " + sql);
    logger.info("where conditions >> " + sql);
    logger.info(String.format("mongo ip [[%s]], port [[%s], db [[%s]], collection [[%s]]", mongoClient.getAddress().getHost(), mongoClient.getAddress().getPort(), mongoDB.getName(), collectionName));
    return collection.count(where);
  }

  public FindIterable<Document> select(String sql) {
    BasicDBObject projection = new BasicDBObject();
    BasicDBObject sort = new BasicDBObject();
    List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, "mysql");
    SQLStatement stmt = stmtList.get(0);
    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    stmt.accept(visitor);

    SQLSelectQuery query = ((SQLSelectStatement) stmt).getSelect().getQuery();
    MySqlSelectQueryBlock queryBlock = ((MySqlSelectQueryBlock) query);
    String collectionName = queryBlock.getFrom().toString();
    MongoCollection<Document> collection = mongoDB.getCollection(collectionName);
    List<SQLSelectItem> selectList = queryBlock.getSelectList();
    for (SQLSelectItem item : selectList) {
      if ("*".equals(item.toString())) break;
      projection.put(item.toString(), 1);
    }
    BasicDBObject where = getWhereDBO(visitor, collectionName);
    FindIterable<Document> result = collection.find(where);

    List<TableStat.Column> orders = visitor.getOrderByColumns();
    for (TableStat.Column order : orders) {
      int direction = convertSort(order.getAttributes().get("orderBy.type").toString());
      sort.put(order.getName(), direction);
    }
    if (sort.size() > 0) {
      result.sort(sort);
    }

    if (projection.size() > 0) {
      result.projection(projection);
    }
    if (queryBlock.getLimit() != null) {
      result.limit(((SQLIntegerExpr) queryBlock.getLimit().getRowCount()).getNumber().intValue());
    }
    logger.info(String.format("mongo ip [[%s]], port [[%s], db [[%s]], collection [[%s]]", mongoClient.getAddress().getHost(), mongoClient.getAddress().getPort(), mongoDB.getName(), collectionName));
    logger.info("mongo select sql >> " + sql);
    logger.info("projection conditions >> " + sort.toJson());
    logger.info("where conditions >> " + where.toJson());
    logger.info("sort conditions >> " + sort.toJson());
    return result;
  }

//fitnesse的QueryTable调用
  public List<Object> query() {
    StringBuffer column = new StringBuffer();
    StringBuffer value = new StringBuffer();
    List<Object> list = new ArrayList<Object>();
    FindIterable<Document> result = select(sql);
    Iterator<Document> it = result.iterator();
    boolean draw = true;

    while (it.hasNext()) {
      List<Object> listObject = new ArrayList<>();
      Document doc = it.next();
      Set<String> keys = doc.keySet();
      for (String key : keys) {
        String strValue = null;
        Object obj = doc.get(key);
        if (obj != null) {
          if (obj instanceof Binary) {
            strValue = JSONObject.toJSONString(obj);
          } else {
            strValue = String.valueOf(obj);
          }
        }
        listObject.add(asList(key.toLowerCase(), strValue));
        if (draw) {
          column.append("|");
          value.append("|");
          column.append(key.toLowerCase());
        }
      }
      draw = false;
      list.add(listObject);
    }

    logger.info(column.toString());
    logger.info(value.toString());
    return list;
  }

  private void putPair(BasicDBObject where, String operator, String key, Object value) {
    if (operator.equals(EQUAL)) {
      where.append(key, value);
    } else {
      if (where.containsField(key)) {
        BasicDBObject obj = (BasicDBObject) where.get(key);
        obj.append(convert(operator), value);
      } else {
        where.append(key, new BasicDBObject(convert(operator), value));
      }
    }
  }

  public void update(String sql) {
    List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, "mysql");
    MySqlUpdateStatement stmt = (MySqlUpdateStatement) stmtList.get(0);
    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    stmt.accept(visitor);
    List<SQLUpdateSetItem> items = stmt.getItems();
    Document updateDoc = new Document();
    for (SQLUpdateSetItem item : items) {
      if (item.getValue() instanceof SQLCharExpr) {
        updateDoc.put(item.getColumn().toString(), ((SQLCharExpr) item.getValue()).getText());
      } else if (item.getValue() instanceof SQLMethodInvokeExpr) {
        updateDoc.put(item.getColumn().toString(), getFunctionItemValue(new FunctionItem(((SQLMethodInvokeExpr) item.getValue()).getParameters(), Function.fromString(((SQLMethodInvokeExpr) item.getValue()).getMethodName().toUpperCase()))));
      }
    }
    Document update = new Document("$set", updateDoc);
    String collectionName = visitor.getCurrentTable();
    BasicDBObject where = getWhereDBO(visitor, collectionName);
    logger.info(String.format("mongo ip [[%s]], port [[%s], db [[%s]], collection [[%s]]", mongoClient.getAddress().getHost(), mongoClient.getAddress().getPort(), mongoDB.getName(), collectionName));
    logger.info("mongo update sql >> " + sql);
    logger.info("update items >> " + update.toJson());
    logger.info("where conditions >> " + where.toJson());
    MongoCollection<Document> collection = mongoDB.getCollection(collectionName);
    collection.updateMany(where, update);
  }

  public void delete(String sql) {
    List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, "mysql");
    MySqlDeleteStatement stmt = (MySqlDeleteStatement) stmtList.get(0);
    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    stmt.accept(visitor);
    String collectionName = visitor.getCurrentTable();
    BasicDBObject where = getWhereDBO(visitor, collectionName);
    logger.info(String.format("mongo ip [[%s]], port [[%s], db [[%s]], collection [[%s]]", mongoClient.getAddress().getHost(), mongoClient.getAddress().getPort(), mongoDB.getName(), collectionName));
    logger.info("mongo delete sql >> " + sql);
    logger.info("delete conditions >> " + where.toJson().toString());
    MongoCollection<Document> collection = mongoDB.getCollection(collectionName);
    collection.deleteMany(where);
  }

  private BasicDBObject getWhereDBO(SchemaStatVisitor visitor, String collectionName) {
    BasicDBObject where = new BasicDBObject();
    List<FunctionItem> functionItems = getFunctionItem(collectionName, visitor);
    List<TableStat.Condition> conds = visitor.getConditions();
    for (TableStat.Condition cond : conds) {
      String operation = cond.getOperator();
      String key = cond.getColumn().getName();
      Predicate<FunctionItem> match = x -> x.getName() != null && x.getOperation() != null && x.getName().equals(key) && x.getOperation().equals(operation);
      putConditionToWhereDBO(where, functionItems, cond, match, key);
    }
    return where;
  }

  private void putConditionToWhereDBO(BasicDBObject where, List<FunctionItem> functionItems, TableStat.Condition cond, Predicate<FunctionItem> match, String key) {
    String operator = cond.getOperator();
    Object value = null;
    boolean exists = functionItems != null && functionItems.stream().anyMatch(match);
    if (exists) {
      FunctionItem item = functionItems.stream().filter(match).collect(Collectors.toList()).get(0);
      value = getFunctionItemValue(item);
    }

    if (operator.equalsIgnoreCase("in")) {
      convert(operator);
      if (exists) {
        //TODO
      } else {
        where.append(key, new BasicDBObject(convert(operator), cond.getValues()));
      }
    } else {
      if (false == exists) {
        value = cond.getValues().get(0);
      }
      putPair(where, operator, key, value);
    }
  }

  public void insert(String sql) {
    List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, "mysql");
    MySqlInsertStatement stmt = (MySqlInsertStatement) stmtList.get(0);
    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    stmt.accept(visitor);
    List<SQLExpr> columns = stmt.getColumns();
    List<SQLExpr> values = stmt.getValuesList().get(0).getValues();
    String collectionName = visitor.getCurrentTable();
    Document doc = listToDocument(collectionName, visitor, columns, values);
    logger.info(String.format("mongo ip [[%s]], port [[%s], db [[%s]], collection [[%s]]", mongoClient.getAddress().getHost(), mongoClient.getAddress().getPort(), mongoDB.getName(), collectionName));
    logger.info("mongo insert sql >> " + sql);
    logger.info("insert items >> " + doc.toJson());
    MongoCollection<Document> collection = mongoDB.getCollection(collectionName);
    collection.insertOne(doc);
  }

  private Document listToDocument(String collectionName, SchemaStatVisitor visitor, List<SQLExpr> columns, List<SQLExpr> values) {
    List<FunctionItem> functionItems = getFunctionItem(collectionName, visitor);
    Document doc = new Document();
    for (int i = 0; i < columns.size(); i++) {
      Object value = values.get(i);
      Predicate<FunctionItem> matchName = x -> x.getName().equals(value.toString());
      boolean exsit = functionItems.stream().anyMatch(matchName);
      if (exsit) {
        FunctionItem item = functionItems.stream().filter(matchName).collect(Collectors.toList()).get(0);
        doc.put(columns.get(i).toString(), getFunctionItemValue(item));
      } else {
        if (value instanceof SQLCharExpr) {
          doc.put(columns.get(i).toString(), ((SQLCharExpr) value).getValue());
        } else if (value instanceof SQLNullExpr) {
          doc.put(columns.get(i).toString(), null);
        } else {
          doc.put(columns.get(i).toString(), value.toString());
        }

      }
    }
    return doc;
  }

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息