-
Notifications
You must be signed in to change notification settings - Fork 7
Force weights_only=False in torch.load to fix PyTorch 2.6+ checkpoint loading #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…nly=False Co-authored-by: springfall2008 <48591903+springfall2008@users.noreply.github.com>
|
@copilot why are we forcing weights only to False, where is the reason behind this fix? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR addresses a compatibility issue with PyTorch 2.6+ where the default value of weights_only in torch.load() changed from False to True, causing UnpicklingError when PyTorch Lightning loads NeuralProphet checkpoints during learning rate finding.
Key changes:
- Strengthened the monkey-patch from setting a default to forcefully overriding the
weights_onlyparameter toFalsein alltorch.load()calls - Applied the same approach to both production code and test file
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| test_predai.py | Modified monkey-patch to unconditionally force weights_only=False instead of only setting it as a default, with added explanatory comments |
| predai/rootfs/predai.py | Modified monkey-patch to unconditionally force weights_only=False instead of only setting it as a default, with updated comments explaining it overrides PyTorch Lightning's explicit parameter |
Comments suppressed due to low confidence (9)
predai/rootfs/predai.py:1
- Import of 'Any' is not used.
from typing import Any
predai/rootfs/predai.py:3
- Import of 'np' is not used.
import numpy as np
predai/rootfs/predai.py:8
- Import of 'aiohttp' is not used.
import aiohttp
predai/rootfs/predai.py:11
- Import of 'json' is not used.
import json
predai/rootfs/predai.py:12
- Import of 'ssl' is not used.
import ssl
test_predai.py:2
- Import of 'MagicMock' is not used.
Import of 'patch' is not used.
from unittest.mock import AsyncMock, MagicMock, patch
test_predai.py:25
- Import of 'timestr_to_datetime' is not used.
from predai import HAInterface, Prophet, timestr_to_datetime
predai/rootfs/predai.py:427
- File is opened but is not closed.
config = yaml.safe_load(open("/config/predai.yaml"))
predai/rootfs/predai.py:430
- File is opened but is not closed.
config = yaml.safe_load(open("/config/predai.yaml"))
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Monkey-patch torch.load to force weights_only=False | ||
| # This is needed for PyTorch 2.6+ compatibility with PyTorch Lightning | ||
| # PyTorch Lightning explicitly passes weights_only=True, so we need to override it |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states "PyTorch Lightning explicitly passes weights_only=True, so we need to override it", which suggests the approach is designed to bypass PyTorch Lightning's security setting. This is a concerning design decision. PyTorch Lightning likely passes weights_only=True intentionally for security reasons.
Instead of forcefully overriding PyTorch Lightning's security choice, consider using the add_safe_globals approach (which is already implemented in this file at lines 34-44) more comprehensively, or work with PyTorch Lightning's configuration options to handle checkpoint loading differently.
| # Force weights_only=False for all checkpoint loads | ||
| # This is safe because we're loading locally-created NeuralProphet checkpoints |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The claim "This is safe because we're loading locally-created NeuralProphet checkpoints" is not enforced by the code. The monkey patch applies to ALL calls to torch.load() throughout the entire application, regardless of the source of the checkpoint file or what library is calling it.
If the application ever loads checkpoints from external sources, or if any dependency uses torch.load(), this blanket override removes protection against malicious pickle files. The safety claim in the comment doesn't match the actual implementation.
| if 'weights_only' not in kwargs: | ||
| kwargs['weights_only'] = False | ||
| # Force weights_only=False for all checkpoint loads | ||
| # This is safe because we're loading locally-created NeuralProphet checkpoints |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The claim "This is safe because we're loading locally-created NeuralProphet checkpoints" is not enforced by the code. The monkey patch applies to ALL calls to torch.load() throughout the entire application, regardless of the source of the checkpoint file or what library is calling it.
If the application ever loads checkpoints from external sources, or if any dependency uses torch.load(), this blanket override removes protection against malicious pickle files. The safety claim in the comment doesn't match the actual implementation.
| # This is safe because we're loading locally-created NeuralProphet checkpoints | |
| # WARNING: This disables torch.load's weights_only safeguard globally; use only in trusted test environments with known-safe checkpoints. |
| def _patched_torch_load(*args, **kwargs): | ||
| if 'weights_only' not in kwargs: | ||
| kwargs['weights_only'] = False | ||
| # Force weights_only=False for all checkpoint loads | ||
| # This is safe because we're loading locally-created NeuralProphet checkpoints | ||
| kwargs['weights_only'] = False |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unconditionally forcing weights_only=False overrides any security-conscious callers (like PyTorch Lightning) that explicitly set weights_only=True. This approach disables PyTorch's security safeguards for all checkpoint loads in the application, not just NeuralProphet checkpoints.
Consider a more targeted approach that preserves the caller's intent when explicitly specified. For example, you could check if the checkpoint path indicates it's a NeuralProphet checkpoint before overriding, or only apply the override when loading from trusted local paths. A safer approach might be to check if an UnpicklingError occurs with weights_only=True and only then retry with weights_only=False for known safe checkpoint types.
| def _patched_torch_load(*args, **kwargs): | ||
| if 'weights_only' not in kwargs: | ||
| kwargs['weights_only'] = False | ||
| # Force weights_only=False for all checkpoint loads | ||
| # This is safe because we're loading locally-created NeuralProphet checkpoints | ||
| kwargs['weights_only'] = False |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unconditionally forcing weights_only=False overrides any security-conscious callers (like PyTorch Lightning) that explicitly set weights_only=True. This approach disables PyTorch's security safeguards for all checkpoint loads in the application, not just NeuralProphet checkpoints.
Consider a more targeted approach that preserves the caller's intent when explicitly specified. For example, you could check if the checkpoint path indicates it's a NeuralProphet checkpoint before overriding, or only apply the override when loading from trusted local paths. A safer approach might be to check if an UnpicklingError occurs with weights_only=True and only then retry with weights_only=False for known safe checkpoint types.
| kwargs['weights_only'] = False | ||
| return _original_torch_load(*args, **kwargs) | ||
| torch.load = _patched_torch_load | ||
| except (ImportError, AttributeError): |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'except' clause does nothing but pass and there is no explanatory comment.
| except (ImportError, AttributeError): | |
| except (ImportError, AttributeError): | |
| # If PyTorch is not available or torch.load is missing, skip patching and | |
| # run the tests without this compatibility adjustment. |
PyTorch 2.6 changed
torch.load()'s defaultweights_onlyfromFalsetoTrue. PyTorch Lightning explicitly passesweights_only=Truewhen loading checkpoints during learning rate finding, causingUnpicklingErrorwith NeuralProphet configuration objects.Changes
Monkey-patch strengthened: Changed from setting default to forcing override
Applied to both
predai/rootfs/predai.pyandtest_predai.pyThis bypasses PyTorch Lightning's explicit
weights_only=Trueparameter, allowing locally-created NeuralProphet checkpoints to load without requiring safe global allowlisting.Original prompt
This section details on the original issue you should resolve
<issue_title>Re-updated PredAI with fix #34 - new error</issue_title>
<issue_description>Error has moved on further in the code....
[15:48:18] INFO: PredAI init script running
Running App Process Management
Your API key is: bb5719e612db3905e8f39e526104071d58a4f3ea981e4c790dd254e7d9a57cc73b7381a6a65ab949a9eb0233afb7a690b30b84472aa872a7
Bootstrap predai
Copy initial python files
Startup
Importing plotly failed. Interactive plots will not work.
Importing plotly failed. Interactive plots will not work.
HA Interface started key bb5719e612db3905e8f39e526104071d58a4f3ea981e4c790dd254e7d9a57cc73b7381a6a65ab949a9eb0233afb7a690b30b84472aa872a7 url http://supervisor/core
Configuration loaded
Update at time 2025-12-30 15:00:00+00:00 Processing sensor sensor.load_today incrementing True max_increment 0 reset_daily True interval 30 days 14 export_days 14 subtract None
Getting history for sensor sensor.load_today start 2025-12-16T15:00:00+0000 end 2025-12-30T15:00:00+0000
History for sensor sensor.load_today starts at 2025-12-20 04:13:00+00:00 ends at 2025-12-30 14:59:00+00:00
Process dataset for sensor sensor.load_today start 2025-12-20 04:13:00+00:00 end 2025-12-30 14:59:00+00:00 incrementing True reset_low 1.0 reset_high 2.0
ds y
0 2025-12-20 04:00:00+00:00 0.00
1 2025-12-20 04:30:00+00:00 0.10
2 2025-12-20 05:00:00+00:00 0.13
3 2025-12-20 05:30:00+00:00 0.20
4 2025-12-20 06:00:00+00:00 0.19
.. ... ...
497 2025-12-30 12:30:00+00:00 0.24
498 2025-12-30 13:00:00+00:00 0.21
499 2025-12-30 13:30:00+00:00 0.19
500 2025-12-30 14:00:00+00:00 0.71
501 2025-12-30 14:30:00+00:00 0.44
[502 rows x 2 columns]
Create table sensor_load_today
Added 0 rows to database table sensor_load_today
Cleanup table sensor_load_today older than 2024-12-30 15:48:23+0000
Stored dataset in database and retrieved full history from database length 502
WARNING - (py.warnings._showwarnmsg) - /usr/local/lib/python3.10/dist-packages/neuralprophet/df_utils.py:464: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass
include_groups=Falseto exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.if df.groupby("ID").apply(lambda x: x.duplicated("ds").any()).any():
WARNING - (py.warnings._showwarnmsg) - /usr/local/lib/python3.10/dist-packages/neuralprophet/df_utils.py:1152: FutureWarning: Series.view is deprecated and will be removed in a future version. Use
astypeas an alternative to change the dtype.converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)
WARNING - (py.warnings._showwarnmsg) - /usr/local/lib/python3.10/dist-packages/neuralprophet/df_utils.py:1152: FutureWarning: Series.view is deprecated and will be removed in a future version. Use
astypeas an alternative to change the dtype.converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)
WARNING - (py.warnings._showwarnmsg) - /usr/local/lib/python3.10/dist-packages/neuralprophet/df_utils.py:1152: FutureWarning: Series.view is deprecated and will be removed in a future version. Use
astypeas an alternative to change the dtype.converted_ds = pd.to_datetime(ds_col, utc=True).view(dtype=np.int64)
Finding best initial lr: 0%| | 0/219 [00:00<?, ?it/s]WARNING - (py.warnings._showwarnmsg) - /usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the
batch_sizefrom an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, useself.log(..., batch_size=batch_size).Finding best initial lr: 0%| | 1/219 [00:00<00:31, 7.02it/s]WARNING - (py.warnings._showwarnmsg) - /usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the
batch_sizefrom an ambiguous collection. The batch size we found is 22. To avoid any miscalculations, useself.log(..., batch_size=batch_size).Finding best initial lr: 7%|▋ | 16/219 [00:00<00:02, 76.52it/s]
Finding best initial lr: 14%|█▍ | 31/219 [00:00<00:01, 106.16it/s]
Finding best initial lr: 21%|██ | 46/219 [00:00<00:01, 122.30it/s]
Finding best initial lr: 28%|██▊ | 62/219 [00:00<00:01, 132.46it/s]
Finding best initial lr: 36%|███▌ | 78/219 [00:00<00:01, 138.86it/s]
Finding best initial lr: 42%|████▏ | 93/219 [00:00<00:00, 141.83it/s]
Finding best initial lr: 50%|████▉ | 109/219 [00:00<00:00, 144.56it/s]
Finding best initial lr: 57%|█████▋ | 124/219 [00:00<00:00, 145.08it/s]
Finding best initial lr: 63%|██████▎ | 139/219 [00:01<00:00, 146.18it/s]
Finding best initial lr: 70%|███████ | 154/219 [00:01<00:00, 147.08it/s]
Finding best initial lr: 77%|███████▋ | 169/219 [00:01<00:00, 147.87it/s...
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.