This is the official code base for the paper GPS Masked Trajectory Models for Human Life Travel Pattern Learning"
Run conda env create -f environment.yml
Example commands can be found in train_examples.sh Run the last command in this file.
The main code is located in the mtm folder.
- The config file for mtm is located at
research/mtm/config.yaml - Some key parameters
traj_length: The length of trajectory sub-segmentsmask_ratios: A list of mask ratios that is randomly sampledmask_pattterns: A list of masking patterns that are randomly sampled. SeeMaskTypeunderresearch/mtm/masks.pyfor supported options.mode_weights: (Only applies forAUTO_MASK) A list of weights that samples which mode is to be the "autoregressive" one. For example, if the mode order is,states,returns,actions, and mode_weights = [0.2, 0.1, 0.7], then with 0.7 probability, the action token and all future tokens will be masked out.
The system supports multiple masking strategies for different evaluation scenarios:
- RANDOM: General random masking
- GOAL: Goal-reaching tasks (
⚠️ hardcoded for ~55-step trajectories) - ID: Inverse dynamics - predict actions from states (
⚠️ hardcoded for first 24 steps) - FD: Forward dynamics - predict future states
- BC/RCBC: Behavioral cloning variants
- FULL_RANDOM: Per-feature random masking
- AUTO_MASK: Autoregressive masking
MASKING_PATTERNS.md for detailed reference and research/mtm/TEST_README.md for troubleshooting.
pre-commits hooks are great. This will automatically do some checking/formatting. To use the pre-commit hooks, run the following:
pip install pre-commit
pre-commit install
If you want to make a commit without using the pre-commit hook, you can commit with the -n flag (ie. git commit -n ...).
- all dataset code is located in the
/Trajectory_dataset/anomaly_traj_datafolder. All datasets have to do is return a pytorch dataset that outputs a dict (named set of trajectories). - a dataset should follow the
DatasetProtocolspecified inresearch/mtm/datasets/base.py. - each dataset should also have a corresponding
get_datasetsfunction where all the dataset specific construction logic happens. This function can take anything as input (as specified in the correspondingyamlconfig) and output the train and val torchDataset.
- All tokenizer code is found in the
research/mtm/tokenizersfolder. Each tokenizer should inherit from theTokenizerabstract class, found inresearch/mtm/tokenizers/base.py Tokenizersmust define acreatemethod, which can handle dataset specific construction logic.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. This is not an official Meta product.
This project builds on top of or utilizes the following third party dependencies.
- FangchenLiu/MaskDP_public: Masked Decision Prediction, which this work builds upon
- ikostrikov/jaxrl: A fast Jax library for RL. We used this environment wrapping and data loading code for all d4rl experiments.
- denisyarats/exorl: ExORL provides datasets collected with unsupervised RL methods which we use in representation learning experiments
- vikashplus/robohive: Provides the Adroit environment
- aravindr93/mjrl: Code for training the policy for generating data on Adroit
- brentyi/tyro: Argument parsing and configuration