Skip to content

Commit

Permalink
refactor: switch to agents
Browse files Browse the repository at this point in the history
  • Loading branch information
cs50victor committed Feb 22, 2024
1 parent c40ac79 commit 21dc61f
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 188 deletions.
3 changes: 2 additions & 1 deletion lkgpt/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ pub fn run_llm(
.build()
.unwrap();

async_runtime.rt.block_on(async {
let rt = async_runtime.rt.clone();
rt.block_on(async {
let mut gpt_resp_stream =
llm_channel.client.chat().create_stream(request).await.unwrap();
while let Some(result) = gpt_resp_stream.next().await {
Expand Down
212 changes: 57 additions & 155 deletions lkgpt/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod controls;
mod frame_capture;
mod llm;
mod room_events;
mod server;
mod stt;
mod tts;
Expand Down Expand Up @@ -31,7 +32,8 @@ use livekit::{

use bevy::{
app::ScheduleRunnerPlugin, core::Name, core_pipeline::tonemapping::Tonemapping, log::LogPlugin,
prelude::*, render::renderer::RenderDevice, time::common_conditions::on_timer,
prelude::*, render::renderer::RenderDevice, tasks::AsyncComputeTaskPool,
time::common_conditions::on_timer,
};
use bevy_panorbit_camera::{PanOrbitCamera, PanOrbitCameraPlugin};

Expand All @@ -51,7 +53,8 @@ use serde::{Deserialize, Serialize};
use stt::STT;

use crate::{
controls::WorldControlChannel, llm::LLMChannel, server::RoomData, tts::TTS, video::VideoChannel,
controls::WorldControlChannel, llm::LLMChannel, room_events::handle_room_events,
server::RoomData, tts::TTS, video::VideoChannel,
};

pub const LIVEKIT_API_SECRET: &str = "LIVEKIT_API_SECRET";
Expand All @@ -69,7 +72,8 @@ pub struct AsyncRuntime {
impl FromWorld for AsyncRuntime {
fn from_world(_world: &mut World) -> Self {
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
AsyncRuntime { rt: std::sync::Arc::new(rt) }

Self { rt: std::sync::Arc::new(rt) }
}
}

Expand Down Expand Up @@ -140,140 +144,6 @@ pub struct LivekitRoom {
room_events: tokio::sync::mpsc::UnboundedReceiver<RoomEvent>,
}

// SYSTEM
pub fn handle_room_events(
async_runtime: Res<AsyncRuntime>,
llm_channel: Res<llm::LLMChannel>,
stt_client: ResMut<STT>,
_video_channel: Res<video::VideoChannel>,
audio_syncer: ResMut<AudioSync>,
mut room_events: ResMut<LivekitRoom>,
_single_frame_data: ResMut<crate::StreamingFrameData>,
) {
while let Ok(event) = room_events.room_events.try_recv() {
println!("\n\n🤡received room event {:?}", event);
match event {
RoomEvent::TrackSubscribed { track, publication: _, participant: _user } => {
match track {
RemoteTrack::Audio(audio_track) => {
let audio_rtc_track = audio_track.rtc_track();
let mut audio_stream = NativeAudioStream::new(audio_rtc_track);
let audio_should_stop = audio_syncer.should_stop.clone();
let stt_client = stt_client.clone();
async_runtime.rt.spawn(async move {
while let Some(frame) = audio_stream.next().await {
if audio_should_stop.load(Ordering::Relaxed) {
continue;
}

let audio_buffer = frame
.data
.iter()
.map(|sample| sample.to_sample::<u8>())
.collect::<Vec<u8>>();

if audio_buffer.is_empty() {
warn!("empty audio frame | {:#?}", audio_buffer);
continue;
}

if let Err(e) = stt_client.send(audio_buffer) {
error!("Couldn't send audio frame to stt {e}");
};
}
});
},
RemoteTrack::Video(video_track) => {
let video_rtc_track = video_track.rtc_track();
let pixel_size = 4;
let mut video_stream = NativeVideoStream::new(video_rtc_track);

async_runtime.rt.spawn(async move {
// every 10 video frames
let mut i = 0;
while let Some(frame) = video_stream.next().await {
log::error!("🤡received video frame | {:#?}", frame);
// VIDEO FRAME BUFFER (i420_buffer)
let video_frame_buffer = frame.buffer.to_i420();
let width = video_frame_buffer.width();
let height = video_frame_buffer.height();
let rgba_stride = video_frame_buffer.width() * pixel_size;

let (stride_y, stride_u, stride_v) = video_frame_buffer.strides();
let (data_y, data_u, data_v) = video_frame_buffer.data();

let rgba_buffer = RgbaImage::new(width, height);
let rgba_raw = unsafe {
std::slice::from_raw_parts_mut(
rgba_buffer.as_raw().as_ptr() as *mut u8,
rgba_buffer.len(),
)
};

livekit::webrtc::native::yuv_helper::i420_to_rgba(
data_y,
stride_y,
data_u,
stride_u,
data_v,
stride_v,
rgba_raw,
rgba_stride,
video_frame_buffer.width() as i32,
video_frame_buffer.height() as i32,
);

if let Err(e) = rgba_buffer.save(format!("camera/{i}.png")) {
log::error!("Couldn't save video frame {e}");
};
i += 1;
}
info!("🤡ended video thread");
});
},
};
},
RoomEvent::DataReceived { payload, kind, topic: _, participant: _ } => {
if kind == DataPacketKind::Reliable {
if let Some(payload) = payload.as_ascii() {
let room_text: serde_json::Result<RoomText> =
serde_json::from_str(payload.as_str());
match room_text {
Ok(room_text) => {
if let Err(e) =
llm_channel.tx.send(format!("[chat]{} ", room_text.message))
{
error!("Couldn't send the text to gpt {e}")
};
},
Err(e) => {
warn!("Couldn't deserialize room text. {e:#?}");
},
}

info!("text from room {:#?}", payload.as_str());
}
}
},
// ignoring the participant for now, currently assuming there is only one participant
RoomEvent::TrackMuted { participant: _, publication: _ } => {
audio_syncer.should_stop.store(true, Ordering::Relaxed);
},
RoomEvent::TrackUnmuted { participant: _, publication: _ } => {
audio_syncer.should_stop.store(false, Ordering::Relaxed);
},
// RoomEvent::ActiveSpeakersChanged { speakers } => {
// if speakers.is_empty() {
// audio_syncer.should_stop.store(true, Ordering::Relaxed);
// }
// let is_main_participant_muted = speakers.iter().any(|speaker| speaker.name() != "kitt");
// audio_syncer.should_stop.store(is_main_participant_muted, Ordering::Relaxed);
// }
_ => info!("received room event {:?}", event),
}
}
}

pub struct TracksPublicationData {
pub video_src: NativeVideoSource,
pub video_pub: LocalTrackPublication,
Expand Down Expand Up @@ -335,7 +205,7 @@ fn setup_gaussian_cloud(
) {
// let remote_file = Some("https://huggingface.co/datasets/cs50victor/splats/resolve/main/train/point_cloud/iteration_7000/point_cloud.gcloud");
// TODO: figure out how to load remote files later
let splat_file = "splats/train/point_cloud/iteration_7000/point_cloud.gcloud";
let splat_file = "splats/bonsai/point_cloud/iteration_7000/point_cloud.gcloud";
log::info!("loading {}", splat_file);
let cloud = asset_server.load(splat_file.to_string());

Expand All @@ -350,11 +220,16 @@ fn setup_gaussian_cloud(
String::from("main_scene"),
);

commands.spawn((GaussianSplattingBundle { cloud, ..default() }, Name::new("gaussian_cloud")));
let gs = GaussianSplattingBundle { cloud, ..default() };
commands.spawn((gs, Name::new("gaussian_cloud")));

commands.spawn((
Camera3dBundle {
transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
transform: Transform {
translation: Vec3::new(-0.59989005, -0.88360703, -2.0863006),
rotation: Quat::from_xyzw(-0.97177905, -0.026801618, 0.13693734, -0.1901983),
scale: Vec3::new(1.0, 1.0, 1.0),
},
tonemapping: Tonemapping::None,
camera: Camera { target: render_target, ..default() },
..default()
Expand All @@ -375,6 +250,7 @@ pub fn sync_bevy_and_server_resources(
mut server_state_clone: ResMut<AppStateSync>,
mut set_app_state: ResMut<NextState<AppState>>,
scene_controller: Res<SceneController>,
audio_syncer: Res<AudioSync>,
) {
if !server_state_clone.dirty {
let participant_room_name = &(server_state_clone.state.lock().0).clone();
Expand All @@ -399,22 +275,46 @@ pub fn sync_bevy_and_server_resources(

info!("initializing required bevy resources");

let tts = async_runtime.rt.block_on(TTS::new(audio_src)).unwrap();
let llm_channel = LLMChannel::new();
let llm_tx = llm_channel.tx.clone();
let llm_channel_tx = llm_tx.clone();

let tts = async_runtime.rt.block_on(TTS::new(audio_src)).unwrap();
let stt = async_runtime.rt.block_on(STT::new(llm_tx)).unwrap();

let video_channel = VideoChannel::new();
commands.insert_resource(llm_channel);
commands.init_resource::<WorldControlChannel>();

let stt = async_runtime.rt.block_on(STT::new(llm_tx)).unwrap();
commands.insert_resource(stt);
commands.insert_resource(stt.clone());

commands.init_resource::<VideoChannel>();
commands.insert_resource(tts);
commands.insert_resource(stream_frame_data);
commands.insert_resource(livekit_room);
// commands.insert_resource(livekit_room);

set_app_state.set(AppState::Active);

let audio_syncer = audio_syncer.should_stop.clone();
let rt = async_runtime.rt.clone();
async_runtime.rt.spawn(handle_room_events(
rt,
llm_channel_tx,
stt,
video_channel,
audio_syncer,
livekit_room,
4,
));
/*
async_runtime: Res<AsyncRuntime>,
llm_channel: Res<llm::LLMChannel>,
stt_client: ResMut<STT>,
_video_channel: Res<video::VideoChannel>,
audio_syncer: ResMut<AudioSync>,
mut room_events: ResMut<LivekitRoom>,
single_frame_data: ResMut<crate::StreamingFrameData>,
*/
server_state_clone.dirty = true;
},
Err(e) => {
Expand Down Expand Up @@ -488,18 +388,18 @@ fn main() {
app.init_resource::<frame_capture::scene::SceneController>();
app.add_event::<frame_capture::scene::SceneController>();

// app.add_systems(Update, move_camera);
app.add_systems(Update, move_camera);

app.add_systems(Update, server::shutdown_bevy_remotely);

app.add_systems(
Update,
handle_room_events
.run_if(resource_exists::<llm::LLMChannel>())
.run_if(resource_exists::<stt::STT>())
.run_if(resource_exists::<video::VideoChannel>())
.run_if(resource_exists::<LivekitRoom>()),
);
// app.add_systems(
// Update,
// room_events::handle_room_events
// .run_if(resource_exists::<llm::LLMChannel>())
// .run_if(resource_exists::<stt::STT>())
// .run_if(resource_exists::<video::VideoChannel>())
// .run_if(resource_exists::<LivekitRoom>()),
// );

app.add_systems(
Update,
Expand All @@ -514,13 +414,15 @@ fn main() {
sync_bevy_and_server_resources.run_if(on_timer(std::time::Duration::from_secs(2))),
);

// app.add_systems(OnEnter(AppState::Active), setup_gaussian_cloud);
app.add_systems(OnEnter(AppState::Active), setup_gaussian_cloud);

app.run();
}

fn move_camera(mut camera: Query<&mut Transform, With<Camera>>) {
for mut transform in camera.iter_mut() {
transform.translation.x += 5.0;
transform.translation.x += 0.0005;
transform.translation.y += 0.0005;
transform.translation.z += 0.0005;
}
}
Loading

0 comments on commit 21dc61f

Please sign in to comment.