Reinforcement Learning Q-learning 算法学习-3
2015-12-31 21:00
483 查看
//Q-learning 源码分析。 import java.util.Random; public class QLearning1 { private static final int Q_SIZE = 6; private static final double GAMMA = 0.8; private static final int ITERATIONS = 10; private static final int INITIAL_STATES[] = new int[] {1, 3, 5, 2, 4, 0}; private static final int R[][] = new int[][] {{-1, -1, -1, -1, 0, -1}, {-1, -1, -1, 0, -1, 100}, {-1, -1, -1, 0, -1, -1}, {-1, 0, 0, -1, 0, -1}, {0, -1, -1, 0, -1, 100}, {-1, 0, -1, -1, 0, 100}}; private static int q[][] = new int[Q_SIZE][Q_SIZE]; private static int currentState = 0; private static void train() { initialize(); // Perform training, starting at all initial states. for(int j = 0; j < ITERATIONS; j++) { for(int i = 0; i < Q_SIZE; i++) { episode(INITIAL_STATES[i]); } // i } // j System.out.println("Q Matrix values:"); for(int i = 0; i < Q_SIZE; i++) { for(int j = 0; j < Q_SIZE; j++) { System.out.print(q[i][j] + ",\t"); } // j System.out.print("\n"); } // i System.out.print("\n"); return; } private static void test() { // Perform tests, starting at all initial states. System.out.println("Shortest routes from initial states:"); for(int i = 0; i < Q_SIZE; i++) { currentState = INITIAL_STATES[i]; int newState = 0; do { newState = maximum(currentState, true); System.out.print(currentState + ", "); currentState = newState; }while(currentState < 5); System.out.print("5\n"); } return; } private static void episode(final int initialState) { currentState = initialState; // Travel from state to state until goal state is reached. do { chooseAnAction(); }while(currentState == 5); // When currentState = 5, Run through the set once more for convergence. for(int i = 0; i < Q_SIZE; i++) { chooseAnAction(); } return; } private static void chooseAnAction() { int possibleAction = 0; // Randomly choose a possible action connected to the current state. possibleAction = getRandomAction(Q_SIZE); if(R[currentState][possibleAction] >= 0){ q[currentState][possibleAction] = reward(possibleAction); currentState = possibleAction; } return; } private static int getRandomAction(final int upperBound) { int action = 0; boolean choiceIsValid = false; // Randomly choose a possible action connected to the current state. while(choiceIsValid == false) { // Get a random value between 0(inclusive) and 6(exclusive). action = new Random().nextInt(upperBound); if(R[currentState][action] > -1){ choiceIsValid = true; } } return action; } private static void initialize() { for(int i = 0; i < Q_SIZE; i++) { for(int j = 0; j < Q_SIZE; j++) { q[i][j] = 0; } // j } // i return; } private static int maximum(final int State, final boolean ReturnIndexOnly) { // If ReturnIndexOnly = True, the Q matrix index is returned. // If ReturnIndexOnly = False, the Q matrix value is returned. int winner = 0; boolean foundNewWinner = false; boolean done = false; while(!done) { foundNewWinner = false; for(int i = 0; i < Q_SIZE; i++) { if(i != winner){ // Avoid self-comparison. if(q[State][i] > q[State][winner]){ winner = i; foundNewWinner = true; } } } if(foundNewWinner == false){ done = true; } } if(ReturnIndexOnly == true){ return winner; }else{ return q[State][winner]; } } private static int reward(final int Action) { return (int)(R[currentState][Action] + (GAMMA * maximum(Action, false))); } public static void main(String[] args) { train(); test(); return; } }
相关文章推荐
- mfc进制转换
- 07 打印1到最大的n位数
- 成为数据专家,你只差一个Quick Insights的距离
- centos 6.5 配置LDAP服务器+客户端!
- 递归下降分析程序
- linux 编译内核树
- SQLdiag-配置文件-PerfmonCollector
- Selenium2Library库文件的使用和简析
- Android极光推送入门
- 微软自拍:让黑科技拯救不会拍照的你
- LA5031 Graph and Queries (Treap模版)
- VMware虚拟机从一台电脑转移复制到另一台电脑的方法
- HttpServlet详解
- THINKING OF DEATH
- NSDictionary使用 ... enumerateKeysAndObjectsUsingBlock
- 华为预期2015年收入3900亿元 研发超1000亿元
- Android Support Design Library之TabLayout
- poj 1830 开关问题 高斯消元
- SyntaxError: Non-ASCII character '\xef' in file deinstall_mysql_5.7.py on line 8, but no encoding de
- java 枚举类型拾遗