Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ func main() {
log.Debugf("AMD GPU detection failed: %v", err)
}

// Check if we have supported MTHREADS GPUs and set ROCm variant accordingly
if hasAMD, err := gpuInfo.HasSupportedMTHREADSGPU(); err == nil && hasAMD {
log.Info("Supported MTHREADS GPU detected, MUSA will be used automatically")
// This will be handled by the llama.cpp backend during server download
} else if err != nil {
log.Debugf("MTHREADS GPU detection failed: %v", err)
}

// Create llama.cpp configuration from environment variables
llamaCppConfig := createLlamaCppConfigFromEnv()

Expand Down
4 changes: 4 additions & 0 deletions pkg/gpuinfo/gpuinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ func (g *GPUInfo) GetVRAMSize() (uint64, error) {
func (g *GPUInfo) HasSupportedAMDGPU() (bool, error) {
return hasSupportedAMDGPU()
}

func (g *GPUInfo) HasSupportedMTHREADSGPU() (bool, error) {
return hasSupportedMTHREADSGPU()
}
7 changes: 6 additions & 1 deletion pkg/gpuinfo/gpuinfo_not_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ package gpuinfo
func (g *GPUInfo) HasSupportedAMDGPU() (bool, error) {
// AMD GPU detection is only supported on Linux
return false, nil
}
}

func (g *GPUInfo) HasSupportedMTHREADSGPU() (bool, error) {
// MTHREADS GPU detection is only supported on Linux
return false, nil
}
6 changes: 6 additions & 0 deletions pkg/gpuinfo/memory_darwin_cgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ func hasSupportedAMDGPU() (bool, error) {
// AMD GPU detection is only supported on Linux
return false, nil
}

// hasSupportedMTHREADSGPU returns true if the system has supported AMD GPUs
func hasSupportedMTHREADSGPU() (bool, error) {
// MTHREADS GPU detection is only supported on Linux
return false, nil
}
6 changes: 6 additions & 0 deletions pkg/gpuinfo/memory_darwin_nocgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ func hasSupportedAMDGPU() (bool, error) {
// AMD GPU detection is only supported on Linux
return false, nil
}

// hasSupportedMTHREADSGPU returns true if the system has supported AMD GPUs
func hasSupportedMTHREADSGPU() (bool, error) {
// MTHREADS GPU detection is only supported on Linux
return false, nil
}
5 changes: 5 additions & 0 deletions pkg/gpuinfo/memory_linux_nocgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ func getVRAMSize(_ string) (uint64, error) {
func hasSupportedAMDGPU() (bool, error) {
return false, errors.New("unimplemented without cgo")
}

// hasSupportedMTHREADSGPU returns true if the system has supported AMD GPUs
func hasSupportedMTHREADSGPU() (bool, error) {
return false, errors.New("unimplemented without cgo")
}
6 changes: 6 additions & 0 deletions pkg/gpuinfo/memory_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ func hasSupportedAMDGPU() (bool, error) {
// AMD GPU detection is only supported on Linux
return false, nil
}

// hasSupportedMTHREADSGPU returns true if the system has supported AMD GPUs
func hasSupportedMTHREADSGPU() (bool, error) {
// MTHREADS GPU detection is only supported on Linux
return false, nil
}
82 changes: 82 additions & 0 deletions pkg/gpuinfo/mthreads_gpu_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//go:build linux

package gpuinfo

import (
"bufio"
"bytes"
"errors"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
)

func hasSupportedMTHREADSGPU() (bool, error) {
// Check if /dev contains mtgpu.* devices
devDir := "/dev"
devEntries, err := os.ReadDir(devDir)
if err != nil {
return false, err
}

foundMTGPU := false
for _, entry := range devEntries {
if strings.HasPrefix(entry.Name(), "mtgpu") {
// MTGPU driver should be properly installed and loaded
foundMTGPU = true
break
}
}
if !foundMTGPU {
return false, nil // no MTHREADS GPU device found
}

// Run muInfo to collect GPU information
cmd := exec.Command("muInfo")
var out bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &out
if err := cmd.Run(); err != nil {
return false, errors.New("failed to execute muInfo: " + err.Error() + "\n" + out.String())
}

// Parse the output
scanner := bufio.NewScanner(&out)
reDriver := regexp.MustCompile(`Driver Version:[ \t]+([0-9.]+)`)
reCompute := regexp.MustCompile(`compute capability:[ \t]+([0-9.]+)`)

var driverVerStr, computeCapStr string

for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())

if m := reDriver.FindStringSubmatch(line); len(m) == 2 {
driverVerStr = m[1]
}
if m := reCompute.FindStringSubmatch(line); len(m) == 2 {
computeCapStr = m[1]
}
}

if driverVerStr == "" || computeCapStr == "" {
return false, errors.New("failed to parse muInfo output for driver version or compute capability")
}

// Helper to parse float version (e.g., 4.3, 2.1)
parseVersion := func(s string) float64 {
v, _ := strconv.ParseFloat(s, 64)
return v
}

driverVer := parseVersion(driverVerStr)
computeCap := parseVersion(computeCapStr)

// Check minimum supported versions
if driverVer >= 4.3 && computeCap >= 2.1 {
return true, nil
}

return false, nil // No supported MTHREADS GPU found
}
20 changes: 16 additions & 4 deletions pkg/inference/backends/llamacpp/download_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger,
llamaCppPath, vendoredServerStoragePath string,
) error {
var hasAMD bool
var hasMTHREADS bool
var err error

ShouldUseGPUVariantLock.Lock()
defer ShouldUseGPUVariantLock.Unlock()
if ShouldUseGPUVariant {
Expand All @@ -33,17 +34,28 @@ func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger,
if err != nil {
log.Debugf("AMD GPU detection failed: %v", err)
}

hasMTHREADS, err = gpuInfo.HasSupportedMTHREADSGPU()
if err != nil {
log.Debugf("MTHREADS GPU detection failed: %v", err)
}
}

desiredVersion := GetDesiredServerVersion()
desiredVariant := "cpu"

// Use ROCm if supported AMD GPU is detected
if hasAMD {
log.Info("Supported AMD GPU detected, using ROCm variant")
desiredVariant = "rocm"
}


// USE MUSA if supported MTHREADS GPU is detected
if hasMTHREADS {
log.Info("Supported MTHREADS GPU detected, using MUSA variant")
desiredVariant = "musa"
}

l.status = fmt.Sprintf("looking for updates for %s variant", desiredVariant)
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
desiredVariant)
Expand Down