Skip to content

Fix inference precision.#802

Open
egeonur wants to merge 4 commits intoPriorLabs:mainfrom
egeonur:ege/fix-fit-mode
Open

Fix inference precision.#802
egeonur wants to merge 4 commits intoPriorLabs:mainfrom
egeonur:ege/fix-fit-mode

Conversation

@egeonur
Copy link
Copy Markdown

@egeonur egeonur commented Mar 1, 2026

Issue

#631 fix for this issue
I tried to fix the precision issue.First fix was come from #784 which doesn't add thinking tokens so that single eval pos stays zero and kv cache can be used during prediction. I casted all tensors for the given dtype so results became

no_cache vs repeat: 0.0
no_cache vs fit_preprocessors: 0.0
no_cache vs fit_with_cache: 0.0

only caveat is that when I run script with float32 there were still some inconsistencies like:

no_cache vs repeat: 5.3390077e-06
no_cache vs fit_preprocessors: 5.3390077e-06
no_cache vs fit_with_cache: 5.5486857e-06

but I am guessing it might be related to low precision 64 vs 32.

Also tests/test_consistency.py fails locally but I assume they are stored for future comparisons if sth deviates

Motivation and Context

code to run on local machine for above results:

import random
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import numpy as np
import torch

from tabpfn import TabPFNRegressor

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42
)

def _set_seeds() -> None:
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

_set_seeds()
reg = TabPFNRegressor(fit_mode="low_memory", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_no_cache = reg.predict(X_test)

reg = TabPFNRegressor(fit_mode="low_memory", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_no_cache_repeat = reg.predict(X_test)

_set_seeds()
reg = TabPFNRegressor(fit_mode="fit_preprocessors", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_cache_preproc = reg.predict(X_test)

_set_seeds()
reg = TabPFNRegressor(fit_mode="fit_with_cache", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_kv_cache = reg.predict(X_test)

def _max_diff(a: np.ndarray, b: np.ndarray) -> float:
    return np.max(np.abs(a - b) / np.abs(a))

print("max relative diffs")
print("no_cache vs no_cache_repeat:", _max_diff(preds_no_cache, preds_no_cache_repeat))
print("no_cache vs cache_preproc:", _max_diff(preds_no_cache, preds_cache_preproc))
print("no_cache vs kv_cache:", _max_diff(preds_no_cache, preds_kv_cache))

Public API Changes

  • [ X] No Public API changes
  • Yes, Public API changes (Details below)

How Has This Been Tested?

Tested locally without GPU only on macbook cpu.
Collecting system and dependency information...
PyTorch version: 2.10.0
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.7.3 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: version 3.31.1
Libc version: N/A

Python version: 3.11.9 (main, Nov 22 2024, 14:33:40) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-15.7.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Dependency Versions:

tabpfn: 6.4.1
torch: 2.10.0
numpy: 2.4.2
scipy: 1.17.1
pandas: 2.3.3
scikit-learn: 1.8.0
typing_extensions: 4.15.0
einops: 0.8.2
huggingface-hub: 1.5.0

Checklist

  • [ X] The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • A changelog entry has been added (see changelog/README.md), or "no changelog needed" label requested.
  • [X ] The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

Copilot AI review requested due to automatic review settings March 1, 2026 21:44
@egeonur egeonur requested a review from a team as a code owner March 1, 2026 21:44
@egeonur egeonur requested review from klemens-floege and removed request for a team March 1, 2026 21:44
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Mar 1, 2026

CLA assistant check
All committers have signed the CLA.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses an inference precision issue by ensuring that the specified inference_precision is respected throughout the InferenceEngineCacheKV. The changes correctly cast tensors to the desired data type, which resolves the inconsistencies noted. Additionally, the modification to conditionally add "thinking tokens" only when not using a KV cache is a logical improvement for consistency. The code is well-structured, and I have a couple of minor suggestions to enhance conciseness.

Comment thread src/tabpfn/inference.py
Comment on lines +777 to +781
inference_dtype = (
force_inference_dtype
if force_inference_dtype is not None
else torch.float32
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for determining inference_dtype can be made more concise. Since torch.dtype objects are not falsy, you can use the or operator to simplify this assignment.

            inference_dtype = force_inference_dtype or torch.float32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you make this change? This way it will match what we do in _prepare_model_inputs().

Comment thread src/tabpfn/inference.py
Comment on lines +835 to +839
inference_dtype = (
self.force_inference_dtype
if self.force_inference_dtype is not None
else torch.float32
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for determining inference_dtype can be simplified for better readability and conciseness. Using the or operator is a more idiomatic way to provide a default value in this case.

            inference_dtype = self.force_inference_dtype or torch.float32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same again.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes prediction inconsistencies across fit_modes by aligning inference-time dtype handling and preventing KV-cache inference from adding extra “thinking” tokens that would change context length / cache behavior.

Changes:

  • Make preprocessing reproducible across repeated predict() calls by overriding preprocessing random state in the on-demand inference engine.
  • Force model parameters and input tensors to the requested inference_precision dtype for KV-cache inference.
  • Skip adding thinking tokens during KV-cache prediction (single_eval_pos == 0) to keep cacheable context stable.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
src/tabpfn/inference.py Adjusts preprocessing seeding override and forces inference dtype casting for KV-cache path.
src/tabpfn/architectures/base/transformer.py Avoids adding thinking tokens during KV-cache prediction to preserve cache consistency.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/tabpfn/inference.py
Comment on lines 365 to 369
y_train=self.y_train,
feature_schema=self.feature_schema,
parallel_mode="in-order",
override_random_state=np.random.default_rng(self.static_seed),
override_random_state=self.static_seed,
)
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override_random_state is now passed as an int (self.static_seed). In TabPFNEnsemblePreprocessor.fit_transform_ensemble_members_iterator the random_state is selected via override_random_state or self.random_state, which will ignore an override of 0 (since 0 is falsy) and fall back to self.random_state, reintroducing non-deterministic preprocessing across predict calls. Prefer either passing a truthy override (e.g., a np.random.Generator like before) or (better) changing the downstream selection to override_random_state if override_random_state is not None else self.random_state so that seed 0 is respected.

Copilot uses AI. Check for mistakes.
Comment on lines +526 to 530
is_kv_cache_prediction = (
self.cache_trainset_representation and single_eval_pos == 0
)
if self.add_thinking_tokens is not None and not is_kv_cache_prediction:
embedded_input, single_eval_pos = self.add_thinking_tokens(
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change alters when thinking tokens are added (they’re skipped for KV-cache prediction when single_eval_pos==0). There’s currently no test covering this specific behavior/contract (e.g., that fit_with_cache prediction path doesn’t append thinking tokens and stays consistent with other fit modes for a fixed seed). Please add/adjust a unit/integration test to lock this in—re-enabling the existing skipped “fit modes return equal results” tests (or adding a targeted regression test for #631) would help prevent regressions.

Copilot uses AI. Check for mistakes.
@egeonur
Copy link
Copy Markdown
Author

egeonur commented Mar 10, 2026

@klemens-floege hey I checked the failing tests but it is not sth I changed in my pr. #757 after this change this test can be failed but my local changes don't touch modality_detection.py or test_modality_detection.py. what would you suggest? macos test I can fix. it is consistency test so no big issue I already mentioned it in pr description but other platform errors shouldn't depend on my changes. The test seems flaky or at least environment-sensitive.

@klemens-floege
Copy link
Copy Markdown
Contributor

@egeonur could you pls try a simple rebase to main? For the changelog test you need to add an .md file in the changelog folder. Thank you :)

@egeonur
Copy link
Copy Markdown
Author

egeonur commented Mar 11, 2026

@klemens-floege I rebased, fixed the consistency test and added .md file. Can you run again to see whether the tests are fixed? 🤞

@egeonur
Copy link
Copy Markdown
Author

egeonur commented Mar 11, 2026

@klemens-floege I get the same errors again but #815 had the same failing tests. my local python version is 3.11 and I guess this error happens since CI paths use Python 3.14 so I guess behaviour of the of pd.to_numeric or s.isna() changed in Python 3.14

Copy link
Copy Markdown
Contributor

@klemens-floege klemens-floege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks again for contributing! I ran the reproduction script locally from your branch (on CPU) and the inconsistencies are still present:

  • no_cache vs cache_preproc: 0.013
  • no_cache vs kv_cache: 1.423

A few thoughts:

On the consistency tests: I'd prefer not to adjust the random dataset inside the estimators to fix the test failures — those tests are there to verify that the internal model behaviour hasn't changed, so relaxing them isn't a great signal. If adjusting the random state inside the estimators is truly unavoidable to fix the inconsistency, that may be acceptable, but it should be a deliberate decision.

What I think would be valuable: Could you add a test to test_regressor.py and test_classifier.py that explicitly verifies the different fit_mode options produce consistent predictions? Something like:

def test_fit_mode_consistency(regressor_or_classifier):
      # assert predictions from low_memory, fit_preprocessors, fit_with_cache
      # are all within float tolerance of each other

This would both document the expected behaviour and catch regressions going forward.

@egeonur
Copy link
Copy Markdown
Author

egeonur commented Mar 12, 2026

import random
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import numpy as np
import torch

from tabpfn import TabPFNRegressor

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42
)

def _set_seeds() -> None:
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

_set_seeds()
reg = TabPFNRegressor(fit_mode="low_memory", inference_precision=torch.float64, device="cpu")
reg.fit(X_train, y_train)
preds_no_cache = reg.predict(X_test)

reg = TabPFNRegressor(fit_mode="low_memory", inference_precision=torch.float64, device="cpu")
reg.fit(X_train, y_train)
preds_no_cache_repeat = reg.predict(X_test)

_set_seeds()
reg = TabPFNRegressor(fit_mode="fit_preprocessors", inference_precision=torch.float64, device="cpu")
reg.fit(X_train, y_train)
preds_cache_preproc = reg.predict(X_test)

_set_seeds()
reg = TabPFNRegressor(fit_mode="fit_with_cache", inference_precision=torch.float64, device="cpu")
reg.fit(X_train, y_train)
preds_kv_cache = reg.predict(X_test)

def _max_diff(a: np.ndarray, b: np.ndarray) -> float:
    return np.max(np.abs(a - b) / np.abs(a))

print("max relative diffs")
print("no_cache vs no_cache_repeat:", _max_diff(preds_no_cache, preds_no_cache_repeat))
print("no_cache vs cache_preproc:", _max_diff(preds_no_cache, preds_cache_preproc))
print("no_cache vs kv_cache:", _max_diff(preds_no_cache, preds_kv_cache))

when I run this on with uv run my local CPU I get this result:
no_cache vs no_cache_repeat: 0.0
no_cache vs cache_preproc: 0.0
no_cache vs kv_cache: 0.0
@klemens-floege If I may ask which script did you run? egeonur:ege/fix-fit-mode this is the branch name where I fixed the issue. If you run it on the main I get the same errors as you do but that is before my fix. If it is another script can I try it as well? the override_random_state change isn't to make consistency tests pass. it's to fix the underlying inconsistency between fit modes. without that change
no_cache vs no_cache_repeat: 0.0
no_cache vs cache_preproc: 0.013375835
no_cache vs kv_cache: 0.013375835 they were still inconsistent so I changed it to fix not to past consistency test

Comparison main dtype fix only dtype fix + random state fix
no_cache vs no_cache_repeat 0.0% 0.0% 0.0%
no_cache vs cache_preproc 1.3% 1.3% 0.0%
no_cache vs kv_cache 142.3% 1.3% 0.0%

@klemens-floege
Copy link
Copy Markdown
Contributor

@egeonur I ran you script with device set to cpu:

max relative diffs
no_cache vs no_cache_repeat: 0.0
no_cache vs cache_preproc:   0.013377033
no_cache vs kv_cache:        1.4230262
tabpfn: 6.4.1 
numpy: 2.3.3
pandas: 2.3.3
scikit-learn: 1.6.1
scipy: 1.16.2
torch: 2.10.0

It could be that the pandas version makes a difference will do more digging, I think a concrete good next step in this PR is to add a new test to test_consistency.py that specifically tests this behavior:

  @pytest.mark.parametrize("estimator_cls,data_fn", [
      (TabPFNRegressor, _get_tiny_regression_data),
      (TabPFNClassifier, _get_tiny_classification_data),
  ])
  def test__fit_modes__produce_consistent_predictions(estimator_cls, data_fn):
      """All fit_mode values should produce numerically equivalent predictions."""
      X_train, y_train, X_test = data_fn()
      fit_modes = ["low_memory", "fit_preprocessors", "fit_with_cache"]
      preds = {}
      for mode in fit_modes:
          model = estimator_cls(**DEFAULT_CONFIG, fit_mode=mode)
          model.fit(X_train, y_train)
          if isinstance(model, TabPFNClassifier):
              preds[mode] = model.predict_proba(X_test)
          else:
              preds[mode] = model.predict(X_test)

      reference = preds["low_memory"]
      for mode in ["fit_preprocessors", "fit_with_cache"]:
          np.testing.assert_allclose(
              preds[mode], reference, rtol=1e-3, atol=1e-3,
              err_msg=f"fit_mode='{mode}' predictions differ from 'low_memory'",
          )

@oscarkey
Copy link
Copy Markdown
Contributor

oscarkey commented Apr 2, 2026

hey @egeonur , I'm just merging #852 which makes fit_preprocessors and fit_with_cache consistent. However, low_memory is still inconsistent, do you think this change might help with that?

Re the tests, we have test__fit_preprocessors_and_low_memory_produce_equal_results in both test_regressor_interface.py and test_classifier_interface.py. These are currently disabled, because the fit modes are not consistent. So if this PR works we should be able to re-enable them!

@egeonur
Copy link
Copy Markdown
Author

egeonur commented Apr 2, 2026

Hey @oscarkey — I was about to write the same update.

I tested my fork again on a different MacBook with an M3 chip yesterday, and I still got zero diffs. That matches what I saw before.

What confuses me is that @klemens-floege said he still did not see improvements, and reported these max relative diffs:

no_cache vs no_cache_repeat: 0.0
no_cache vs cache_preproc:   0.013377033
no_cache vs kv_cache:        1.4230262

So from his side it looks like the fix is not taking effect.

I am wondering whether this could be a branch/version mismatch somewhere, although I am not certain. I tested with the same library versions he mentioned:

  • tabpfn: 6.4.1
  • numpy: 2.3.3
  • pandas: 2.3.3
  • scikit-learn: 1.6.1
  • scipy: 1.16.2
  • torch: 2.10.0

On my side, with float64 precision, I again got zero diffs.

The key thing I found is this change in PR #802:
https://github.com/PriorLabs/TabPFN/pull/802/changes#diff-261bbedda6d582b2e6a903d83c193c83d082dd4d6b8513305b998ca732f6ce73L366

That part seems to solve the issue by overriding the random state setting to self.static_seed.

So I am not sure why Klemens still sees no improvement with my local version, especially because my solution was built on top of your solution. Even if it did not fully fix everything on his side, I would have expected at least some improvement. If you have time to check whether switching to this PR changes the results when you run it, I’d really appreciate it. Then I can pull your latest changes from main and try again with the seed improvement.

@oscarkey
Copy link
Copy Markdown
Contributor

oscarkey commented Apr 2, 2026

hey @egeonur . I applied your static seed fix and ran test_regressor_interface.py:test__fit_preprocessors_and_low_memory_produce_equal_results and test_classifier_interface.py:test__fit_preprocessors_and_low_memory_produce_equal_results, and they passed. So I think you should be good to rebase the PR!

I think it should also be safe to update tests/test_consistency.py to stop specifying the fit_mode entirely and just test the default one, as we'll now have tests asserting that all the fit modes produce the same answers.

@egeonur egeonur force-pushed the ege/fix-fit-mode branch from cea20d8 to 080cf3b Compare April 5, 2026 21:27
@egeonur
Copy link
Copy Markdown
Author

egeonur commented Apr 5, 2026

@oscarkey I removed the skip markers from the fit_preprocessors vs low_memory equivalence tests in tests/test_classifier_interface.py and tests/ test_regressor_interface.py, and they now pass locally. I also regenerated the consistency snapshots and tests/test_consistency.py passes locally.

On your earlier comment about simplifying tests/test_consistency.py: do you mean removing the fit-mode-specific cases in TEST_CASES and just testing
the default estimator behavior there, while relying on the interface tests for fit-mode equivalence? The blocks I think you mean are at:

  • tests/test_consistency.py:109
  • tests/test_consistency.py:124
  • tests/test_consistency.py:137
  • tests/test_consistency.py:152

If that’s what you had in mind, I’m happy to make that cleanup too. I just wanted to confirm exactly what you’d prefer to remove.

Other than that, I think this issue is fixed. All relevant tests passed locally. I’m still seeing some unrelated mypy errors in tests/ test_browser_auth.py, but they don’t seem related to this PR.

Copy link
Copy Markdown
Contributor

@oscarkey oscarkey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean removing the fit-mode-specific cases in TEST_CASES

yes exactly! To update the reference predictions, I would suggest renaming [prefix]_fit_preprocessors.json to just [prefix].json, and deleting the low_memory/fit_with_cache reference predictions. You can also revert all other reference prediction changes, as the predictions should not have changed.

Comment thread src/tabpfn/inference.py
Comment thread src/tabpfn/inference.py
Comment on lines +884 to +893
inference_dtype = (
self.force_inference_dtype
if self.force_inference_dtype is not None
else torch.float32
)
X_test = torch.as_tensor(
X_test,
dtype=inference_dtype,
device=self.device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same again? This already happens below?

Copy link
Copy Markdown
Author

@egeonur egeonur Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lower cast only overlaps partially. My change makes the KV-cache path normalize inputs up front with inference_dtype = force_inference_dtype or
torch.float32, even when X and y already arrive as tensors. Before that, this branch only cast non-tensor inputs to float32, so tensor inputs could retain
an arbitrary dtype, and force_inference_dtype was not applied consistently during cache building. The later block only re-applies the forced dtype after
GPU preprocessing, so it does not replace the up-front normalization or the default float32 path.

With the seed held fixed, that difference was observable in output consistency. Before this change, I saw:

max relative diffs
no_cache vs no_cache_repeat: 0.0
no_cache vs cache_preproc: 0.0
no_cache vs kv_cache: 1.191185e-06

After normalizing the KV-cache path to use self.force_inference_dtype instead of defaulting to float32, the discrepancy went to zero:

max relative diffs
no_cache vs no_cache_repeat: 0.0
no_cache vs cache_preproc: 0.0
no_cache vs kv_cache: 0.0

When I last checked TabPFN-2.6 doesn't support currently fit_with_cache fit_mode so these were the results from TabPFN-2.5. but if you say they are redundant I can revert them back but without that change small difference will stay imo.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh I see my mistake! I think there might still be some redundancy (see my other comments), but if we fix that I think we're ready to go :)

@egeonur
Copy link
Copy Markdown
Author

egeonur commented Apr 12, 2026

@oscarkey I updated as you sugessted. I simplified tests/test_consistency.py to snapshot only the default estimator behavior, removed the fit-mode-specific consistency cases, renamed
the surviving *_fit_preprocessors.json references to the plain filenames, and deleted the obsolete low_memory / fit_with_cache reference predictions. I
also rebased onto current main, preserved the deterministic seed fix in inference.py, and re-ran uv run pytest tests/test_consistency.py tests/
test_inference.py; both pass. also for the other change I replied to your pr review but can you look into that as wellso that if you find it maybe we can merge this pr? otherwise I can revert that change and see any cleanup is required. Thanks :)

Copy link
Copy Markdown
Contributor

@oscarkey oscarkey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey, thanks for making the changes, I think we're nearly there!

I think the changes in reference_predictions can be reverted? The only thing that should change is *_fit_preprocessors.json should be renamed to *.json, but the contents of the files shouldn't change.

Comment thread src/tabpfn/inference.py
Comment on lines +884 to +893
inference_dtype = (
self.force_inference_dtype
if self.force_inference_dtype is not None
else torch.float32
)
X_test = torch.as_tensor(
X_test,
dtype=inference_dtype,
device=self.device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhh I see my mistake! I think there might still be some redundancy (see my other comments), but if we fix that I think we're ready to go :)

Comment thread src/tabpfn/inference.py

if self.force_inference_dtype is not None:
model.type(self.force_inference_dtype)
X_test = X_test.type(self.force_inference_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, hopefully we can delete this line.

Comment thread src/tabpfn/inference.py
Comment on lines +835 to +839
inference_dtype = (
self.force_inference_dtype
if self.force_inference_dtype is not None
else torch.float32
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same again.

Comment thread src/tabpfn/inference.py
Comment on lines +777 to +781
inference_dtype = (
force_inference_dtype
if force_inference_dtype is not None
else torch.float32
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you make this change? This way it will match what we do in _prepare_model_inputs().

Comment thread src/tabpfn/inference.py
Comment on lines 895 to 896
X = X.type(force_inference_dtype)
y = y.type(force_inference_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully we can delete these lines as we do it above? I guess there's a risk _maybe_run_gpu_preprocessing() doesn't preserve the type. But, judging by the other inference engines, it looks like it does?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants