Skip to content

Commit b84945f

Browse files
authored
Update GPT Detector (#15)
* edit gpt pipeline * tab -> spaces * fix lint
1 parent c775263 commit b84945f

File tree

5 files changed

+77
-23
lines changed

5 files changed

+77
-23
lines changed
Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
{
22
"primitives": [
33
"mlstars.custom.timeseries_preprocessing.time_segments_aggregate",
4+
"sklearn.impute.SimpleImputer",
45
"sigllm.primitives.transformation.Float2Scalar",
5-
"mlstars.custom.timeseries_preprocessing.rolling_window_sequences",
6+
"sigllm.primitives.forecasting.custom.rolling_window_sequences",
67
"sigllm.primitives.transformation.format_as_string",
78
"sigllm.primitives.forecasting.gpt.GPT",
89
"sigllm.primitives.transformation.format_as_integer",
910
"sigllm.primitives.transformation.Scalar2Float",
1011
"sigllm.primitives.postprocessing.aggregate_rolling_window",
12+
"numpy.reshape",
1113
"orion.primitives.timeseries_errors.regression_errors",
1214
"orion.primitives.timeseries_anomalies.find_anomalies"
1315
],
@@ -21,45 +23,73 @@
2123
"decimal": 2,
2224
"rescale": true
2325
},
24-
"mlstars.custom.timeseries_preprocessing.rolling_window_sequences#1": {
26+
"sigllm.primitives.forecasting.custom.rolling_window_sequences#1": {
2527
"target_column": 0,
2628
"window_size": 140,
2729
"target_size": 1
2830
},
2931
"sigllm.primitives.transformation.format_as_string#1": {
3032
"space": true
3133
},
32-
"sigllm.primitives.forecasting.gpt.GPT": {
34+
"sigllm.primitives.forecasting.gpt.GPT#1": {
3335
"name": "gpt-3.5-turbo",
3436
"steps": 5
3537
},
3638
"sigllm.primitives.transformation.format_as_integer#1": {
37-
"trunc": 1
39+
"trunc": 1,
40+
"errors": "coerce"
3841
},
3942
"sigllm.primitives.postprocessing.aggregate_rolling_window#1": {
40-
"agg": "median"
43+
"agg": "median",
44+
"remove_outliers": true
4145
},
4246
"orion.primitives.timeseries_anomalies.find_anomalies#1": {
43-
"window_size_portion": 0.33,
47+
"window_size_portion": 0.3,
4448
"window_step_size_portion": 0.1,
4549
"fixed_threshold": true
4650
}
4751
},
4852
"input_names": {
53+
"sigllm.primitives.transformation.Float2Scalar#1": {
54+
"X": "y"
55+
},
56+
"sigllm.primitives.transformation.format_as_integer#1": {
57+
"X": "y_hat"
58+
},
59+
"sigllm.primitives.transformation.Scalar2Float#1": {
60+
"X": "y_hat"
61+
},
4962
"sigllm.primitives.postprocessing.aggregate_rolling_window#1": {
5063
"y": "y_hat"
64+
},
65+
"numpy.reshape#1": {
66+
"X": "y_hat"
67+
},
68+
"orion.primitives.timeseries_anomalies.find_anomalies#1": {
69+
"index": "target_index"
5170
}
5271
},
5372
"output_names": {
54-
"mlstars.custom.timeseries_preprocessing.rolling_window_sequences#1": {
55-
"index": "X_index",
56-
"target_index": "y_index"
73+
"sklearn.impute.SimpleImputer#1": {
74+
"X": "y"
5775
},
58-
"sigllm.primitives.forecasting.huggingface.HF#1": {
59-
"y": "yhat"
76+
"sigllm.primitives.forecasting.gpt.GPT#1": {
77+
"y": "y_hat"
78+
},
79+
"sigllm.primitives.transformation.format_as_integer#1": {
80+
"X": "y_hat"
81+
},
82+
"sigllm.primitives.transformation.Scalar2Float#1": {
83+
"X": "y_hat"
6084
},
6185
"sigllm.primitives.postprocessing.aggregate_rolling_window#1": {
62-
"y": "yhat"
86+
"y": "y_hat"
87+
},
88+
"numpy.reshape#1": {
89+
"X": "y_hat"
90+
},
91+
"orion.primitives.timeseries_anomalies.find_anomalies#1": {
92+
"y": "anomalies"
6393
}
6494
}
6595
}

sigllm/pipelines/detector/mistral_detector.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"sklearn.impute.SimpleImputer",
55
"sigllm.primitives.transformation.Float2Scalar",
66
"sigllm.primitives.forecasting.custom.rolling_window_sequences",
7-
"sigllm.primitives.transformation.format_as_string",
7+
"sigllm.primitives.transformation.format_as_string",
88
"sigllm.primitives.forecasting.huggingface.HF",
99
"sigllm.primitives.transformation.format_as_integer",
1010
"sigllm.primitives.transformation.Scalar2Float",

sigllm/primitives/forecasting/gpt.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import openai
77
import tiktoken
8+
from openai import OpenAI
89
from tqdm import tqdm
910

1011
PROMPT_PATH = os.path.join(
@@ -74,6 +75,8 @@ def __init__(self, name='gpt-3.5-turbo', chat=True, sep=',', steps=1, temp=1,
7475
valid_tokens.extend(self.tokenizer.encode(self.sep))
7576
self.logit_bias = {token: BIAS for token in valid_tokens}
7677

78+
self.client = OpenAI()
79+
7780
def forecast(self, X, **kwargs):
7881
"""Use GPT to forecast a signal.
7982
@@ -86,21 +89,21 @@ def forecast(self, X, **kwargs):
8689
* List of forecasted signal values.
8790
* Optionally, a list of the output tokens' log probabilities.
8891
"""
89-
input_length = len(self.tokenizer.encode(X[0]))
90-
average_length = (input_length + 1) // len(X[0].split(','))
91-
max_tokens = average_length * self.steps
92-
9392
all_responses, all_probs = [], []
9493
for text in tqdm(X):
94+
input_length = len(self.tokenizer.encode(text))
95+
average_length = (input_length + 1) // len(text.split(','))
96+
max_tokens = average_length * self.steps
97+
9598
if self.chat:
96-
message = ' '.join(PROMPTS['user_message'], text, self.sep)
97-
response = openai.ChatCompletion.create(
99+
message = ' '.join([PROMPTS['user_message'], text, self.sep])
100+
response = self.client.chat.completions.create(
98101
model=self.name,
99102
messages=[
100103
{"role": "system", "content": PROMPTS['system_message']},
101104
{"role": "user", "content": message}
102105
],
103-
max_tokens=max_tokens,
106+
max_completion_tokens=max_tokens,
104107
temperature=self.temp,
105108
top_p=self.top_p,
106109
logprobs=self.logprobs,
@@ -111,7 +114,7 @@ def forecast(self, X, **kwargs):
111114
responses = [choice.message.content for choice in response.choices]
112115

113116
else:
114-
message = ' '.join(text, self.sep)
117+
message = ' '.join([text, self.sep])
115118
response = openai.Completion.create(
116119
model=self.name,
117120
prompt=message,
@@ -135,4 +138,4 @@ def forecast(self, X, **kwargs):
135138
if self.logprobs:
136139
return all_responses, all_probs
137140

138-
return responses
141+
return all_responses

sigllm/primitives/jsons/sigllm.primitives.postprocessing.aggregate_rolling_window.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
"agg": {
3030
"type": "str",
3131
"default": "median"
32+
},
33+
"remove_outliers": {
34+
"type": "bool",
35+
"default": false
3236
}
3337
}
3438
}

sigllm/primitives/postprocessing.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,19 @@
22
import numpy as np
33

44

5-
def aggregate_rolling_window(y, step_size=1, agg="median"):
5+
def outliers(predictions):
6+
Q1, Q3 = np.percentile(predictions, [25, 75])
7+
8+
IQR = Q3 - Q1
9+
lower_bound = Q1 - 1.5 * IQR
10+
upper_bound = Q3 + 1.5 * IQR
11+
12+
predictions[(predictions < lower_bound) | (predictions > upper_bound)] = np.nan
13+
14+
return predictions
15+
16+
17+
def aggregate_rolling_window(y, step_size=1, agg="median", remove_outliers=False):
618
"""Aggregate a rolling window sequence.
719
820
Convert a rolling window sequence into a flattened time series.
@@ -15,6 +27,8 @@ def aggregate_rolling_window(y, step_size=1, agg="median"):
1527
Stride size used when creating the rolling windows.
1628
agg (string):
1729
String denoting the aggregation method to use. Default is "median".
30+
remove_outliers (bool):
31+
Indicator to whether remove outliers from the predictions.
1832
1933
Return:
2034
ndarray:
@@ -23,6 +37,9 @@ def aggregate_rolling_window(y, step_size=1, agg="median"):
2337
num_windows, num_samples, pred_length = y.shape
2438
num_errors = pred_length + step_size * (num_windows - 1)
2539

40+
if remove_outliers:
41+
y = outliers(y)
42+
2643
method = getattr(np, agg)
2744
signal = []
2845

0 commit comments

Comments
 (0)