Commit d11e3de9 authored by Maël Donnart's avatar Maël Donnart
Browse files

Added SARSA

parent 54474af7
......@@ -12,6 +12,9 @@
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Output of the graph generation
*.html
# Dependency directories (remove the comment below to include it)
# vendor/
......@@ -9,7 +9,7 @@ import (
// Entry point of our application
func main() {
// Generate seed for random package
// Generate seed for package random
rand.Seed(time.Now().UnixNano())
// Launch interactive menu
......
......@@ -7,6 +7,7 @@ import (
"strings"
)
// Print the board in ascii art
func PrintBoard(b [9]string) {
Clear()
......@@ -34,6 +35,7 @@ func PrintBoard(b [9]string) {
fmt.Println()
}
// Print the board in ascii art
func printBasicBoard(b [9]string) {
fmt.Println(b[1], "|", b[2], "|", b[3])
fmt.Println(b[0], "|", b[8], "|", b[4])
......@@ -41,6 +43,7 @@ func printBasicBoard(b [9]string) {
fmt.Println()
}
// Print the board in tic tac toe mode
func PrintHelp() {
Clear()
......@@ -55,14 +58,16 @@ func PrintHelp() {
fmt.Println()
}
// Print the help in tic tac toe mode
func printBasicHelp() {
fmt.Println("1|2|3")
fmt.Println("0|8|4")
fmt.Println("7|6|5")
fmt.Println("1 | 2 | 3")
fmt.Println("0 | 8 | 4")
fmt.Println("7 | 6 | 5")
fmt.Println()
}
func Reset(b *[9]string) {
// Reset the board
func ResetBoard(b *[9]string) {
*b = [9]string{
"O", "O", "O",
"O", "X", "X",
......@@ -70,8 +75,8 @@ func Reset(b *[9]string) {
}
}
// Get the position of the empty square
func GetEmptyPostionIndex(board [9]string) (int, error) {
for i, v := range board {
if v == " " {
return i, nil
......@@ -79,3 +84,25 @@ func GetEmptyPostionIndex(board [9]string) (int, error) {
}
return 0, errors.New("there is no empty square")
}
// Convert the actual board to an array
func BoardToState(board [9]string) (state string) {
for _, v := range board {
state = state + v
}
return state
}
// Initialize all action-states
func intializeQ(board [9]string) (stateActions []Q) {
states := Permutations(board[:])
for _, s := range states {
var stateAction Q
var state [9]string
copy(state[:], s[:9])
stateAction.state = BoardToState(state)
stateAction.actions = PermutationsActions()
stateActions = append(stateActions, stateAction)
}
return stateActions
}
......@@ -3,7 +3,6 @@ package mutorere
import (
"errors"
"fmt"
"math"
"os"
"strings"
......@@ -43,11 +42,11 @@ func ModeSelection() {
}
}
// Mode to train two artificial intelligences
// Mode to train the sarsa agent and plot statistics
func trainArtificialIntelligence() {
validate := func(input string) error {
if len(input) < 3 {
if len(input) < 2 {
return errors.New("you should specify a larger number for best results")
}
if ConvertStringToInteger(input) == -1 {
......@@ -59,7 +58,7 @@ func trainArtificialIntelligence() {
prompt := promptui.Prompt{
Label: "Number of games",
Validate: validate,
Default: "10000",
Default: "150",
}
result, err := prompt.Run()
......@@ -71,70 +70,58 @@ func trainArtificialIntelligence() {
// Create the game
var game Game
game.CreateNewGame(AiVersusAi)
game.CreateNewGame(AiVersusRandom)
total := ConvertStringToInteger(result)
var winrate []int
var epsilons []float64
var turns []int
var badMoves []int
var rewards []int
// Launch "total" games
for i := 0; i < total; i++ {
fmt.Println("Game ", i+1, " / ", total)
game.Start()
winrate = append(winrate, game.GetWinner())
turns = append(turns, game.turn)
epsilons = append(epsilons, game.firstPlayer.epsilon)
game.firstPlayer.epsilon = math.Max(game.firstPlayer.epsilon*(1.0-(1.0/float64(total))), 0.05)
badMoves = append(badMoves, game.firstPlayer.badMoves)
rewards = append(rewards, game.firstPlayer.rewards)
}
PlotEpsilonEvolution(epsilons)
PlotDuration(turns)
PlotWinrate(winrate)
PlotRewards(rewards)
PlotBadMoves(badMoves)
}
// Mode to play against a trained artificial intelligence
func challengeArtificialIntelligence() {
prompt := promptui.Select{
Label: "Select AI level",
Items: []string{"Easy", "Medium", "Hard"},
}
_, level, err := prompt.Run()
if err != nil {
fmt.Printf("Prompt failed %v\n", err)
os.Exit(1)
}
fmt.Print(level)
var n int
switch level {
case "Easy":
n = 100
case "Medium":
n = 1000
case "Hard":
n = 10000
default:
n = 1000
}
// Train the algorithm
// Create the training
var training Game
training.CreateNewGame(AiVersusAi)
training.CreateNewGame(AiVersusRandom)
for i := 0; i < n; i++ {
// Train the algorithm in 150 games
for i := 0; i < 150; i++ {
fmt.Println("Training ", i+1, " / ", 150)
training.Start()
training.firstPlayer.epsilon = math.Max(training.firstPlayer.epsilon, 0.05)
}
// Play against human
var game Game
game.secondPlayer = training.firstPlayer
game.secondPlayer.epsilon = 0.05
game.CreateNewGame(HumanVersusAi)
// Setting to AI what first AI have learned
game.firstPlayer.epsilonDecay = training.firstPlayer.epsilonDecay
game.firstPlayer.alpha = training.firstPlayer.alpha
game.firstPlayer.gamma = training.firstPlayer.gamma
game.firstPlayer.Q = training.firstPlayer.Q
// Force the sarsa agent to be in exploitation mode
game.firstPlayer.epsilon = 0.05
// Play the game
game.Start()
}
......
package mutorere
import (
"errors"
"fmt"
)
......@@ -15,7 +14,7 @@ const (
const (
HumanVersusHuman GameType = 0
AiVersusAi GameType = 1
AiVersusRandom GameType = 1
HumanVersusAi GameType = 2
)
......@@ -36,46 +35,38 @@ func (game *Game) Start() {
panic("Players are not defined")
}
// Reset parmeters before starting the actual game
Reset(&game.board)
game.turn = 0
game.state = InProgress
// ResetBoard parmeters before starting the actual game
game.Reset()
// Play moves while game is in progress
for game.state == InProgress {
err := game.Play()
if err != nil && game.isHumanPlayer() {
fmt.Println(err)
continue
}
game.Play()
}
// Print the result of the game if there is a human player
if game.isHumanPlayer() {
PrintBoard(game.board)
printBasicBoard(game.board)
fmt.Printf("Congrats, %s is the winner ! The game lasted %d turns.\n", game.actualPlayer.mark, game.turn)
}
}
// Play a movement
func (game *Game) Play() error {
func (game *Game) Play() {
// Recover all valid moves
moves := GetValidesMoves(game)
// If moves is empty then the game is finished
if len(moves) <= 0 {
game.state = Finished
game.SwitchPlayers()
return nil
return
}
move := game.actualPlayer.Play(game, moves)
if IsValideMove(moves, move) {
game.board[move.initialPosition] = " "
game.board[move.finalPosition] = game.actualPlayer.mark
game.SwitchPlayers()
game.turn++
return nil
}
return errors.New("try another move")
// Play a move and switch player turn
game.actualPlayer.Play(moves)
game.SwitchPlayers()
game.turn++
}
// Change turn
......@@ -89,6 +80,7 @@ func (game *Game) SwitchPlayers() {
// Creation of a new game according to the game type
func (game *Game) CreateNewGame(gameType GameType) {
game.Reset()
game.setPlayers(gameType)
}
......@@ -102,31 +94,32 @@ func (game *Game) setPlayers(gameType GameType) {
case HumanVersusHuman:
// First Player
game.firstPlayer.entity = HumanAgent
game.firstPlayer.game = game
// Second Player
game.secondPlayer.entity = HumanAgent
case AiVersusAi:
game.secondPlayer.game = game
case AiVersusRandom:
// First Player
game.firstPlayer.entity = SarsaAgent
game.firstPlayer.epsilon = 1
game.firstPlayer.alpha = 0.85
game.firstPlayer.gamma = 0.95
game.firstPlayer.game = game
game.firstPlayer.epsilon = 0.7
game.firstPlayer.epsilonDecay = 0.0001
game.firstPlayer.alpha = 0.35
game.firstPlayer.gamma = 0.9
game.firstPlayer.Q = intializeQ(game.board)
// Second Player
game.secondPlayer.entity = SarsaAgent
game.secondPlayer.epsilon = 0.4
game.secondPlayer.alpha = 0.1
game.secondPlayer.gamma = 0.3
game.secondPlayer.entity = RandomAgent
game.secondPlayer.game = game
case HumanVersusAi:
// First Player
game.firstPlayer.entity = HumanAgent
game.firstPlayer.entity = SarsaAgent
game.firstPlayer.game = game
// Second Player
game.secondPlayer.entity = SarsaAgent
game.secondPlayer.epsilon = 0.5
game.secondPlayer.alpha = 0.1
game.secondPlayer.gamma = 0.3
game.secondPlayer.entity = HumanAgent
game.secondPlayer.game = game
default:
panic("No game type provided")
}
......@@ -134,6 +127,13 @@ func (game *Game) setPlayers(gameType GameType) {
game.actualPlayer = &game.firstPlayer
}
// Reset all parameters linked to one game
func (game *Game) Reset() {
ResetBoard(&game.board)
game.turn = 0
game.state = InProgress
}
// Get the winner of the actual game
func (game *Game) GetWinner() int {
if game.state == Finished {
......@@ -150,3 +150,9 @@ func (game *Game) GetWinner() int {
func (game *Game) isHumanPlayer() bool {
return game.firstPlayer.entity == HumanAgent || game.secondPlayer.entity == HumanAgent
}
// Update the board
func (game *Game) UpdateBoard(move Move) {
game.board[move.initialPosition] = " "
game.board[move.finalPosition] = game.actualPlayer.mark
}
......@@ -39,7 +39,7 @@ func PlotEpsilonEvolution(epsilons []float64) {
}
// Plot the duration in turns according to the number of games
func PlotDuration(turns []int) {
func PlotRewards(rewards []int) {
// create a new line instance
line := charts.NewLine()
......@@ -49,43 +49,53 @@ func PlotDuration(turns []int) {
Theme: types.ThemeInfographic,
}),
charts.WithTitleOpts(opts.Title{
Title: "Duration",
Subtitle: "Number of turns over " + strconv.Itoa(len(turns)) + " games",
Title: "Rewards",
Subtitle: "Evolution of rewards in " + strconv.Itoa(len(rewards)) + " games",
}),
)
// Create x axis
var xAxis []string
for i := 0; i < len(turns); i++ {
for i := 0; i < len(rewards); i++ {
xAxis = append(xAxis, strconv.Itoa(i))
}
// Put data into instance
line.SetXAxis(xAxis).
AddSeries("turns", generateLineItemsWithInteger(turns)).
AddSeries("reward", generateLineItemsWithInteger(rewards)).
SetSeriesOptions(charts.WithLineChartOpts(opts.LineChart{Smooth: true}))
f, _ := os.Create("duration.html")
f, _ := os.Create("rewards.html")
_ = line.Render(f)
}
// Plot the winrate according to the number of games
func PlotWinrate(winrate []int) {
// create a new bar instance
bar := charts.NewBar()
// Plot the duration in turns according to the number of games
func PlotBadMoves(badMoves []int) {
// create a new line instance
line := charts.NewLine()
// set some global options like Title/Legend/ToolTip or anything else
bar.SetGlobalOptions(charts.WithTitleOpts(opts.Title{
Title: "Winrate",
Subtitle: "Winrate over " + strconv.Itoa(len(winrate)) + " games",
}))
line.SetGlobalOptions(
charts.WithInitializationOpts(opts.Initialization{
Theme: types.ThemeInfographic,
}),
charts.WithTitleOpts(opts.Title{
Title: "Bad Moves",
Subtitle: "Evolution of bad moves in " + strconv.Itoa(len(badMoves)) + " games",
}),
)
// Put data into instance
bar.SetXAxis([]string{"Win", "Loose"}).
AddSeries("Winrate", generateBarItems(winrate))
// Create x axis
var xAxis []string
for i := 0; i < len(badMoves); i++ {
xAxis = append(xAxis, strconv.Itoa(i))
}
// Where the magic happens
f, _ := os.Create("winrate.html")
bar.Render(f)
// Put data into instance
line.SetXAxis(xAxis).
AddSeries("bad moves", generateLineItemsWithInteger(badMoves)).
SetSeriesOptions(charts.WithLineChartOpts(opts.LineChart{Smooth: true}))
f, _ := os.Create("badmoves.html")
_ = line.Render(f)
}
// Generate line items with an array of floats
......@@ -105,22 +115,3 @@ func generateLineItemsWithInteger(data []int) []opts.LineData {
}
return items
}
// Generate bar items with an array of integers
func generateBarItems(data []int) []opts.BarData {
items := make([]opts.BarData, 0)
var win int
var lose int
for i := 0; i < len(data); i++ {
if data[i] == 1 {
win++
} else if data[i] == -1 {
lose++
}
}
items = append(items, opts.BarData{Value: win})
items = append(items, opts.BarData{Value: lose})
return items
}
package mutorere
import (
"errors"
"fmt"
"math/rand"
)
......@@ -13,55 +14,170 @@ const (
SarsaAgent Entity = 2
)
const (
IllegalMove int = -2
LegalMove int = 100
)
type Player struct {
mark string
entity Entity
epsilon float64
alpha float32 // Learning rate
gamma float32 // Decay rate
Qmatrix [86][86]float32
mark string
entity Entity
game *Game
epsilon float64
epsilonDecay float64
alpha float64 // Learning rate
gamma float64 // Decay rate
Q []Q
rewards int
badMoves int
}
type Q struct {
state string
actions []Action
}
type Action struct {
value float64
action Move
}
// Ask a movement to the player
func (player *Player) askforplay() Move {
func (player *Player) askforplay(moves []Move) {
var move Move
var kewai int
var newPosition int
valid := false
fmt.Printf("It's %s turn to play. ", player.mark)
for !valid {
if Help() {
PrintHelp()
}
// Display the map
printBasicBoard(player.game.board)
// Display help
if Help() {
printBasicHelp()
}
fmt.Println("Enter Kewai to move:")
fmt.Scan(&kewai)
fmt.Printf("It's %s turn to play. ", player.mark)
fmt.Println("Enter a new position for kewai:")
fmt.Scan(&newPosition)
fmt.Println("Enter Kewai to move:")
fmt.Scan(&kewai)
return Move{kewai, newPosition}
fmt.Println("Enter a new position for kewai:")
fmt.Scan(&newPosition)
move = Move{kewai, newPosition}
// if not valid print an error and ask again for a move
if IsValideMove(moves, move) {
valid = true
} else {
fmt.Println(errors.New("try another move"))
}
}
// Update the board
player.game.UpdateBoard(move)
}
// Play a random legal move
func (player *Player) playRandomMove(moves []Move) Move {
return moves[rand.Intn(len(moves))]
func (player *Player) playRandomMove(moves []Move) {
player.game.UpdateBoard(moves[rand.Intn(len(moves))])
}
// Make agent learn rules with SARSA algorithm
func (player *Player) learn(moves []Move) {
valid := false
var state string
var newState string
var action Move
var newAction Move
// Decrease epsilon
player.decreaseEpsilon(player.epsilon > 0.1, player.epsilonDecay)
// Determine actual state and action
state = BoardToState(player.game.board)
action = player.epsilonGreedy(state)
// while the determined action is not valid
for !valid {
// If it's a valid move
if IsValideMove(moves, action) {
player.game.UpdateBoard(action)
valid = true
}
// Determine next state and action
newState = BoardToState(player.game.board)
newAction := player.epsilonGreedy(newState)
// If it's not a valid move increase epsilon and update Q
if !valid {
player.badMoves++
player.decreaseEpsilon(player.epsilon < 0.9, float64(IllegalMove)*player.epsilonDecay)
player.update(state, newState, action, newAction, IllegalMove)
}
state = newState
action = newAction
}
// If it's a valid move decrease epsilon and update Q
player.decreaseEpsilon(player