ComfyUI-CutomNodes/repeat_node.py

249 lines
8.1 KiB
Python
Raw Normal View History

2025-12-23 19:19:33 -05:00
import os
import subprocess
import tempfile
from shutil import which
import numpy as np
def _extract_audio(audio):
"""
Returns (samples_ch_first, sr)
samples_ch_first: np.ndarray shape (channels, frames)
"""
sr = None
samples = None
if isinstance(audio, dict):
# sample rate key variants
sr = audio.get("sample_rate", None)
if sr is None:
sr = audio.get("sr", None)
if sr is None:
sr = audio.get("rate", None)
# avoid tensor truthiness checks (no `or` chaining)
samples = audio.get("samples", None)
if samples is None:
samples = audio.get("waveform", None)
if samples is None:
samples = audio.get("audio", None)
elif isinstance(audio, (tuple, list)) and len(audio) >= 2:
samples, sr = audio[0], audio[1]
else:
raise TypeError(f"Unsupported AUDIO type: {type(audio)}")
if sr is None or samples is None:
raise ValueError(
f"Could not extract samples/sample_rate from AUDIO: "
f"keys={list(audio.keys()) if isinstance(audio, dict) else 'n/a'}"
)
# torch -> numpy
try:
import torch
if isinstance(samples, torch.Tensor):
samples = samples.detach().cpu().numpy()
except Exception:
pass
samples = np.asarray(samples)
# Normalize to (channels, frames)
# Possible shapes:
# (frames,) -> (1, frames)
# (channels, frames) -> ok
# (frames, channels) -> transpose
# (batch, channels, frames) -> take batch[0]
# (batch, frames, channels) -> take batch[0] then transpose
if samples.ndim == 1:
samples = samples[None, :]
elif samples.ndim == 3:
# take first batch
samples = samples[0]
elif samples.ndim > 3:
# super defensive: reduce until <= 3
while samples.ndim > 3:
samples = samples[0]
if samples.ndim == 3:
samples = samples[0]
if samples.ndim != 2:
raise ValueError(f"Unsupported samples shape after normalization: {samples.shape}")
# If it's (frames, channels) transpose -> (channels, frames)
if samples.shape[0] > 8 and samples.shape[1] <= 8:
samples = samples.T
return samples, int(sr)
def _write_wav_int16(path, samples_ch_first, sample_rate):
import wave
s = np.clip(samples_ch_first, -1.0, 1.0)
s_i16 = (s * 32767.0).astype(np.int16)
if s_i16.ndim != 2:
raise ValueError(f"_write_wav_int16 expects 2D (ch, frames), got {s_i16.shape}")
channels, frames = s_i16.shape
interleaved = s_i16.T.reshape(-1)
with wave.open(path, "wb") as wf:
wf.setnchannels(channels)
wf.setsampwidth(2)
wf.setframerate(int(sample_rate))
wf.writeframes(interleaved.tobytes())
def _load_audio_to_comfy(path, target_sr=44100, target_channels=2):
"""
Decode audio file -> ComfyUI AUDIO dict.
ComfyUI save nodes expect:
audio["waveform"] : torch.Tensor [batch, channels, frames]
audio["sample_rate"] : int
"""
ffmpeg = which("ffmpeg")
if ffmpeg is None:
raise RuntimeError("ffmpeg not found in PATH inside the ComfyUI container.")
# Force known SR/ch so parsing is deterministic
cmd = [
ffmpeg,
"-i", path,
"-f", "f32le",
"-ac", str(int(target_channels)),
"-ar", str(int(target_sr)),
"pipe:1",
]
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if proc.returncode != 0:
raise RuntimeError(
"ffmpeg decode failed:\n" + proc.stderr.decode("utf-8", errors="ignore")
)
raw = np.frombuffer(proc.stdout, dtype=np.float32)
ch = int(target_channels)
if raw.size % ch != 0:
raw = raw[: raw.size - (raw.size % ch)]
samples = raw.reshape(-1, ch).T # (channels, frames)
import torch
waveform = torch.from_numpy(samples).unsqueeze(0) # (1, channels, frames)
return {"waveform": waveform, "sample_rate": int(target_sr)}
class AudioRepeatFromAudioNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"repeat_count": ("INT", {"default": 20, "min": 1, "max": 500}),
"output_audio_path": ("STRING", {"default": "/basedir/output/repeated.mp3"}),
},
"optional": {
"overwrite": ("BOOLEAN", {"default": True}),
"mp3_quality": ("INT", {"default": 0, "min": 0, "max": 9}),
"crossfade_seconds": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 5.0, "step": 0.01}),
},
}
RETURN_TYPES = ("AUDIO", "STRING")
RETURN_NAMES = ("audio_out", "output_audio_path")
FUNCTION = "repeat_audio"
CATEGORY = "audio"
def repeat_audio(self, audio, repeat_count, output_audio_path, overwrite=True, mp3_quality=0, crossfade_seconds=0.15):
ffmpeg = which("ffmpeg")
if ffmpeg is None:
raise RuntimeError("ffmpeg not found in PATH inside the ComfyUI container.")
repeat_count = int(repeat_count)
if repeat_count < 1:
raise ValueError("repeat_count must be >= 1")
crossfade_seconds = float(crossfade_seconds or 0.0)
out_dir = os.path.dirname(output_audio_path)
if out_dir and not os.path.isdir(out_dir):
os.makedirs(out_dir, exist_ok=True)
if os.path.exists(output_audio_path) and not overwrite:
audio_out = _load_audio_to_comfy(output_audio_path, target_sr=44100, target_channels=2)
return (audio_out, output_audio_path)
samples, sr = _extract_audio(audio)
frames = int(samples.shape[1])
duration_sec = frames / float(sr)
# Safety: crossfade must be shorter than clip
if crossfade_seconds >= duration_sec:
crossfade_seconds = max(0.0, duration_sec * 0.25)
with tempfile.TemporaryDirectory() as td:
in_wav = os.path.join(td, "input.wav")
_write_wav_int16(in_wav, samples, sr)
ext = os.path.splitext(output_audio_path)[1].lower()
# No crossfade (or only 1 repeat): simple transcode
if repeat_count == 1 or crossfade_seconds <= 0.0:
cmd = [ffmpeg, "-y" if overwrite else "-n", "-i", in_wav]
if ext == ".mp3":
cmd += ["-c:a", "libmp3lame", "-q:a", str(int(mp3_quality))]
elif ext == ".wav":
cmd += ["-c:a", "pcm_s16le"]
else:
cmd += ["-c:a", "libmp3lame", "-q:a", str(int(mp3_quality))]
cmd += [output_audio_path]
else:
# Chain acrossfade between repeated inputs
cmd = [ffmpeg, "-y" if overwrite else "-n"]
for _ in range(repeat_count):
cmd += ["-i", in_wav]
xf = crossfade_seconds
parts = [f"[0:a][1:a]acrossfade=d={xf}:c1=tri:c2=tri[a1]"]
for i in range(2, repeat_count):
parts.append(f"[a{i-1}][{i}:a]acrossfade=d={xf}:c1=tri:c2=tri[a{i}]")
aout = f"a{repeat_count-1}"
filter_complex = ";".join(parts)
cmd += ["-filter_complex", filter_complex, "-map", f"[{aout}]"]
if ext == ".mp3":
cmd += ["-c:a", "libmp3lame", "-q:a", str(int(mp3_quality))]
elif ext == ".wav":
cmd += ["-c:a", "pcm_s16le"]
else:
cmd += ["-c:a", "libmp3lame", "-q:a", str(int(mp3_quality))]
cmd += [output_audio_path]
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if proc.returncode != 0:
raise RuntimeError(
"ffmpeg processing failed.\n"
f"Command: {' '.join(cmd)}\n\n"
f"STDERR:\n{proc.stderr}"
)
# Return a Comfy-compatible AUDIO dict
audio_out = _load_audio_to_comfy(output_audio_path, target_sr=sr, target_channels=int(samples.shape[0]))
return (audio_out, output_audio_path)