import RecorderNode from "./recorder-node.js"; const { useState, useEffect, useCallback, useRef } = React; const { createMachine, assign } = XState; const { useMachine } = XStateReact; const SILENT_DELAY = 4000; // in milliseconds const CANCEL_OLD_AUDIO = false; // TODO: set this to true after cancellations don't terminate containers. const INITIAL_MESSAGE = "Hi! I'm a language model running on Modal. Talk to me using your microphone, and remember to turn your speaker volume up!"; const INDICATOR_TYPE = { TALKING: "talking", SILENT: "silent", GENERATING: "generating", IDLE: "idle", }; const MODELS = [ { id: "zephyr-7b-beta-4bit", label: "Zephyr 7B beta (4-bit)" }, // { id: "vicuna-13b-4bit", label: "Vicuna 13B (4-bit)" }, // { id: "alpaca-lora-7b", label: "Alpaca LORA 7B" }, ]; const chatMachine = createMachine( { initial: "botDone", context: { pendingSegments: 0, transcript: "", messages: 1, }, states: { botGenerating: { on: { GENERATION_DONE: { target: "botDone", actions: "resetTranscript" }, }, }, botDone: { on: { TYPING_DONE: { target: "userSilent", actions: ["resetPendingSegments", "incrementMessages"], }, SEGMENT_RECVD: { target: "userTalking", actions: [ "resetPendingSegments", "segmentReceive", "incrementMessages", ], }, }, }, userTalking: { on: { SILENCE: { target: "userSilent" }, SEGMENT_RECVD: { actions: "segmentReceive" }, TRANSCRIPT_RECVD: { actions: "transcriptReceive" }, }, }, userSilent: { on: { SOUND: { target: "userTalking" }, SEGMENT_RECVD: { actions: "segmentReceive" }, TRANSCRIPT_RECVD: { actions: "transcriptReceive" }, }, after: [ { delay: SILENT_DELAY, target: "botGenerating", actions: "incrementMessages", cond: "canGenerate", }, { delay: SILENT_DELAY, target: "userSilent", }, ], }, }, }, { actions: { segmentReceive: assign({ pendingSegments: (context) => context.pendingSegments + 1, }), transcriptReceive: assign({ pendingSegments: (context) => context.pendingSegments - 1, transcript: (context, event) => { console.log(context, event); return context.transcript + event.transcript; }, }), resetPendingSegments: assign({ pendingSegments: 0 }), incrementMessages: assign({ messages: (context) => context.messages + 1, }), resetTranscript: assign({ transcript: "" }), }, guards: { canGenerate: (context) => { console.log(context); return context.pendingSegments === 0 && context.transcript.length > 0; }, }, } ); function Sidebar({ selected, isTortoiseOn, isMicOn, setIsMicOn, setIsTortoiseOn, onModelSelect, }) { return ( ); } function BotIcon() { return ( {/*! Font Awesome Pro 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc.*/} ); } function UserIcon() { return ( {/*! Font Awesome Pro 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc.*/} ); } function FaceIcon() { return ( {/*! Font Awesome Pro 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc.*/} ); } function MicOnIcon() { return ( {/*! Font Awesome Pro 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc.*/} ); } function MicOffIcon() { return ( {/*! Font Awesome Pro 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc.*/} ); } function TalkingSpinner({ isUser }) { return (
span]:" + (isUser ? "bg-yellow-500" : "bg-primary") } > {" "} {" "}
); } function LoadingSpinner() { return (
{[...Array(12)].map((_, i) => (
))}
); } function ChatMessage({ text, isUser, indicator }) { return (
{isUser ? : }
{indicator == INDICATOR_TYPE.TALKING && ( )} {indicator == INDICATOR_TYPE.GENERATING && }
{text || (isUser ? "Speak into your microphone to talk to the bot..." : "Bot is typing...")}
); } class PlayQueue { constructor(audioContext, onChange) { this.call_ids = []; this.audioContext = audioContext; this._onChange = onChange; this._isProcessing = false; this._indicators = {}; } async add(item) { this.call_ids.push(item); this.play(); } _updateState(idx, indicator) { this._indicators[idx] = indicator; this._onChange(this._indicators); } _onEnd(idx) { this._updateState(idx, INDICATOR_TYPE.IDLE); this._isProcessing = false; this.play(); } async play() { if (this._isProcessing || this.call_ids.length === 0) { return; } this._isProcessing = true; const [payload, idx, isTts] = this.call_ids.shift(); this._updateState(idx, INDICATOR_TYPE.GENERATING); if (!isTts) { const audio = new SpeechSynthesisUtterance(payload); audio.onend = () => this._onEnd(idx); this._updateState(idx, INDICATOR_TYPE.TALKING); window.speechSynthesis.speak(audio); return; } const call_id = payload; console.log("Fetching audio for call", call_id, idx); let response; let success = false; while (true) { response = await fetch(`/audio/${call_id}`); if (response.status === 202) { continue; } else if (response.status === 204) { console.error("No audio found for call: " + call_id); break; } else if (!response.ok) { console.error("Error occurred fetching audio: " + response.status); } else { success = true; break; } } if (!success) { this._onEnd(idx); return; } const arrayBuffer = await response.arrayBuffer(); const audioBuffer = await this.audioContext.decodeAudioData(arrayBuffer); const source = this.audioContext.createBufferSource(); source.buffer = audioBuffer; source.connect(this.audioContext.destination); source.onended = () => this._onEnd(idx); this._updateState(idx, INDICATOR_TYPE.TALKING); source.start(); } clear() { for (const [call_id, _, isTts] of this.call_ids) { if (isTts) { fetch(`/audio/${call_id}`, { method: "DELETE" }); } } this.call_ids = []; } } async function fetchTranscript(buffer) { const blob = new Blob([buffer], { type: "audio/float32" }); const response = await fetch("/transcribe", { method: "POST", body: blob, headers: { "Content-Type": "audio/float32" }, }); if (!response.ok) { console.error("Error occurred during transcription: " + response.status); } return await response.json(); } async function* fetchGeneration(noop, input, history, isTortoiseOn) { const body = noop ? { noop: true, tts: isTortoiseOn } : { input, history, tts: isTortoiseOn }; const response = await fetch("/generate", { method: "POST", body: JSON.stringify(body), headers: { "Content-Type": "application/json" }, }); if (!response.ok) { console.error("Error occurred during submission: " + response.status); } if (noop) { return; } const readableStream = response.body; const decoder = new TextDecoder(); const reader = readableStream.getReader(); while (true) { const { done, value } = await reader.read(); if (done) { break; } for (let message of decoder.decode(value).split("\x1e")) { if (message.length === 0) { continue; } const { type, value: payload } = JSON.parse(message); yield { type, payload }; } } reader.releaseLock(); } function App() { const [history, setHistory] = useState([]); const [fullMessage, setFullMessage] = useState(INITIAL_MESSAGE); const [typedMessage, setTypedMessage] = useState(""); const [model, setModel] = useState(MODELS[0].id); const [botIndicators, setBotIndicators] = useState({}); const [state, send, service] = useMachine(chatMachine); const [isMicOn, setIsMicOn] = useState(true); const [isTortoiseOn, setIsTortoiseOn] = useState(false); const recorderNodeRef = useRef(null); const playQueueRef = useRef(null); useEffect(() => { const subscription = service.subscribe((state, event) => { console.log("Transitioned to state:", state.value, state.context); if (event && event.type == "TRANSCRIPT_RECVD") { setFullMessage( (m) => m + (m ? event.transcript : event.transcript.trimStart()) ); } }); return subscription.unsubscribe; }, [service]); const generateResponse = useCallback( async (noop, input = "") => { if (!noop) { recorderNodeRef.current.stop(); } console.log("Generating response", input, history); let firstAudioRecvd = false; for await (let { type, payload } of fetchGeneration( noop, input, history.slice(1), isTortoiseOn )) { if (type === "text") { setFullMessage((m) => m + payload); } else if (type === "audio") { if (!firstAudioRecvd && CANCEL_OLD_AUDIO) { playQueueRef.current.clear(); firstAudioRecvd = true; } playQueueRef.current.add([payload, history.length + 1, true]); } else if (type === "sentence") { playQueueRef.current.add([payload, history.length + 1, false]); } } if (!isTortoiseOn && playQueueRef.current) { while ( playQueueRef.current.call_ids.length || playQueueRef.current._isProcessing ) { await new Promise((r) => setTimeout(r, 100)); } } console.log("Finished generating response"); if (!noop) { recorderNodeRef.current.start(); send("GENERATION_DONE"); } }, [history, isTortoiseOn] ); useEffect(() => { const transition = state.context.messages > history.length + 1; if (transition && state.matches("botGenerating")) { generateResponse(/* noop = */ false, fullMessage); } if (transition) { setHistory((h) => [...h, fullMessage]); setFullMessage(""); setTypedMessage(""); } }, [state, history, fullMessage]); const onSegmentRecv = useCallback( async (buffer) => { if (buffer.length) { send("SEGMENT_RECVD"); } // TODO: these can get reordered const data = await fetchTranscript(buffer); if (buffer.length) { send({ type: "TRANSCRIPT_RECVD", transcript: data }); } }, [history] ); async function onMount() { // Warm up GPU functions. onSegmentRecv(new Float32Array()); generateResponse(/* noop = */ true); const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); const context = new AudioContext(); const source = context.createMediaStreamSource(stream); await context.audioWorklet.addModule("processor.js"); const recorderNode = new RecorderNode( context, onSegmentRecv, () => send("SILENCE"), () => send("SOUND") ); recorderNodeRef.current = recorderNode; source.connect(recorderNode); recorderNode.connect(context.destination); playQueueRef.current = new PlayQueue(context, setBotIndicators); } useEffect(() => { onMount(); }, []); const tick = useCallback(() => { if (!recorderNodeRef.current) { return; } if (typedMessage.length < fullMessage.length) { const n = 1; // Math.round(Math.random() * 3) + 3; setTypedMessage(fullMessage.substring(0, typedMessage.length + n)); if (typedMessage.length + n == fullMessage.length) { send("TYPING_DONE"); } } }, [typedMessage, fullMessage]); useEffect(() => { const intervalId = setInterval(tick, 20); return () => clearInterval(intervalId); }, [tick]); const onModelSelect = (id) => { setModel(id); }; useEffect(() => { if (recorderNodeRef.current) { console.log("Mic", isMicOn); if (isMicOn) { recorderNodeRef.current.start(); } else { recorderNodeRef.current.stop(); } } }, [isMicOn]); useEffect(() => { if (playQueueRef.current && !isTortoiseOn) { console.log("Canceling future audio calls"); playQueueRef.current.clear(); } }, [isTortoiseOn]); const isUserLast = history.length % 2 == 1; let userIndicator = INDICATOR_TYPE.IDLE; if (isUserLast) { userIndicator = state.matches("userTalking") ? INDICATOR_TYPE.TALKING : INDICATOR_TYPE.SILENT; } useEffect(() => { console.log("Bot indicator changed", botIndicators); }, [botIndicators]); return (
{history.map((msg, i) => ( ))}
); } const container = document.getElementById("react"); ReactDOM.createRoot(container).render();