use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
extract::State,
response::IntoResponse,
Router, routing::get,
};
use tokio::sync::broadcast;
use futures::{SinkExt, StreamExt};
use std::sync::Arc;
struct AppState {
tx: broadcast::Sender<String>,
}
async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
async fn handle_socket(mut socket: WebSocket, state: Arc<AppState>) {
let mut rx = state.tx.subscribe();
let (mut sender, mut receiver) = socket.split();
let tx = state.tx.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
println!("收到: {}", text);
let _ = tx.send(text.to_string());
}
});
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
if sender.send(Message::Text(msg.into())).await.is_err() {
break;
}
}
});
tokio::select! {
_ = &mut recv_task => send_task.abort(),
_ = &mut send_task => recv_task.abort(),
}
}
#[tokio::main]
async fn main() {
let (tx, _) = broadcast::channel(100);
let state = Arc::new(AppState { tx });
let app = Router::new()
.route("/ws", get(ws_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
.await
.unwrap();
println!("聊天室运行在 ws://localhost:3000/ws");
axum::serve(listener, app).await.unwrap();
}