Skip to main content
For concepts, architecture, and usage guidance, see the Sprocket overview.

sprocket.Sprocket

Base class for inference workers.
MethodSignatureDescription
setupsetup(self) -> NoneCalled once at startup. Load models and resources.
predictpredict(self, args: dict) -> dictCalled for each job. Process input and return output.
shutdownshutdown(self) -> NoneCalled on graceful shutdown. Clean up resources. Optional.
Class attributes:
AttributeTypeDefaultDescription
processorType[InputOutputProcessor]InputOutputProcessorCustom I/O processor class
warmup_inputslist[dict][]Inputs to run during cache warmup
import sprocket


class MyModel(sprocket.Sprocket):
    def setup(self) -> None:
        self.model = load_model()

    def predict(self, args: dict) -> dict:
        result = self.model(args["input"])
        return {"output": result}

    def shutdown(self) -> None:
        self.model.cleanup()


if __name__ == "__main__":
    sprocket.run(MyModel(), "my-org/my-model")

sprocket.run

Entry point for starting a Sprocket worker.
def run(sprocket: Sprocket, name: str, use_torchrun: bool = False) -> None:
ParameterTypeDescription
sprocketSprocketYour Sprocket instance
namestrDeployment name (used for queue routing)
use_torchrunboolEnable multi-GPU mode via torchrun. Default: False

sprocket.FileOutput

Wraps a local file path for automatic upload after predict() returns. Extends pathlib.PosixPath.
from sprocket import FileOutput


def predict(self, args):
    video.save("output.mp4")
    return {"video": FileOutput("output.mp4"), "duration": 10.5}
The FileOutput is replaced with the public URL in the final job result.

sprocket.emit_info

Report progress updates from inside predict(). Emitted data is available to clients via the info field on the job status endpoint.
from sprocket import emit_info

emit_info({"progress": 0.75, "current_frame": 45, "total_frames": 60})
ParameterTypeDescription
infodictProgress data to emit. Must serialize to under 4096 bytes JSON.
Updates are batched and merged (later values overwrite earlier ones for the same keys). When using use_torchrun=True, call emit_info() only from rank 0 to avoid duplicate updates.

sprocket.InputOutputProcessor

Override for custom file download/upload behavior. Attach to your Sprocket via the processor class attribute.

Custom I/O Processing

MethodSignatureDescription
process_input_fileprocess_input_file(self, resp: httpx.Response, dst: pathlib.Path) -> NoneCalled after downloading each input file. Write resp.content to dst.
finalizeasync finalize(self, request_id: str, inputs: dict, outputs: dict) -> dictCalled after predict(), before FileOutput upload. Return modified outputs.
Default behavior:
  • process_input_file: writes resp.content to dst
  • finalize: returns outputs unchanged
import gzip
import pathlib

import httpx
from sprocket import Sprocket, InputOutputProcessor


class CustomProcessor(InputOutputProcessor):
    def process_input_file(
        self, resp: httpx.Response, dst: pathlib.Path
    ) -> None:
        if dst.suffix == ".gz":
            decompressed = gzip.decompress(resp.content)
            dst.with_suffix("").write_bytes(decompressed)
        else:
            dst.write_bytes(resp.content)

    async def finalize(
        self, request_id: str, inputs: dict, outputs: dict
    ) -> dict:
        # Example: upload to S3 instead of Together storage
        video_path = outputs.pop("video")
        url = await self.upload_to_s3(video_path, bucket="my-bucket")
        outputs["url"] = url
        return outputs


class MyModel(Sprocket):
    processor = CustomProcessor

    def setup(self):
        pass

    def predict(self, args):
        return {"result": "done"}

HTTP Endpoints

EndpointMethodResponse
/healthGET200 {"status": "healthy"} or 503 {"status": "unhealthy"}
/metricsGETrequests_inflight 0.0 or 1.0 (Prometheus format)
/generatePOSTDirect HTTP inference (non-queue mode)

CLI Arguments

ArgumentDefaultDescription
--queuefalseEnable queue worker mode
--port8000HTTP server port

Environment Variables

VariableDefaultDescription
TOGETHER_API_KEYRequiredAPI key for queue authentication
TOGETHER_API_BASE_URLhttps://api.together.aiAPI base URL
TERMINATION_GRACE_PERIOD_SECONDS300Max time for graceful shutdown and prediction timeout
WORLD_SIZE1Number of GPU processes (set automatically by torchrun)

Complete Examples

Image Classification

import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
import sprocket


class ImageClassifier(sprocket.Sprocket):
    def setup(self) -> None:
        self.model = AutoModel.from_pretrained("model-name").to("cuda").eval()
        self.processor = AutoProcessor.from_pretrained("model-name")

    def predict(self, args: dict) -> dict:
        image = Image.open(args["image"])
        inputs = self.processor(images=image, return_tensors="pt").to("cuda")
        outputs = self.model(**inputs)
        return {"embeddings": outputs.last_hidden_state.mean(dim=1).tolist()}


if __name__ == "__main__":
    sprocket.run(ImageClassifier(), "my-org/classifier")

Video Generation with File Output

from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video
import sprocket


class VideoGenerator(sprocket.Sprocket):
    def setup(self) -> None:
        self.pipe = DiffusionPipeline.from_pretrained("model-name").to("cuda")

    def predict(self, args: dict) -> dict:
        video_frames = self.pipe(args["prompt"], num_frames=16).frames[0]
        export_to_video(video_frames, "output.mp4", fps=8)
        return {"video": sprocket.FileOutput("output.mp4")}


if __name__ == "__main__":
    sprocket.run(VideoGenerator(), "my-org/video-gen")

Multi-Model Pipeline

import sprocket


class SpeechToSpeech(sprocket.Sprocket):
    def setup(self) -> None:
        self.asr = load_whisper_model()
        self.llm = load_chat_model()
        self.tts = load_tts_model()

    def predict(self, args: dict) -> dict:
        transcript = self.asr.transcribe(args["audio"])
        response = self.llm.chat(transcript)
        self.tts.synthesize(response).save("response.wav")
        return {"audio": sprocket.FileOutput("response.wav")}


if __name__ == "__main__":
    sprocket.run(SpeechToSpeech(), "my-org/speech-to-speech")