add ship wreck detection

This commit is contained in:
sparshg
2024-09-19 21:05:25 +05:30
parent 0242a92ab2
commit 68764fc461
9 changed files with 244 additions and 151 deletions

122
src/board.rs Normal file
View File

@@ -0,0 +1,122 @@
use std::ops::{Deref, DerefMut};
use axum::Json;
use rand::Rng;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct Board(pub [[char; 10]; 10]);
impl From<Board> for Vec<String> {
fn from(board: Board) -> Self {
board.iter().map(|row| row.iter().collect()).collect()
}
}
impl From<Vec<String>> for Board {
fn from(board: Vec<String>) -> Self {
let mut arr = [['e'; 10]; 10];
for (i, row) in board.iter().enumerate() {
for (j, cell) in row.chars().enumerate() {
arr[i][j] = cell;
}
}
Board(arr)
}
}
impl Deref for Board {
type Target = [[char; 10]; 10];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Board {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Board {
const SHIPS: [i32; 5] = [5, 4, 3, 3, 2];
pub fn from_json(Json(board): Json<Board>) -> Self {
board
}
pub fn randomize() -> Self {
let mut board = Board([['e'; 10]; 10]);
for &length in Self::SHIPS.iter() {
loop {
let dir = rand::thread_rng().gen_bool(0.5);
let x = rand::thread_rng().gen_range(0..(if dir { 10 } else { 11 - length }));
let y = rand::thread_rng().gen_range(0..(if dir { 11 - length } else { 10 }));
if board.is_overlapping(x, y, length, dir) {
continue;
}
for i in 0..length {
let (tx, ty) = if dir { (x, y + i) } else { (x + i, y) };
board[tx as usize][ty as usize] = 's';
}
break;
}
}
board
}
fn is_overlapping(&self, x: i32, y: i32, length: i32, dir: bool) -> bool {
for i in -1..2 {
for j in -1..=length {
let (tx, ty) = if dir { (x + i, y + j) } else { (x + j, y + i) };
if !(0..10).contains(&tx) || !(0..10).contains(&ty) {
continue;
}
if self[tx as usize][ty as usize] != 'e' {
return true;
}
}
}
false
}
pub fn has_sunk(&self, (i, j): (usize, usize)) -> Option<[(usize, usize); 2]> {
let mut queue = vec![(i, j)];
let mut visited = vec![vec![false; 10]; 10];
let mut bounds = [(i, j), (i, j)];
visited[i][j] = true;
while let Some((x, y)) = queue.pop() {
if self[x][y] == 's' {
return None;
}
bounds[0].0 = bounds[0].0.min(x);
bounds[0].1 = bounds[0].1.min(y);
bounds[1].0 = bounds[1].0.max(x);
bounds[1].1 = bounds[1].1.max(y);
for (dx, dy) in [(-1, 0), (1, 0), (0, -1), (0, 1)].iter() {
let (tx, ty) = ((x as i32 + dx) as usize, (y as i32 + dy) as usize);
if (0..10).contains(&tx)
&& (0..10).contains(&ty)
&& !visited[tx][ty]
&& matches!(self[tx][ty], 'h' | 's')
{
visited[tx][ty] = true;
queue.push((tx, ty));
}
}
}
Some(bounds)
}
// fn validate_syntax(&self) -> bool {
// self
// .iter()
// .all(|row| row.iter().all(|cell| matches!(cell, 'e' | 'h' | 'm' | 's')))
// }
}
// pub async fn create_board_route(board: Json<Board>) -> Json<String> {
// let board = Board::from_json(board).await;
// Json(format!("{:?}", board))
// }

View File

@@ -1,11 +1,8 @@
use std::convert::Infallible;
use axum::Json;
use rand::Rng;
use serde::Deserialize;
use socketioxide::socket::Sid;
use thiserror::Error;
use crate::board::Board;
pub const ROOM_CODE_LENGTH: usize = 4;
pub type Result<T> = std::result::Result<T, Error>;
@@ -20,6 +17,8 @@ pub enum Error {
AlreadyInRoom,
#[error("Not in room")]
NotInRoom,
#[error("Invalid Move")]
InvalidMove,
#[error("SQL Error\n{0:?}")]
Sqlx(#[from] sqlx::Error),
}
@@ -85,25 +84,14 @@ pub async fn join_room(sid: Sid, code: String, pool: &sqlx::PgPool) -> Result<()
}
pub async fn add_board(sid: Sid, board: Board, pool: &sqlx::PgPool) -> Result<()> {
let query = format!(
"UPDATE players SET board = ARRAY[{}] WHERE id = '{}'",
board
.0
.map(|row| {
format!(
"ARRAY[{}]",
row.map(|x| format!("'{x}'"))
.into_iter()
.collect::<Vec<_>>()
.join(",")
)
})
.into_iter()
.collect::<Vec<String>>()
.join(","),
let board: Vec<String> = board.into();
sqlx::query!(
"UPDATE players SET board = $1 WHERE id = $2",
&board,
sid.as_str()
);
sqlx::query(&query).execute(pool).await?;
)
.execute(pool)
.await?;
Ok(())
}
@@ -137,7 +125,11 @@ pub async fn start(sid: Sid, code: String, pool: &sqlx::PgPool) -> Result<()> {
Ok(())
}
pub async fn attack(sid: Sid, (i, j): (usize, usize), pool: &sqlx::PgPool) -> Result<bool> {
pub async fn attack(
sid: Sid,
(i, j): (usize, usize),
pool: &sqlx::PgPool,
) -> Result<(bool, Option<[(usize, usize); 2]>)> {
let player = sqlx::query!(r"SELECT room_code FROM players WHERE id = $1", sid.as_str())
.fetch_one(pool)
.await?;
@@ -159,42 +151,42 @@ pub async fn attack(sid: Sid, (i, j): (usize, usize), pool: &sqlx::PgPool) -> Re
_ => return Err(Error::RoomNotFull), // room not full
};
let mut board: Board = sqlx::query!(r"SELECT board FROM players WHERE id = $1", other)
.fetch_one(pool)
.await?
.board
.unwrap()
.into();
let hit = match board[i][j] {
's' => true,
'e' => false,
_ => return Err(Error::InvalidMove),
};
board[i][j] = if hit { 'h' } else { 'm' };
let mut txn = pool.begin().await?;
let turn = sqlx::query!(
r"SELECT board[$1][$2] as HIT FROM players WHERE id = $3",
i as i32 + 1,
j as i32 + 1,
other
)
.fetch_one(&mut *txn)
.await?;
sqlx::query!(
r#"UPDATE players
SET board[$1][$2] = CASE
WHEN board[$1][$2] = 's' THEN 'h'
WHEN board[$1][$2] = 'e' THEN 'm'
ELSE board[$1][$2]
END
WHERE id = $3"#,
r#"UPDATE players SET board[$1] = $2 WHERE id = $3"#,
i as i32 + 1,
j as i32 + 1,
board[i].iter().collect::<String>(),
other
)
.execute(&mut *txn)
.await?;
sqlx::query!(
r#"UPDATE rooms SET stat = $1 WHERE code = $2"#,
to_status as Status,
player.room_code
)
.execute(&mut *txn)
.await?;
if !hit {
sqlx::query!(
r#"UPDATE rooms SET stat = $1 WHERE code = $2"#,
to_status as Status,
player.room_code
)
.execute(&mut *txn)
.await?;
}
txn.commit().await?;
Ok(turn.hit.unwrap() == "s")
Ok((hit, if hit { board.has_sunk((i, j)) } else { None }))
}
pub async fn disconnect(sid: Sid, pool: &sqlx::PgPool) -> Result<()> {
@@ -211,60 +203,3 @@ enum Status {
P1Turn,
P2Turn,
}
#[derive(Debug, Deserialize)]
pub struct Board(pub [[char; 10]; 10]);
impl Board {
const SHIPS: [i32; 5] = [5, 4, 3, 3, 2];
pub fn from_json(Json(board): Json<Board>) -> Self {
board
}
pub fn randomize() -> Self {
let mut board = Board([['e'; 10]; 10]);
for &length in Self::SHIPS.iter() {
loop {
let dir = rand::thread_rng().gen_bool(0.5);
let x = rand::thread_rng().gen_range(0..(if dir { 10 } else { 11 - length }));
let y = rand::thread_rng().gen_range(0..(if dir { 11 - length } else { 10 }));
if board.is_overlapping(x, y, length, dir) {
continue;
}
for i in 0..length {
let (tx, ty) = if dir { (x, y + i) } else { (x + i, y) };
board.0[tx as usize][ty as usize] = 's';
}
break;
}
}
board
}
fn is_overlapping(&self, x: i32, y: i32, length: i32, dir: bool) -> bool {
for i in -1..2 {
for j in -1..=length {
let (tx, ty) = if dir { (x + i, y + j) } else { (x + j, y + i) };
if !(0..10).contains(&tx) || !(0..10).contains(&ty) {
continue;
}
if self.0[tx as usize][ty as usize] != 'e' {
return true;
}
}
}
false
}
// fn validate_syntax(&self) -> bool {
// self.0
// .iter()
// .all(|row| row.iter().all(|cell| matches!(cell, 'e' | 'h' | 'm' | 's')))
// }
}
// pub async fn create_board_route(board: Json<Board>) -> Json<String> {
// let board = Board::from_json(board).await;
// Json(format!("{:?}", board.0))
// }

View File

@@ -1,9 +1,10 @@
mod board;
mod game;
use axum::Router;
use board::Board;
use dotenv::dotenv;
use futures_util::stream::StreamExt;
use game::{add_board, add_room, attack, disconnect, join_room, start, Board, ROOM_CODE_LENGTH};
use game::{add_board, add_room, attack, disconnect, join_room, start, ROOM_CODE_LENGTH};
use rand::Rng;
use socketioxide::{
extract::{Data, SocketRef, State},
@@ -33,7 +34,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.layer(layer)
.layer(CorsLayer::very_permissive());
let listener = TcpListener::bind("127.0.0.1:3000").await?;
let listener = TcpListener::bind("0.0.0.0:3000").await?;
println!("listening on {}", listener.local_addr()?);
axum::serve(listener, app).await?;
Ok(())
@@ -98,7 +99,6 @@ fn on_connect(socket: SocketRef) {
if let Err(e) = add_board(id, ack.data.pop().unwrap(), &pool).await
{
tracing::error!("{:?}", e);
return;
}
}
Err(err) => tracing::error!("Ack error, {}", err),
@@ -121,19 +121,19 @@ fn on_connect(socket: SocketRef) {
socket.on(
"attack",
|socket: SocketRef, Data::<[usize; 2]>([i, j]), pool: State<PgPool>| async move {
let res = match attack(socket.id, (i, j), &pool).await {
let (hit, sunk) = match attack(socket.id, (i, j), &pool).await {
Ok(res) => res,
Err(e) => {
tracing::error!("{:?}", e);
return;
}
};
tracing::info!("Attacking at: ({}, {}), result: {}", i, j, res);
tracing::info!("Attacking at: ({}, {}), result: {:?}", i, j, hit);
socket
.within(socket.rooms().unwrap().first().unwrap().clone())
.emit(
"attacked",
serde_json::json!({"by": socket.id.as_str(), "at": [i, j], "res": res}),
serde_json::json!({"by": socket.id.as_str(), "at": [i, j], "hit": hit, "sunk": sunk}),
)
.unwrap();
},
@@ -144,7 +144,6 @@ fn on_connect(socket: SocketRef) {
socket.leave_all().unwrap();
if let Err(e) = disconnect(socket.id, &pool).await {
tracing::error!("{:?}", e);
return;
}
});
}