Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@ import (
"syscall"
"time"

"github.com/docker/model-runner/pkg/gpuinfo"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/backends/mlx"
"github.com/docker/model-runner/pkg/inference/backends/vllm"
"github.com/docker/model-runner/pkg/inference/config"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/docker/model-runner/pkg/metrics"
Expand Down Expand Up @@ -65,15 +63,6 @@ func main() {
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
}

gpuInfo := gpuinfo.New(llamaServerPath)

sysMemInfo, err := memory.NewSystemMemoryInfo(log, gpuInfo)
if err != nil {
log.Fatalf("unable to initialize system memory info: %v", err)
}

memEstimator := memory.NewEstimator(sysMemInfo)

// Create a proxy-aware HTTP transport
// Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment
var baseTransport *http.Transport
Expand All @@ -93,7 +82,6 @@ func main() {
log,
clientConfig,
nil,
memEstimator,
)
modelManager := models.NewManager(log.WithFields(logrus.Fields{"component": "model-manager"}), clientConfig)
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
Expand All @@ -118,12 +106,6 @@ func main() {
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
}

if os.Getenv("MODEL_RUNNER_RUNTIME_MEMORY_CHECK") == "1" {
memory.SetRuntimeMemoryCheck(true)
}

memEstimator.SetDefaultBackend(llamaCppBackend)

vllmBackend, err := vllm.New(
log,
modelManager,
Expand Down Expand Up @@ -160,7 +142,6 @@ func main() {
"",
false,
),
sysMemInfo,
)

// Create the HTTP handler for the scheduler
Expand Down
53 changes: 0 additions & 53 deletions pkg/inference/memory/estimator.go

This file was deleted.

18 changes: 0 additions & 18 deletions pkg/inference/memory/settings.go

This file was deleted.

64 changes: 0 additions & 64 deletions pkg/inference/memory/system.go

This file was deleted.

22 changes: 3 additions & 19 deletions pkg/inference/models/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,10 @@ import (
"github.com/docker/model-runner/pkg/distribution/builder"
reg "github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"

"github.com/sirupsen/logrus"
)

type mockMemoryEstimator struct{}

func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {}

func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) {
return inference.RequiredMemory{RAM: 0, VRAM: 0}, nil
}

func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, inference.RequiredMemory, inference.RequiredMemory, error) {
return true, inference.RequiredMemory{}, inference.RequiredMemory{}, nil
}

// getProjectRoot returns the absolute path to the project root directory
func getProjectRoot(t *testing.T) string {
// Start from the current test file's directory
Expand Down Expand Up @@ -123,11 +110,10 @@ func TestPullModel(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
log := logrus.NewEntry(logrus.StandardLogger())
memEstimator := &mockMemoryEstimator{}
handler := NewHTTPHandler(log, ClientConfig{
StoreRootPath: tempDir,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
}, nil, memEstimator)
}, nil)

r := httptest.NewRequest(http.MethodPost, "/models/create", strings.NewReader(`{"from": "`+tag+`"}`))
if tt.acceptHeader != "" {
Expand Down Expand Up @@ -234,13 +220,12 @@ func TestHandleGetModel(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
log := logrus.NewEntry(logrus.StandardLogger())
memEstimator := &mockMemoryEstimator{}
handler := NewHTTPHandler(log, ClientConfig{
StoreRootPath: tempDir,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
Transport: http.DefaultTransport,
UserAgent: "test-agent",
}, nil, memEstimator)
}, nil)

// First pull the model if we're testing local access
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
Expand Down Expand Up @@ -315,11 +300,10 @@ func TestCors(t *testing.T) {
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
t.Parallel()
memEstimator := &mockMemoryEstimator{}
discard := logrus.New()
discard.SetOutput(io.Discard)
log := logrus.NewEntry(discard)
m := NewHTTPHandler(log, ClientConfig{}, []string{"*"}, memEstimator)
m := NewHTTPHandler(log, ClientConfig{}, []string{"*"})
req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody)
req.Header.Set("Origin", "docker.com")
w := httptest.NewRecorder()
Expand Down
30 changes: 5 additions & 25 deletions pkg/inference/models/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/internal/utils"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/middleware"
Expand All @@ -38,8 +37,6 @@ type HTTPHandler struct {
httpHandler http.Handler
// lock is used to synchronize access to the models manager's router.
lock sync.RWMutex
// memoryEstimator is used to calculate runtime memory requirements for models.
memoryEstimator memory.MemoryEstimator
// manager handles business logic for model operations.
manager *Manager
}
Expand All @@ -56,13 +53,12 @@ type ClientConfig struct {
}

// NewHTTPHandler creates a new model's handler.
func NewHTTPHandler(log logging.Logger, c ClientConfig, allowedOrigins []string, memoryEstimator memory.MemoryEstimator) *HTTPHandler {
func NewHTTPHandler(log logging.Logger, c ClientConfig, allowedOrigins []string) *HTTPHandler {
// Create the manager.
m := &HTTPHandler{
log: log,
router: http.NewServeMux(),
memoryEstimator: memoryEstimator,
manager: NewManager(log.WithFields(logrus.Fields{"component": "service"}), c),
log: log,
router: http.NewServeMux(),
manager: NewManager(log.WithFields(logrus.Fields{"component": "service"}), c),
}

// Register routes.
Expand Down Expand Up @@ -163,23 +159,7 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request)
// Normalize the model name to add defaults
request.From = NormalizeModelName(request.From)

// Pull the model. In the future, we may support additional operations here
// besides pulling (such as model building).
if memory.RuntimeMemoryCheckEnabled() && !request.IgnoreRuntimeMemoryCheck {
h.log.Infof("Will estimate memory required for %q", request.From)
proceed, req, totalMem, err := h.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil)
if err != nil {
h.log.Warnf("Failed to validate sufficient system memory for model %q: %s", request.From, err)
// Prefer staying functional in case of unexpected estimation errors.
proceed = true
}
if !proceed {
errstr := fmt.Sprintf("Runtime memory requirement for model %q exceeds total system memory: required %d RAM %d VRAM, system %d RAM %d VRAM", request.From, req.RAM, req.VRAM, totalMem.RAM, totalMem.VRAM)
h.log.Warnf(errstr)
http.Error(w, errstr, http.StatusInsufficientStorage)
return
}
}
// Pull the model
if err := h.manager.Pull(request.From, request.BearerToken, r, w); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
h.log.Infof("Request canceled/timed out while pulling model %q", request.From)
Expand Down
Loading
Loading