OSDev.org
https://forum.osdev.org/

Looking for advice on minmax algorithm for Connect-4 Game
https://forum.osdev.org/viewtopic.php?f=13&t=56455
Page 1 of 1

Author:  Schol-R-LEA [ Tue Aug 30, 2022 8:25 pm ]
Post subject:  Looking for advice on minmax algorithm for Connect-4 Game

I've been working on the GUI portion of my C++ Connect-N game project, and have gotten it to would could charitably be called a working state. However, the game AI, not to put it to finely, stinks.

I was wondering if anyone here has more experience in implementing a minmax algorithm, and can offer any advice on how to get it to play a suitably challenging game.

The minmax code is:
Code:
#ifndef __CONNECTN_AI_H__
#define __CONNECTN_AI_H__ 1

#include <vector>
#include "connectnboard.h"

namespace ConnectN
{

    typedef intmax_t Weight;
    typedef uint8_t Depth;
    const Weight infinity = INTMAX_MAX, neg_infinity = INTMAX_MIN;

    struct Node
    {
        Board board;
        Grid_Size column;
        Weight weight;
        Depth depth;
        Node(const Board& b, Grid_Size c, Weight w, Depth d): board(b), column(c), weight(w), depth(d) {};
        Node(const Node& n);
        void operator=(const Node& n);
    };


    class Solver
    {
    private:
        Board& board;
        Depth search_depth;
        Node negamax(Node node, Depth depth, Weight alpha, Weight beta, Weight color);
        std::vector<Node> enumerate_moves(Node node);
        void order_moves(std::vector<Node>& children);
        Weight evaluate_move(Node& node);
        Weight evaluate_node(Node& node);
        Weight weight_column(Grid_Size c);

    public:
        Solver(Board& b, Depth d);
        Grid_Size move();

        friend class Board;
    };
}

#endif

and
Code:
#include <algorithm>
#include <chrono>
#include <cmath>
#include <random>
#include "connectnboard.h"
#include "connectn_AI.h"

namespace ConnectN
{

    /**
     *
     *
     */
    Node::Node(const Node& n)
    {
        this->board = n.board;
        this->column = n.column;
        this->weight = n.weight;
        this->depth = n.depth;
    }


    /**
     *
     *
     */
    void Node::operator=(const Node& n)
    {
        this->board = n.board;
        this->column = n.column;
        this->weight = n.weight;
        this->depth = n.depth;
    }


    /**
     *
     *
     */
    Solver::Solver(Board& b, Depth d = 3): board(b), search_depth(d)
    {
   
    }


    /**
     * 
     *
     */
    Grid_Size Solver::move()
    {
        Node top = {this->board, 0, neg_infinity, this->search_depth};
        Node found = this->negamax(top, this->search_depth, neg_infinity, infinity, 1);
        return found.column;
    }


    /**
     *
     *
     */
    Node Solver::negamax(Node node, Depth depth, Weight alpha, Weight beta, Weight color)
    {

        if (depth == 0 || board.win() == COMPUTER)
        {
            return (Node {node.board, node.column, color * node.weight, depth});
        }
        std::vector<Node> children = enumerate_moves(node);

        order_moves(children);
        Node value = {node.board, node.column, neg_infinity, depth};

        for (auto child : children)
        {
            value.column = child.column;
            Node recursive_value = negamax(child, depth - 1, -beta, -alpha, -color);
            value.weight = std::max(value.weight, -recursive_value.weight);
            alpha = std::max(value.weight, alpha);
            if (alpha >= beta)
            {
                break;
            }
        }

        return value;
    }


    /**
     *
     *
     */
    std::vector<Node> Solver::enumerate_moves(Node node)
    {
        std::vector<Node> children;
        for (Grid_Size i = 0; i < node.board.size(); i++)
        {
            Node child = {node.board, i, 0, static_cast<Depth>(node.depth-1)};
            if (!child.board.add_at(i))
            {
                continue;
            }
            if (child.depth != this->search_depth)
            {
                child.board.switch_player();
            }
            child.weight = evaluate_move(child);
            children.push_back(child);
        }
        return children;
    }


    /**
     *
     *
     */
    void Solver::order_moves(std::vector<Node>& children)
    {
        std::sort(children.begin(),
                  children.end(),
                  [](Node first, Node second)
                      {
                          return (first.weight < second.weight);
                      });
    }


    /**
     *
     *
     */
    Weight Solver::evaluate_move(Node& node)
    {
        Grid_Size size = node.board.size();
        Weight weight = 0;
        Grid_Size column = node.column;
        Grid_Size midpoint = std::ceil(size/2);

        // add weight depending on whether the token is on the
        // center column(s).
        if (size % 2)
        {
            if (node.column == midpoint)
            {
                weight += 4;
            }
        }
        else
        {
            // if there are two central columns, weight half as much
            // as you would if there is only one
            if (column == midpoint-1 || column == midpoint)
            {
                weight += 2;
            }
        }

        weight += evaluate_node(node);
        return weight;
    }

    /**
     * Weight Solver::weight_column(Grid_Size c) -
     * @param c - total number of
     */
    Weight Solver::weight_column(Grid_Size c)
    {
        Weight weight = 0;
        Grid_Size max = this->board.winning_count;
        Grid_Size midpoint = std::floor(max/2);

        if (c >= this->board.winning_count)
        {
            // winning move, weight it the maximum amount
            weight = 1000;
        }
        else
        {
            // for every token beyond half the winning amount,
            // increase the weighting by 2
            Weight weighted_count = c - midpoint;
            weight = std::max(0, static_cast<int>(weighted_count+1)) * 2;
        }

        return weight;
    }
       
    /**
     * Weight Solver::evaluate_column(Node& node)
     * - scan the neighboring square of a given square to see if there
     *   is a winning sequence.
     * @param node - version of board to evaluate
     */
    Weight Solver::evaluate_node(Node& node)
    {
        Weight weight = 0;
        Player p = node.board.current_player();
        Grid_Size size = node.board.size();   
        Grid_Size offset = node.board.winning_count,
            target = static_cast<Grid_Size>(node.board.winning_count - 1),
            row = std::max(0, node.board.column_top(node.column) - 1),
            column = node.column;

        bool check_up = (row < (size - target)),
            check_left = (column > offset),
            check_right = (column < (size - offset));

        if (check_right)
        {
            Grid_Size count = 1;
            for (Grid_Size c = column; c < std::min(size, column + offset) && (node.board.grid[row][c] == p); c++, count++)
            {
                // iterate through
            }
            weight += weight_column(count);
        }

        if (check_up)
        {
            Grid_Size count = 1;

            for (Grid_Size r = row; (r < std::min(size, row + offset)) && (node.board.grid[r][column] == p); r++, count++)
            {
                // iterate through
            }
            weight += weight_column(count);
        }

        if (check_left)
        {
            Grid_Size count = 1;
            for (Grid_Size r = row, c = column, count = 0;
                 (r < (row + offset)) && (c >= (column - target)) && (node.board.grid[r][c] == p);
                 r++, c--, count++)
            {
                // iterate through
            }
            weight += weight_column(count);
        }

        if (check_right)
        {
            Grid_Size count = 1;
            for (Grid_Size r = row, c = column, count = 0;
                 (r < std::min(size, row + offset)) && (c < std::min(size, column + offset)) && (node.board.grid[r][c] == p);
                 r++, c++, count++)
            {
                // iterate through
            }
            weight += weight_column(count);
        }

        return weight;
    }

}

Page 1 of 1 All times are UTC - 6 hours
Powered by phpBB © 2000, 2002, 2005, 2007 phpBB Group
http://www.phpbb.com/