diff --git a/plugins/stashAI/stashai.js b/plugins/stashAI/stashai.js index 759af2df..a018784b 100644 --- a/plugins/stashAI/stashai.js +++ b/plugins/stashAI/stashai.js @@ -1,7 +1,8 @@ (function () { "use strict"; - let STASHMARKER_API_URL = "https://cc1234-stashtag.hf.space/api/predict"; + let STASHMARKER_API_URL = "https://cc1234-stashtag-onnx.hf.space/gradio_api/call/predict_tags"; + let STASHMARKER_API_MARKER = "https://cc1234-stashtag-onnx.hf.space/gradio_api/call/predict_markers"; var OPTIONS = [ "Anal", @@ -2543,6 +2544,80 @@ }); } + async function gradioCall(url, image, vtt, threshold, retries = 3) { + for (let attempt = 0; attempt < retries; attempt++) { + try { + return await _gradioCall(url, image, vtt, threshold); + } catch (err) { + if (attempt === retries - 1) throw err; + await new Promise((r) => setTimeout(r, 3000)); + } + } + } + + async function _gradioCall(url, image, vtt, threshold) { + const body = { + data: [ + { url: image, meta: { _type: "gradio.FileData" } }, + vtt, + threshold, + ], + }; + + const response = await fetch(url, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + + if (!response.ok) { + throw new Error("HTTP " + response.status); + } + + const { event_id } = await response.json(); + const sseUrl = url + "/" + event_id; + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 120000); + + let text; + try { + const resp = await fetch(sseUrl, { signal: controller.signal }); + if (!resp.ok) { + throw new Error("HTTP " + resp.status); + } + text = await resp.text(); + } finally { + clearTimeout(timeout); + } + + let currentEvent = ""; + let currentData = ""; + + for (const line of text.split("\n")) { + if (line.startsWith("event: ")) { + currentEvent = line.slice(7).trim(); + } else if (line.startsWith("data: ")) { + currentData = line.slice(6); + } else if (line === "") { + if (currentEvent === "complete") { + try { + return JSON.parse(currentData); + } catch (e) { + throw new Error("Failed to parse result"); + } + } + if (currentEvent === "error") { + throw new Error(currentData || "API error"); + } + currentEvent = ""; + currentData = ""; + } + } + + throw new Error("No result received"); + } + function instance$3($$self, $$props, $$invalidate) { let { $$slots: slots = {}, $$scope } = $$props; validate_slots("MarkerButton", slots, []); @@ -2569,51 +2644,23 @@ let vtt = await download(vtt_url); - // query the api with a threshold of 0.4 as we want to do the filtering ourselves - var data = { data: [image, vtt, 0.4] }; - - fetch(STASHMARKER_API_URL + "_1", { - method: "POST", - headers: { - "Content-Type": "application/json; charset=utf-8", - }, - body: JSON.stringify(data), - }) - .then((response) => { - if (response.status !== 200) { - $$invalidate(0, (scanner = false)); - alert( - "Something went wrong. It's likely a server issue, Please try again later." - ); - return; - } - - return response.json(); - }) - .then((data) => { - $$invalidate(0, (scanner = false)); - let frames = data.data[0]; - $$invalidate(0, (scanner = false)); - - if (frames.length === 0) { - alert("No tags found"); - return; - } + try { + let result = await gradioCall(STASHMARKER_API_MARKER, image, vtt, 0.4); + let frames = result[0]; - // find a div with class row - let row = document.querySelector(".row"); + $$invalidate(0, (scanner = false)); - new MarkerMatches({ target: row, props: { frames, url } }); - }) - .catch((error) => { - $$invalidate(0, (scanner = false)); + if (!frames || frames.length === 0) { + alert("No tags found"); + return; + } - if (error.message === "") { - alert("Error: Service may be down. please try again later."); - } else { - alert("Error: " + error.message); - } - }); + let row = document.querySelector(".row"); + new MarkerMatches({ target: row, props: { frames, url } }); + } catch (error) { + $$invalidate(0, (scanner = false)); + alert("Error: " + (error.message || "Service may be down. Please try again later.")); + } } const writable_props = []; @@ -4027,52 +4074,28 @@ reader.readAsDataURL(vblob); }); - // query the api with a threshold of 0.2 as we want to do the filtering ourselves - var data = { data: [image, vtt, 0.2] }; - - fetch(STASHMARKER_API_URL, { - method: "POST", - headers: { - "Content-Type": "application/json; charset=utf-8", - }, - body: JSON.stringify(data), - }) - .then((response) => { - if (response.status !== 200) { - $$invalidate(0, (scanner = false)); - alert( - "Something went wrong. It's likely a server issue, Please try again later." - ); - return; - } + try { + let result = await gradioCall(STASHMARKER_API_URL, image, vtt, 0.2); + let tags = {}; + result.forEach((item) => Object.assign(tags, item)); - return response.json(); - }) - .then((data) => { - $$invalidate(0, (scanner = false)); + $$invalidate(0, (scanner = false)); - if (data.data[0].length === 0) { - alert("No tags found"); - return; - } + if (Object.keys(tags).length === 0) { + alert("No tags found"); + return; + } - // grab stash-tag-threshold from local storage or set to default - let threshold = localStorage.getItem("stash-tag-threshold") || 0.4; + let threshold = localStorage.getItem("stash-tag-threshold") || 0.4; - new TagMatches({ - target: document.body, - props: { matches: data.data[0], url, threshold }, - }); - }) - .catch((error) => { - $$invalidate(0, (scanner = false)); - - if (error.message === "") { - alert("Error: Service may be down. please try again later."); - } else { - alert("Error: " + error.message); - } + new TagMatches({ + target: document.body, + props: { matches: tags, url, threshold }, }); + } catch (error) { + $$invalidate(0, (scanner = false)); + alert("Error: " + (error.message || "Service may be down. Please try again later.")); + } } const writable_props = []; diff --git a/plugins/stashAI/stashai.yml b/plugins/stashAI/stashai.yml index 15975dd6..7111176b 100644 --- a/plugins/stashAI/stashai.yml +++ b/plugins/stashAI/stashai.yml @@ -12,4 +12,4 @@ ui: - stashai.css csp: connect-src: - - "https://cc1234-stashtag.hf.space" + - "https://cc1234-stashtag-onnx.hf.space"