fixes for json serialize bug
This commit is contained in:
19
config.py
19
config.py
@@ -25,14 +25,14 @@ class Config:
|
||||
self.set_pg_config()
|
||||
self.output_formatting()
|
||||
with open("./config.json", "w") as outfile:
|
||||
json.dump(self.__encode(), outfile)
|
||||
json.dump(self.__dict__, outfile)
|
||||
return self
|
||||
else:
|
||||
self.receive_data_path()
|
||||
self.set_pg_config()
|
||||
self.output_formatting()
|
||||
with open("./config.json", "w") as outfile:
|
||||
json.dump(self.__encode(), outfile)
|
||||
json.dump(self.__dict__, outfile)
|
||||
return self
|
||||
|
||||
# if a config file already exists, offer to use it instead
|
||||
@@ -42,9 +42,9 @@ class Config:
|
||||
|
||||
if response_for_prev_config == "y":
|
||||
with open("./config.json", "r") as infile:
|
||||
config_data = json.load(infile)
|
||||
config_data: Config = json.load(infile)
|
||||
|
||||
return Config(config_data)
|
||||
return config_data
|
||||
elif response_for_prev_config == "n":
|
||||
return None
|
||||
else:
|
||||
@@ -62,12 +62,6 @@ class Config:
|
||||
print("Got it!")
|
||||
self.data_path = data_path
|
||||
|
||||
def get_data_path(self):
|
||||
return self.data_path
|
||||
|
||||
def set_data_path(self, value):
|
||||
self.data_path = value
|
||||
|
||||
def set_pg_config(self):
|
||||
"""Determine if data should be associated with a PostgreSQL instance, and, if so, record the required connection info"""
|
||||
elect_for_pg = input("Connect this program to a PostgreSQL instance? y/n ").lower()
|
||||
@@ -117,7 +111,4 @@ class Config:
|
||||
self.sort_by_match_strength = False
|
||||
else:
|
||||
print("Invalid response.")
|
||||
self.output_formatting()
|
||||
|
||||
def __encode(self):
|
||||
return json.dumps(self, default=lambda x: x.__dict__)
|
||||
self.output_formatting()
|
||||
@@ -5,13 +5,13 @@ def format_result(app_config: Config, json_path):
|
||||
# dictionary to hold and later display our results
|
||||
insertions_by_label = {}
|
||||
|
||||
data_path = app_config.data_path
|
||||
data_path = app_config['data_path']
|
||||
|
||||
# if pg_config is not None, run the postgres prediction[0] of this code
|
||||
pg_config = app_config.pg_config
|
||||
pg_config = app_config['pg_config']
|
||||
|
||||
# if this is True, run the prediction[0] "for line in contents:" below
|
||||
sort_by_match_strength = app_config.sort_by_match_strength
|
||||
sort_by_match_strength = app_config['sort_by_match_strength']
|
||||
|
||||
weak_results = 0
|
||||
total_count = 0
|
||||
@@ -52,6 +52,9 @@ def format_result(app_config: Config, json_path):
|
||||
if not guess_label in insertions_by_label:
|
||||
insertions_by_label[guess_label] = 0
|
||||
|
||||
print(img_path)
|
||||
print("./predictions/" + guess_label)
|
||||
|
||||
# copy file to appropriate location, depending on if sorting
|
||||
if sort_by_match_strength:
|
||||
if (not os.path.exists("./predictions/" + match_strength + guess_label)):
|
||||
|
||||
18
main.py
18
main.py
@@ -14,11 +14,11 @@ print("\n\n")
|
||||
############################## CONFIG
|
||||
|
||||
print("Running app config...")
|
||||
appconfig: Config = Config()
|
||||
appconfig = Config()
|
||||
config_file = appconfig.run()
|
||||
|
||||
if (config_file.get_data_path()[-1] != "/"):
|
||||
config_file.set_data_path(config_file.get_data_path() + "/")
|
||||
if (config_file['data_path'][-1] != "/"):
|
||||
config_file['data_path'] = config_file['data_path'] + "/"
|
||||
|
||||
# create the target directory if it doesn't exist
|
||||
if (not os.path.exists("./predictions")):
|
||||
@@ -29,7 +29,7 @@ if (not os.path.exists("./predictions")):
|
||||
############################## SETUP
|
||||
############################## SETUP
|
||||
|
||||
files = os.listdir(config_file.data_path)
|
||||
files = os.listdir(config_file['data_path'])
|
||||
|
||||
# generate current time for use in identifying outfiles
|
||||
cur_time = str(int(time()))
|
||||
@@ -41,22 +41,22 @@ all_results = []
|
||||
############################## ANALYSIS
|
||||
############################## ANALYSIS
|
||||
|
||||
print("Attempting TF imports...\n\n")
|
||||
print("\nAttempting imports...\n")
|
||||
|
||||
from predict import predict
|
||||
from formatresult import format_result
|
||||
from keras.applications.vgg16 import VGG16
|
||||
|
||||
print("Success!")
|
||||
print("\nSuccess!\n")
|
||||
|
||||
# declare model to be used for each prediction
|
||||
model = VGG16(weights='imagenet')
|
||||
|
||||
print("Running image analysis. This may take some time...\n\n")
|
||||
print("\nRunning image analysis. This may take some time...\n\n")
|
||||
|
||||
# for each file in directory, append its prediction result to main list
|
||||
for file in files:
|
||||
result = predict(model, config_file.data_path + file)
|
||||
result = predict(model, config_file['data_path'] + file)
|
||||
if result is not None:
|
||||
all_results.append({ "path": file, "prediction": result })
|
||||
|
||||
@@ -74,6 +74,6 @@ print("Analysis complete! Beginning sort process...\n\n")
|
||||
############################## SORTING
|
||||
############################## SORTING
|
||||
|
||||
format_result(appconfig, json_path)
|
||||
format_result(config_file, json_path)
|
||||
|
||||
print("File sort successful! Process complete.")
|
||||
|
||||
Reference in New Issue
Block a user