55# the WPILib BSD license file in the root directory of this project.
66#
77
8- from typing import Any , Callable , Union
9-
10- from .command import Command
11- from .subsystem import Subsystem
8+ from typing import Any , Callable , Generic , Union
129
1310from wpimath .controller import ProfiledPIDController , ProfiledPIDControllerRadians
1411from wpimath .trajectory import TrapezoidProfile , TrapezoidProfileRadians
1512
13+ from .command import Command
14+ from .subsystem import Subsystem
15+ from .typing import TProfiledPIDController , UseOutputFunction
16+
1617
17- class ProfiledPIDCommand (Command ):
18+ class ProfiledPIDCommand (Command , Generic [ TProfiledPIDController ] ):
1819 """A command that controls an output with a :class:`.ProfiledPIDController`. Runs forever by default -
1920 to add exit conditions and/or other behavior, subclass this class. The controller calculation and
2021 output are performed synchronously in the command's execute() method.
@@ -24,10 +25,10 @@ class ProfiledPIDCommand(Command):
2425
2526 def __init__ (
2627 self ,
27- controller ,
28+ controller : TProfiledPIDController ,
2829 measurementSource : Callable [[], float ],
2930 goalSource : Union [float , Callable [[], float ]],
30- useOutput : Callable [[ float , Any ], Any ] ,
31+ useOutput : UseOutputFunction ,
3132 * requirements : Subsystem ,
3233 ):
3334 """Creates a new ProfiledPIDCommand, which controls the given output with a ProfiledPIDController. Goal
@@ -40,14 +41,15 @@ def __init__(
4041 :param requirements: the subsystems required by this command
4142 """
4243
44+ super ().__init__ ()
4345 if isinstance (controller , ProfiledPIDController ):
4446 self ._stateCls = TrapezoidProfile .State
4547 elif isinstance (controller , ProfiledPIDControllerRadians ):
4648 self ._stateCls = TrapezoidProfileRadians .State
4749 else :
4850 raise ValueError (f"unknown controller type { controller !r} " )
4951
50- self ._controller = controller
52+ self ._controller : TProfiledPIDController = controller
5153 self ._useOutput = useOutput
5254 self ._measurement = measurementSource
5355 if isinstance (goalSource , (float , int )):
0 commit comments