import RecorderNode from "./recorder-node.js";
const { useState, useEffect, useCallback, useRef } = React;
const { createMachine, assign } = XState;
const { useMachine } = XStateReact;
const SILENT_DELAY = 4000; // in milliseconds
const CANCEL_OLD_AUDIO = false; // TODO: set this to true after cancellations don't terminate containers.
const INITIAL_MESSAGE =
"Hi! I'm a language model running on Modal. Talk to me using your microphone, and remember to turn your speaker volume up!";
const INDICATOR_TYPE = {
TALKING: "talking",
SILENT: "silent",
GENERATING: "generating",
IDLE: "idle",
};
const MODELS = [
{ id: "zephyr-7b-beta-4bit", label: "Zephyr 7B beta (4-bit)" },
// { id: "vicuna-13b-4bit", label: "Vicuna 13B (4-bit)" },
// { id: "alpaca-lora-7b", label: "Alpaca LORA 7B" },
];
const chatMachine = createMachine(
{
initial: "botDone",
context: {
pendingSegments: 0,
transcript: "",
messages: 1,
},
states: {
botGenerating: {
on: {
GENERATION_DONE: { target: "botDone", actions: "resetTranscript" },
},
},
botDone: {
on: {
TYPING_DONE: {
target: "userSilent",
actions: ["resetPendingSegments", "incrementMessages"],
},
SEGMENT_RECVD: {
target: "userTalking",
actions: [
"resetPendingSegments",
"segmentReceive",
"incrementMessages",
],
},
},
},
userTalking: {
on: {
SILENCE: { target: "userSilent" },
SEGMENT_RECVD: { actions: "segmentReceive" },
TRANSCRIPT_RECVD: { actions: "transcriptReceive" },
},
},
userSilent: {
on: {
SOUND: { target: "userTalking" },
SEGMENT_RECVD: { actions: "segmentReceive" },
TRANSCRIPT_RECVD: { actions: "transcriptReceive" },
},
after: [
{
delay: SILENT_DELAY,
target: "botGenerating",
actions: "incrementMessages",
cond: "canGenerate",
},
{
delay: SILENT_DELAY,
target: "userSilent",
},
],
},
},
},
{
actions: {
segmentReceive: assign({
pendingSegments: (context) => context.pendingSegments + 1,
}),
transcriptReceive: assign({
pendingSegments: (context) => context.pendingSegments - 1,
transcript: (context, event) => {
console.log(context, event);
return context.transcript + event.transcript;
},
}),
resetPendingSegments: assign({ pendingSegments: 0 }),
incrementMessages: assign({
messages: (context) => context.messages + 1,
}),
resetTranscript: assign({ transcript: "" }),
},
guards: {
canGenerate: (context) => {
console.log(context);
return context.pendingSegments === 0 && context.transcript.length > 0;
},
},
}
);
function Sidebar({
selected,
isTortoiseOn,
isMicOn,
setIsMicOn,
setIsTortoiseOn,
onModelSelect,
}) {
return (
);
}
function BotIcon() {
return (
);
}
function UserIcon() {
return (
);
}
function FaceIcon() {
return (
);
}
function MicOnIcon() {
return (
);
}
function MicOffIcon() {
return (
);
}
function TalkingSpinner({ isUser }) {
return (
span]:" + (isUser ? "bg-yellow-500" : "bg-primary")
}
>
{" "}
{" "}
);
}
function LoadingSpinner() {
return (
{[...Array(12)].map((_, i) => (
))}
);
}
function ChatMessage({ text, isUser, indicator }) {
return (
{isUser ? : }
{indicator == INDICATOR_TYPE.TALKING && (
)}
{indicator == INDICATOR_TYPE.GENERATING &&
}
{text ||
(isUser
? "Speak into your microphone to talk to the bot..."
: "Bot is typing...")}
);
}
class PlayQueue {
constructor(audioContext, onChange) {
this.call_ids = [];
this.audioContext = audioContext;
this._onChange = onChange;
this._isProcessing = false;
this._indicators = {};
}
async add(item) {
this.call_ids.push(item);
this.play();
}
_updateState(idx, indicator) {
this._indicators[idx] = indicator;
this._onChange(this._indicators);
}
_onEnd(idx) {
this._updateState(idx, INDICATOR_TYPE.IDLE);
this._isProcessing = false;
this.play();
}
async play() {
if (this._isProcessing || this.call_ids.length === 0) {
return;
}
this._isProcessing = true;
const [payload, idx, isTts] = this.call_ids.shift();
this._updateState(idx, INDICATOR_TYPE.GENERATING);
if (!isTts) {
const audio = new SpeechSynthesisUtterance(payload);
audio.onend = () => this._onEnd(idx);
this._updateState(idx, INDICATOR_TYPE.TALKING);
window.speechSynthesis.speak(audio);
return;
}
const call_id = payload;
console.log("Fetching audio for call", call_id, idx);
let response;
let success = false;
while (true) {
response = await fetch(`/audio/${call_id}`);
if (response.status === 202) {
continue;
} else if (response.status === 204) {
console.error("No audio found for call: " + call_id);
break;
} else if (!response.ok) {
console.error("Error occurred fetching audio: " + response.status);
} else {
success = true;
break;
}
}
if (!success) {
this._onEnd(idx);
return;
}
const arrayBuffer = await response.arrayBuffer();
const audioBuffer = await this.audioContext.decodeAudioData(arrayBuffer);
const source = this.audioContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(this.audioContext.destination);
source.onended = () => this._onEnd(idx);
this._updateState(idx, INDICATOR_TYPE.TALKING);
source.start();
}
clear() {
for (const [call_id, _, isTts] of this.call_ids) {
if (isTts) {
fetch(`/audio/${call_id}`, { method: "DELETE" });
}
}
this.call_ids = [];
}
}
async function fetchTranscript(buffer) {
const blob = new Blob([buffer], { type: "audio/float32" });
const response = await fetch("/transcribe", {
method: "POST",
body: blob,
headers: { "Content-Type": "audio/float32" },
});
if (!response.ok) {
console.error("Error occurred during transcription: " + response.status);
}
return await response.json();
}
async function* fetchGeneration(noop, input, history, isTortoiseOn) {
const body = noop
? { noop: true, tts: isTortoiseOn }
: { input, history, tts: isTortoiseOn };
const response = await fetch("/generate", {
method: "POST",
body: JSON.stringify(body),
headers: { "Content-Type": "application/json" },
});
if (!response.ok) {
console.error("Error occurred during submission: " + response.status);
}
if (noop) {
return;
}
const readableStream = response.body;
const decoder = new TextDecoder();
const reader = readableStream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
for (let message of decoder.decode(value).split("\x1e")) {
if (message.length === 0) {
continue;
}
const { type, value: payload } = JSON.parse(message);
yield { type, payload };
}
}
reader.releaseLock();
}
function App() {
const [history, setHistory] = useState([]);
const [fullMessage, setFullMessage] = useState(INITIAL_MESSAGE);
const [typedMessage, setTypedMessage] = useState("");
const [model, setModel] = useState(MODELS[0].id);
const [botIndicators, setBotIndicators] = useState({});
const [state, send, service] = useMachine(chatMachine);
const [isMicOn, setIsMicOn] = useState(true);
const [isTortoiseOn, setIsTortoiseOn] = useState(false);
const recorderNodeRef = useRef(null);
const playQueueRef = useRef(null);
useEffect(() => {
const subscription = service.subscribe((state, event) => {
console.log("Transitioned to state:", state.value, state.context);
if (event && event.type == "TRANSCRIPT_RECVD") {
setFullMessage(
(m) => m + (m ? event.transcript : event.transcript.trimStart())
);
}
});
return subscription.unsubscribe;
}, [service]);
const generateResponse = useCallback(
async (noop, input = "") => {
if (!noop) {
recorderNodeRef.current.stop();
}
console.log("Generating response", input, history);
let firstAudioRecvd = false;
for await (let { type, payload } of fetchGeneration(
noop,
input,
history.slice(1),
isTortoiseOn
)) {
if (type === "text") {
setFullMessage((m) => m + payload);
} else if (type === "audio") {
if (!firstAudioRecvd && CANCEL_OLD_AUDIO) {
playQueueRef.current.clear();
firstAudioRecvd = true;
}
playQueueRef.current.add([payload, history.length + 1, true]);
} else if (type === "sentence") {
playQueueRef.current.add([payload, history.length + 1, false]);
}
}
if (!isTortoiseOn && playQueueRef.current) {
while (
playQueueRef.current.call_ids.length ||
playQueueRef.current._isProcessing
) {
await new Promise((r) => setTimeout(r, 100));
}
}
console.log("Finished generating response");
if (!noop) {
recorderNodeRef.current.start();
send("GENERATION_DONE");
}
},
[history, isTortoiseOn]
);
useEffect(() => {
const transition = state.context.messages > history.length + 1;
if (transition && state.matches("botGenerating")) {
generateResponse(/* noop = */ false, fullMessage);
}
if (transition) {
setHistory((h) => [...h, fullMessage]);
setFullMessage("");
setTypedMessage("");
}
}, [state, history, fullMessage]);
const onSegmentRecv = useCallback(
async (buffer) => {
if (buffer.length) {
send("SEGMENT_RECVD");
}
// TODO: these can get reordered
const data = await fetchTranscript(buffer);
if (buffer.length) {
send({ type: "TRANSCRIPT_RECVD", transcript: data });
}
},
[history]
);
async function onMount() {
// Warm up GPU functions.
onSegmentRecv(new Float32Array());
generateResponse(/* noop = */ true);
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
const context = new AudioContext();
const source = context.createMediaStreamSource(stream);
await context.audioWorklet.addModule("processor.js");
const recorderNode = new RecorderNode(
context,
onSegmentRecv,
() => send("SILENCE"),
() => send("SOUND")
);
recorderNodeRef.current = recorderNode;
source.connect(recorderNode);
recorderNode.connect(context.destination);
playQueueRef.current = new PlayQueue(context, setBotIndicators);
}
useEffect(() => {
onMount();
}, []);
const tick = useCallback(() => {
if (!recorderNodeRef.current) {
return;
}
if (typedMessage.length < fullMessage.length) {
const n = 1; // Math.round(Math.random() * 3) + 3;
setTypedMessage(fullMessage.substring(0, typedMessage.length + n));
if (typedMessage.length + n == fullMessage.length) {
send("TYPING_DONE");
}
}
}, [typedMessage, fullMessage]);
useEffect(() => {
const intervalId = setInterval(tick, 20);
return () => clearInterval(intervalId);
}, [tick]);
const onModelSelect = (id) => {
setModel(id);
};
useEffect(() => {
if (recorderNodeRef.current) {
console.log("Mic", isMicOn);
if (isMicOn) {
recorderNodeRef.current.start();
} else {
recorderNodeRef.current.stop();
}
}
}, [isMicOn]);
useEffect(() => {
if (playQueueRef.current && !isTortoiseOn) {
console.log("Canceling future audio calls");
playQueueRef.current.clear();
}
}, [isTortoiseOn]);
const isUserLast = history.length % 2 == 1;
let userIndicator = INDICATOR_TYPE.IDLE;
if (isUserLast) {
userIndicator = state.matches("userTalking")
? INDICATOR_TYPE.TALKING
: INDICATOR_TYPE.SILENT;
}
useEffect(() => {
console.log("Bot indicator changed", botIndicators);
}, [botIndicators]);
return (
{history.map((msg, i) => (
))}
);
}
const container = document.getElementById("react");
ReactDOM.createRoot(container).render();